Skip to content

Commit 3277b0d

Browse files
committed
Handle STOMP messages to user destination in order
Closes gh-31395
1 parent 9eb39e1 commit 3277b0d

File tree

9 files changed

+218
-24
lines changed

9 files changed

+218
-24
lines changed

framework-docs/modules/ROOT/pages/web/websocket/stomp/ordered-messages.adoc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ written to WebSocket sessions. As the channel is backed by a `ThreadPoolExecutor
66
are processed in different threads, and the resulting sequence received by the client may
77
not match the exact order of publication.
88

9-
If this is an issue, enable the `setPreservePublishOrder` flag, as the following example shows:
9+
To enable ordered publishing, set the `setPreservePublishOrder` flag as follows:
1010

1111
[source,java,indent=0,subs="verbatim,quotes"]
1212
----
@@ -47,5 +47,22 @@ When the flag is set, messages within the same client session are published to t
4747
`clientOutboundChannel` one at a time, so that the order of publication is guaranteed.
4848
Note that this incurs a small performance overhead, so you should enable it only if it is required.
4949

50+
The same also applies to messages from the client, which are sent to the `clientInboundChannel`,
51+
from where they are handled according to their destination prefix. As the channel is backed by
52+
a `ThreadPoolExecutor`, messages are processed in different threads, and the resulting sequence
53+
of handling may not match the exact order in which they were received.
5054

55+
To enable ordered publishing, set the `setPreserveReceiveOrder` flag as follows:
5156

57+
[source,java,indent=0,subs="verbatim,quotes"]
58+
----
59+
@Configuration
60+
@EnableWebSocketMessageBroker
61+
public class MyConfig implements WebSocketMessageBrokerConfigurer {
62+
63+
@Override
64+
public void registerStompEndpoints(StompEndpointRegistry registry) {
65+
registry.setPreserveReceiveOrder(true);
66+
}
67+
}
68+
----

spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ protected void doSend(String destination, Message<?> message) {
157157
if (simpAccessor.isMutable()) {
158158
simpAccessor.setDestination(destination);
159159
simpAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
160-
simpAccessor.setImmutable();
160+
// ImmutableMessageChannelInterceptor will make it immutable
161161
sendInternal(message);
162162
return;
163163
}

spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,16 @@ else if (channel instanceof ExecutorSubscribableChannel execChannel) {
159159
}
160160
}
161161

162+
/**
163+
* Whether the channel has been {@link #configureInterceptor configured}
164+
* with an interceptor for sequential handling.
165+
* @since 6.1
166+
*/
167+
public static boolean supportsOrderedMessages(MessageChannel channel) {
168+
return (channel instanceof ExecutorSubscribableChannel ch &&
169+
ch.getInterceptors().stream().anyMatch(CallbackTaskInterceptor.class::isInstance));
170+
}
171+
162172
/**
163173
* Obtain the task to release the next message, if found.
164174
*/

spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,17 @@ public UserDestinationResult resolveDestination(Message<?> message) {
131131
}
132132
String user = parseResult.getUser();
133133
String sourceDest = parseResult.getSourceDestination();
134+
Set<String> sessionIds = parseResult.getSessionIds();
134135
Set<String> targetSet = new HashSet<>();
135-
for (String sessionId : parseResult.getSessionIds()) {
136+
for (String sessionId : sessionIds) {
136137
String actualDest = parseResult.getActualDestination();
137138
String targetDest = getTargetDestination(sourceDest, actualDest, sessionId, user);
138139
if (targetDest != null) {
139140
targetSet.add(targetDest);
140141
}
141142
}
142143
String subscribeDest = parseResult.getSubscribeDestination();
143-
return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user);
144+
return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user, sessionIds);
144145
}
145146

146147
@Nullable

spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,13 +17,18 @@
1717
package org.springframework.messaging.simp.user;
1818

1919
import java.util.Arrays;
20+
import java.util.Iterator;
2021
import java.util.List;
22+
import java.util.Map;
23+
import java.util.Set;
24+
import java.util.concurrent.ConcurrentHashMap;
2125

2226
import org.apache.commons.logging.Log;
2327

2428
import org.springframework.context.SmartLifecycle;
2529
import org.springframework.lang.Nullable;
2630
import org.springframework.messaging.Message;
31+
import org.springframework.messaging.MessageChannel;
2732
import org.springframework.messaging.MessageHandler;
2833
import org.springframework.messaging.MessageHeaders;
2934
import org.springframework.messaging.MessagingException;
@@ -33,6 +38,7 @@
3338
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
3439
import org.springframework.messaging.simp.SimpMessageType;
3540
import org.springframework.messaging.simp.SimpMessagingTemplate;
41+
import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
3642
import org.springframework.messaging.support.MessageBuilder;
3743
import org.springframework.messaging.support.MessageHeaderAccessor;
3844
import org.springframework.messaging.support.MessageHeaderInitializer;
@@ -61,7 +67,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
6167

6268
private final UserDestinationResolver destinationResolver;
6369

64-
private final MessageSendingOperations<String> messagingTemplate;
70+
private final SendHelper sendHelper;
6571

6672
@Nullable
6773
private BroadcastHandler broadcastHandler;
@@ -91,7 +97,7 @@ public UserDestinationMessageHandler(
9197

9298
this.clientInboundChannel = clientInboundChannel;
9399
this.brokerChannel = brokerChannel;
94-
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel);
100+
this.sendHelper = new SendHelper(clientInboundChannel, brokerChannel);
95101
this.destinationResolver = destinationResolver;
96102
}
97103

@@ -112,7 +118,7 @@ public UserDestinationResolver getUserDestinationResolver() {
112118
*/
113119
public void setBroadcastDestination(@Nullable String destination) {
114120
this.broadcastHandler = (StringUtils.hasText(destination) ?
115-
new BroadcastHandler(this.messagingTemplate, destination) : null);
121+
new BroadcastHandler(this.sendHelper.getMessagingTemplate(), destination) : null);
116122
}
117123

118124
/**
@@ -128,7 +134,7 @@ public String getBroadcastDestination() {
128134
* broker channel.
129135
*/
130136
public MessageSendingOperations<String> getBrokerMessagingTemplate() {
131-
return this.messagingTemplate;
137+
return this.sendHelper.getMessagingTemplate();
132138
}
133139

134140
/**
@@ -193,6 +199,7 @@ public void handleMessage(Message<?> sourceMessage) throws MessagingException {
193199

194200
UserDestinationResult result = this.destinationResolver.resolveDestination(message);
195201
if (result == null) {
202+
this.sendHelper.checkDisconnect(message);
196203
return;
197204
}
198205

@@ -215,9 +222,8 @@ public void handleMessage(Message<?> sourceMessage) throws MessagingException {
215222
if (logger.isTraceEnabled()) {
216223
logger.trace("Translated " + result.getSourceDestination() + " -> " + result.getTargetDestinations());
217224
}
218-
for (String target : result.getTargetDestinations()) {
219-
this.messagingTemplate.send(target, message);
220-
}
225+
226+
this.sendHelper.send(result, message);
221227
}
222228

223229
private void initHeaders(SimpMessageHeaderAccessor headerAccessor) {
@@ -232,6 +238,63 @@ public String toString() {
232238
}
233239

234240

241+
private static class SendHelper {
242+
243+
private final MessageChannel brokerChannel;
244+
245+
private final MessageSendingOperations<String> messagingTemplate;
246+
247+
@Nullable
248+
private final Map<String, MessageSendingOperations<String>> orderedMessagingTemplates;
249+
250+
SendHelper(MessageChannel clientInboundChannel, MessageChannel brokerChannel) {
251+
this.brokerChannel = brokerChannel;
252+
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel);
253+
if (OrderedMessageChannelDecorator.supportsOrderedMessages(clientInboundChannel)) {
254+
this.orderedMessagingTemplates = new ConcurrentHashMap<>();
255+
OrderedMessageChannelDecorator.configureInterceptor(brokerChannel, true);
256+
}
257+
else {
258+
this.orderedMessagingTemplates = null;
259+
}
260+
}
261+
262+
public MessageSendingOperations<String> getMessagingTemplate() {
263+
return this.messagingTemplate;
264+
}
265+
266+
public void send(UserDestinationResult destinationResult, Message<?> message) throws MessagingException {
267+
Set<String> sessionIds = destinationResult.getSessionIds();
268+
Iterator<String> itr = (sessionIds != null ? sessionIds.iterator() : null);
269+
270+
for (String target : destinationResult.getTargetDestinations()) {
271+
String sessionId = (itr != null ? itr.next() : null);
272+
getTemplateToUse(sessionId).send(target, message);
273+
}
274+
}
275+
276+
private MessageSendingOperations<String> getTemplateToUse(@Nullable String sessionId) {
277+
if (this.orderedMessagingTemplates != null && sessionId != null) {
278+
return this.orderedMessagingTemplates.computeIfAbsent(sessionId, id ->
279+
new SimpMessagingTemplate(new OrderedMessageChannelDecorator(this.brokerChannel, logger)));
280+
}
281+
return this.messagingTemplate;
282+
}
283+
284+
public void checkDisconnect(Message<?> message) {
285+
if (this.orderedMessagingTemplates != null) {
286+
MessageHeaders headers = message.getHeaders();
287+
if (SimpMessageHeaderAccessor.getMessageType(headers) == SimpMessageType.DISCONNECT) {
288+
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
289+
if (sessionId != null) {
290+
this.orderedMessagingTemplates.remove(sessionId);
291+
}
292+
}
293+
}
294+
}
295+
}
296+
297+
235298
/**
236299
* A handler that broadcasts locally unresolved messages to the broker and
237300
* also handles similar broadcasts received from the broker.

spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationResult.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2017 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.messaging.simp.user;
1818

19+
import java.util.Collections;
1920
import java.util.Set;
2021

2122
import org.springframework.lang.Nullable;
@@ -40,10 +41,23 @@ public class UserDestinationResult {
4041
@Nullable
4142
private final String user;
4243

44+
private final Set<String> sessionIds;
45+
4346

4447
public UserDestinationResult(String sourceDestination, Set<String> targetDestinations,
4548
String subscribeDestination, @Nullable String user) {
4649

50+
this(sourceDestination, targetDestinations, subscribeDestination, user, null);
51+
}
52+
53+
/**
54+
* Additional constructor with the session id for each targetDestination.
55+
* @since 6.1
56+
*/
57+
public UserDestinationResult(
58+
String sourceDestination, Set<String> targetDestinations,
59+
String subscribeDestination, @Nullable String user, @Nullable Set<String> sessionIds) {
60+
4761
Assert.notNull(sourceDestination, "'sourceDestination' must not be null");
4862
Assert.notNull(targetDestinations, "'targetDestinations' must not be null");
4963
Assert.notNull(subscribeDestination, "'subscribeDestination' must not be null");
@@ -52,6 +66,7 @@ public UserDestinationResult(String sourceDestination, Set<String> targetDestina
5266
this.targetDestinations = targetDestinations;
5367
this.subscribeDestination = subscribeDestination;
5468
this.user = user;
69+
this.sessionIds = (sessionIds != null ? sessionIds : Collections.emptySet());
5570
}
5671

5772

@@ -96,6 +111,13 @@ public String getUser() {
96111
return this.user;
97112
}
98113

114+
/**
115+
* Return the session id for the targetDestination.
116+
*/
117+
@Nullable
118+
public Set<String> getSessionIds() {
119+
return this.sessionIds;
120+
}
99121

100122
@Override
101123
public String toString() {

spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessagingTemplateTests.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ public void convertAndSendWithMutableSimpMessageHeaders() {
158158
Message<byte[]> message = messages.get(0);
159159

160160
assertThat(message.getHeaders()).isSameAs(headers);
161-
assertThat(accessor.isMutable()).isFalse();
162161
}
163162

164163
@Test
@@ -190,7 +189,6 @@ public void doSendWithMutableHeaders() {
190189
Message<byte[]> sentMessage = messages.get(0);
191190

192191
assertThat(sentMessage).isSameAs(message);
193-
assertThat(accessor.isMutable()).isFalse();
194192
}
195193

196194
@Test

spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.concurrent.CompletableFuture;
2525
import java.util.stream.Stream;
2626

27+
import jakarta.servlet.Filter;
2728
import org.apache.commons.logging.Log;
2829
import org.apache.commons.logging.LogFactory;
2930
import org.junit.jupiter.api.AfterEach;
@@ -35,6 +36,7 @@
3536
import org.springframework.context.Lifecycle;
3637
import org.springframework.context.annotation.Bean;
3738
import org.springframework.context.annotation.Configuration;
39+
import org.springframework.lang.Nullable;
3840
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
3941
import org.springframework.web.socket.client.WebSocketClient;
4042
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
@@ -85,11 +87,18 @@ static Stream<Arguments> argumentsFactory() {
8587
protected AnnotationConfigWebApplicationContext wac;
8688

8789

88-
protected void setup(WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
90+
protected void setup(WebSocketTestServer server, WebSocketClient client, TestInfo info) throws Exception {
91+
setup(server, null, client, info);
92+
}
93+
94+
protected void setup(
95+
WebSocketTestServer server, @Nullable Filter filter, WebSocketClient client, TestInfo info)
96+
throws Exception {
97+
8998
this.server = server;
90-
this.webSocketClient = webSocketClient;
99+
this.webSocketClient = client;
91100

92-
logger.debug("Setting up '" + testInfo.getTestMethod().get().getName() + "', client=" +
101+
logger.debug("Setting up '" + info.getTestMethod().get().getName() + "', client=" +
93102
this.webSocketClient.getClass().getSimpleName() + ", server=" +
94103
this.server.getClass().getSimpleName());
95104

@@ -102,7 +111,12 @@ protected void setup(WebSocketTestServer server, WebSocketClient webSocketClient
102111
}
103112

104113
this.server.setup();
105-
this.server.deployConfig(this.wac);
114+
if (filter != null) {
115+
this.server.deployConfig(this.wac, filter);
116+
}
117+
else {
118+
this.server.deployConfig(this.wac);
119+
}
106120
this.server.start();
107121

108122
this.wac.setServletContext(this.server.getServletContext());

0 commit comments

Comments
 (0)