@@ -99,9 +99,18 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
9999 if ( ! diagnostic . Properties . TryGetValue ( Constants . Properties . Replacement , out var replacement ) || replacement is null )
100100 return ;
101101
102+ // We want to remove the 'await' in front of 'Assert.CollectionAsync' since 'Assert.Single' doesn't need it
103+ var nodeToReplace = invocation . Parent ;
104+ if ( assertMethodName == Constants . Asserts . CollectionAsync && nodeToReplace is AwaitExpressionSyntax )
105+ nodeToReplace = nodeToReplace . Parent ;
106+
107+ // Can't replace something that's not a standlone expression
108+ if ( nodeToReplace is not ExpressionStatementSyntax )
109+ return ;
110+
102111 context . RegisterCodeFix (
103112 XunitCodeAction . Create (
104- ct => UseSingleMethod ( context . Document , invocation , assertMethodName , replacement , ct ) ,
113+ ct => UseSingleMethod ( context . Document , invocation , nodeToReplace , assertMethodName , replacement , ct ) ,
105114 Key_UseSingleMethod ,
106115 "Use Assert.{0}" , replacement
107116 ) ,
@@ -112,6 +121,7 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
112121 static async Task < Document > UseSingleMethod (
113122 Document document ,
114123 InvocationExpressionSyntax invocation ,
124+ SyntaxNode nodeToReplace ,
115125 string assertMethodName ,
116126 string replacementMethod ,
117127 CancellationToken cancellationToken )
@@ -131,10 +141,6 @@ static async Task<Document> UseSingleMethod(
131141 . WithArgumentList ( ArgumentList ( SeparatedList ( [ Argument ( invocation . ArgumentList . Arguments [ 0 ] . Expression ) ] ) ) )
132142 . WithExpression ( memberAccess . WithName ( IdentifierName ( replacementMethod ) ) ) ;
133143
134- // We want to replace the whole expression, because it may include an unnecessary await, as we may be
135- // converting from Assert.CollectionAsync (which needs await) to Assert.Single (which does not).
136- var nodeToReplace = invocation . FirstAncestorOrSelf < ExpressionStatementSyntax > ( ) ?? invocation . Parent ;
137-
138144 if ( invocation . ArgumentList . Arguments [ 1 ] . Expression is SimpleLambdaExpressionSyntax lambdaExpression )
139145 {
140146 var originalParameterName = lambdaExpression . Parameter . Identifier . Text ;
0 commit comments