11// Licensed to the .NET Foundation under one or more agreements.
22// The .NET Foundation licenses this file to you under the MIT license.
33
4- using System ;
5- using System . Collections . Generic ;
6- using System . Threading . Tasks ;
4+ using System . Net . WebSockets ;
75using Microsoft . AspNetCore . Http . Connections ;
6+ using Microsoft . AspNetCore . Http . Connections . Client ;
7+ using Microsoft . AspNetCore . InternalTesting ;
88using Microsoft . AspNetCore . SignalR . Client ;
99using Microsoft . AspNetCore . SignalR . Protocol ;
1010using Microsoft . AspNetCore . SignalR . Tests ;
11- using Microsoft . AspNetCore . InternalTesting ;
1211using Microsoft . Extensions . DependencyInjection ;
1312using Microsoft . Extensions . Logging ;
14- using Xunit ;
1513
1614namespace Microsoft . AspNetCore . SignalR . StackExchangeRedis . Tests ;
1715
@@ -213,7 +211,105 @@ public async Task CanSendAndReceiveUserMessagesUserNameWithPatternIsTreatedAsLit
213211 }
214212 }
215213
216- private static HubConnection CreateConnection ( string url , HttpTransportType transportType , IHubProtocol protocol , ILoggerFactory loggerFactory , string userName = null )
214+ [ ConditionalTheory ]
215+ [ SkipIfDockerNotPresent ]
216+ [ InlineData ( "messagepack" ) ]
217+ [ InlineData ( "json" ) ]
218+ public async Task StatefulReconnectPreservesMessageFromOtherServer ( string protocolName )
219+ {
220+ using ( StartVerifiableLog ( ) )
221+ {
222+ var protocol = HubProtocolHelpers . GetHubProtocol ( protocolName ) ;
223+
224+ ClientWebSocket innerWs = null ;
225+ WebSocketWrapper ws = null ;
226+ TaskCompletionSource reconnectTcs = null ;
227+ TaskCompletionSource startedReconnectTcs = null ;
228+
229+ var connection = CreateConnection ( _serverFixture . FirstServer . Url + "/stateful" , HttpTransportType . WebSockets , protocol , LoggerFactory ,
230+ customizeConnection : builder =>
231+ {
232+ builder . WithStatefulReconnect ( ) ;
233+ builder . Services . Configure < HttpConnectionOptions > ( o =>
234+ {
235+ // Replace the websocket creation for the first connection so we can make the client think there was an ungraceful closure
236+ // Which will trigger the stateful reconnect flow
237+ o . WebSocketFactory = async ( context , token ) =>
238+ {
239+ if ( reconnectTcs is null )
240+ {
241+ reconnectTcs = new TaskCompletionSource ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
242+ startedReconnectTcs = new TaskCompletionSource ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
243+ }
244+ else
245+ {
246+ startedReconnectTcs . SetResult ( ) ;
247+ // We only want to wait on the reconnect, not the initial connection attempt
248+ await reconnectTcs . Task . DefaultTimeout ( ) ;
249+ }
250+
251+ innerWs = new ClientWebSocket ( ) ;
252+ ws = new WebSocketWrapper ( innerWs ) ;
253+ await innerWs . ConnectAsync ( context . Uri , token ) ;
254+
255+ _ = Task . Run ( async ( ) =>
256+ {
257+ try
258+ {
259+ while ( innerWs . State == WebSocketState . Open )
260+ {
261+ var buffer = new byte [ 1024 ] ;
262+ var res = await innerWs . ReceiveAsync ( buffer , default ) ;
263+ ws . SetReceiveResult ( ( res , buffer . AsMemory ( 0 , res . Count ) ) ) ;
264+ }
265+ }
266+ // Log but ignore receive errors, that likely just means the connection closed
267+ catch ( Exception ex )
268+ {
269+ Logger . LogInformation ( ex , "Error while reading from inner websocket" ) ;
270+ }
271+ } ) ;
272+
273+ return ws ;
274+ } ;
275+ } ) ;
276+ } ) ;
277+ var secondConnection = CreateConnection ( _serverFixture . SecondServer . Url + "/stateful" , HttpTransportType . WebSockets , protocol , LoggerFactory ) ;
278+
279+ var tcs = new TaskCompletionSource < string > ( ) ;
280+ connection . On < string > ( "SendToAll" , message => tcs . TrySetResult ( message ) ) ;
281+
282+ var tcs2 = new TaskCompletionSource < string > ( ) ;
283+ secondConnection . On < string > ( "SendToAll" , message => tcs2 . TrySetResult ( message ) ) ;
284+
285+ await connection . StartAsync ( ) . DefaultTimeout ( ) ;
286+ await secondConnection . StartAsync ( ) . DefaultTimeout ( ) ;
287+
288+ // Close first connection before the second connection sends a message to all clients
289+ await ws . CloseOutputAsync ( WebSocketCloseStatus . InternalServerError , statusDescription : null , default ) ;
290+ await startedReconnectTcs . Task . DefaultTimeout ( ) ;
291+
292+ // Send to all clients, since both clients are on different servers this means the backplane will be used
293+ // And we want to test that messages are still preserved for stateful reconnect purposes when a client disconnects
294+ // But is on a different server from the original message sender.
295+ await secondConnection . SendAsync ( "SendToAll" , "test message" ) . DefaultTimeout ( ) ;
296+
297+ // Check that second connection still receives the message
298+ Assert . Equal ( "test message" , await tcs2 . Task . DefaultTimeout ( ) ) ;
299+ Assert . False ( tcs . Task . IsCompleted ) ;
300+
301+ // allow first connection to reconnect
302+ reconnectTcs . SetResult ( ) ;
303+
304+ // Check that first connection received the message once it reconnected
305+ Assert . Equal ( "test message" , await tcs . Task . DefaultTimeout ( ) ) ;
306+
307+ await connection . DisposeAsync ( ) . DefaultTimeout ( ) ;
308+ }
309+ }
310+
311+ private static HubConnection CreateConnection ( string url , HttpTransportType transportType , IHubProtocol protocol , ILoggerFactory loggerFactory , string userName = null ,
312+ Action < IHubConnectionBuilder > customizeConnection = null )
217313 {
218314 var hubConnectionBuilder = new HubConnectionBuilder ( )
219315 . WithLoggerFactory ( loggerFactory )
@@ -227,6 +323,8 @@ private static HubConnection CreateConnection(string url, HttpTransportType tran
227323
228324 hubConnectionBuilder . Services . AddSingleton ( protocol ) ;
229325
326+ customizeConnection ? . Invoke ( hubConnectionBuilder ) ;
327+
230328 return hubConnectionBuilder . Build ( ) ;
231329 }
232330
@@ -255,4 +353,67 @@ public static IEnumerable<object[]> TransportTypesAndProtocolTypes
255353 }
256354 }
257355 }
356+
357+ internal sealed class WebSocketWrapper : WebSocket
358+ {
359+ private readonly WebSocket _inner ;
360+ private TaskCompletionSource < ( WebSocketReceiveResult , ReadOnlyMemory < byte > ) > _receiveTcs = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
361+
362+ public WebSocketWrapper ( WebSocket inner )
363+ {
364+ _inner = inner ;
365+ }
366+
367+ public override WebSocketCloseStatus ? CloseStatus => _inner . CloseStatus ;
368+
369+ public override string CloseStatusDescription => _inner . CloseStatusDescription ;
370+
371+ public override WebSocketState State => _inner . State ;
372+
373+ public override string SubProtocol => _inner . SubProtocol ;
374+
375+ public override void Abort ( )
376+ {
377+ _inner . Abort ( ) ;
378+ }
379+
380+ public override Task CloseAsync ( WebSocketCloseStatus closeStatus , string statusDescription , CancellationToken cancellationToken )
381+ {
382+ return _inner . CloseAsync ( closeStatus , statusDescription , cancellationToken ) ;
383+ }
384+
385+ public override Task CloseOutputAsync ( WebSocketCloseStatus closeStatus , string statusDescription , CancellationToken cancellationToken )
386+ {
387+ _receiveTcs . TrySetException ( new IOException ( "force reconnect" ) ) ;
388+ return Task . CompletedTask ;
389+ }
390+
391+ public override void Dispose ( )
392+ {
393+ _inner . Dispose ( ) ;
394+ }
395+
396+ public void SetReceiveResult ( ( WebSocketReceiveResult , ReadOnlyMemory < byte > ) result )
397+ {
398+ _receiveTcs . SetResult ( result ) ;
399+ }
400+
401+ public override async Task < WebSocketReceiveResult > ReceiveAsync ( ArraySegment < byte > buffer , CancellationToken cancellationToken )
402+ {
403+ var res = await _receiveTcs . Task ;
404+ // Handle zero-byte reads
405+ if ( buffer . Count == 0 )
406+ {
407+ return res . Item1 ;
408+ }
409+ _receiveTcs = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
410+ res . Item2 . CopyTo ( buffer ) ;
411+ return res . Item1 ;
412+ }
413+
414+ public override Task SendAsync ( ArraySegment < byte > buffer , WebSocketMessageType messageType , bool endOfMessage , CancellationToken cancellationToken )
415+ {
416+ return _inner . SendAsync ( buffer , messageType , endOfMessage , cancellationToken ) ;
417+ }
418+ }
258419}
0 commit comments