@@ -14,13 +14,25 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.ChangeTracking.Internal;
1414/// </summary>
1515public sealed class StringDictionaryComparer < TDictionary , TElement > : ValueComparer < object > , IInfrastructure < ValueComparer >
1616{
17+ private static readonly bool UseOldBehavior35239 =
18+ AppContext . TryGetSwitch ( "Microsoft.EntityFrameworkCore.Issue35239" , out var enabled35239 ) && enabled35239 ;
19+
1720 private static readonly MethodInfo CompareMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
21+ nameof ( Compare ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( object ) , typeof ( Func < TElement , TElement , bool > ) ] ) ! ;
22+
23+ private static readonly MethodInfo LegacyCompareMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
1824 nameof ( Compare ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( object ) , typeof ( ValueComparer ) ] ) ! ;
1925
2026 private static readonly MethodInfo GetHashCodeMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
27+ nameof ( GetHashCode ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( IEnumerable ) , typeof ( Func < TElement , int > ) ] ) ! ;
28+
29+ private static readonly MethodInfo LegacyGetHashCodeMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
2130 nameof ( GetHashCode ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( IEnumerable ) , typeof ( ValueComparer ) ] ) ! ;
2231
2332 private static readonly MethodInfo SnapshotMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
33+ nameof ( Snapshot ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( Func < TElement , TElement > ) ] ) ! ;
34+
35+ private static readonly MethodInfo LegacySnapshotMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
2436 nameof ( Snapshot ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( ValueComparer ) ] ) ! ;
2537
2638 /// <summary>
@@ -52,14 +64,56 @@ ValueComparer IInfrastructure<ValueComparer>.Instance
5264 var prm1 = Expression . Parameter ( typeof ( object ) , "a" ) ;
5365 var prm2 = Expression . Parameter ( typeof ( object ) , "b" ) ;
5466
67+ if ( UseOldBehavior35239 )
68+ {
69+ // (a, b) => Compare(a, b, new Comparer(...))
70+ return Expression . Lambda < Func < object ? , object ? , bool > > (
71+ Expression . Call (
72+ LegacyCompareMethod ,
73+ prm1 ,
74+ prm2 ,
75+ #pragma warning disable EF9100
76+ elementComparer . ConstructorExpression ) ,
77+ #pragma warning restore EF9100
78+ prm1 ,
79+ prm2 ) ;
80+ }
81+
82+ // we check the compatibility between element type we expect on the Equals methods
83+ // vs what we actually get from the element comparer
84+ // if the expected is assignable from actual we can just do simple call...
85+ if ( typeof ( TElement ) . IsAssignableFrom ( elementComparer . Type ) )
86+ {
87+ // (a, b) => Compare(a, b, elementComparer.Equals)
88+ return Expression . Lambda < Func < object ? , object ? , bool > > (
89+ Expression . Call (
90+ CompareMethod ,
91+ prm1 ,
92+ prm2 ,
93+ elementComparer . EqualsExpression ) ,
94+ prm1 ,
95+ prm2 ) ;
96+ }
97+
98+ // ...otherwise we need to rewrite the actual lambda (as we can't change the expected signature)
99+ // in that case we are rewriting the inner lambda parameters to TElement and cast to the element comparer
100+ // type argument in the body, so that semantics of the element comparison func don't change
101+ var newInnerPrm1 = Expression . Parameter ( typeof ( TElement ) , "a" ) ;
102+ var newInnerPrm2 = Expression . Parameter ( typeof ( TElement ) , "b" ) ;
103+
104+ var newEqualsExpressionBody = elementComparer . ExtractEqualsBody (
105+ Expression . Convert ( newInnerPrm1 , elementComparer . Type ) ,
106+ Expression . Convert ( newInnerPrm2 , elementComparer . Type ) ) ;
107+
55108 return Expression . Lambda < Func < object ? , object ? , bool > > (
56109 Expression . Call (
57110 CompareMethod ,
58111 prm1 ,
59112 prm2 ,
60- #pragma warning disable EF9100
61- elementComparer . ConstructorExpression ) ,
62- #pragma warning restore EF9100
113+ Expression . Lambda (
114+ newEqualsExpressionBody ,
115+ newInnerPrm1 ,
116+ newInnerPrm2 ) ) ,
63117 prm1 ,
64118 prm2 ) ;
65119 }
@@ -68,32 +122,144 @@ private static Expression<Func<object, int>> GetHashCodeLambda(ValueComparer ele
68122 {
69123 var prm = Expression . Parameter ( typeof ( object ) , "o" ) ;
70124
125+ if ( UseOldBehavior35239 )
126+ {
127+ // o => GetHashCode((IEnumerable)o, new Comparer(...))
128+ return Expression . Lambda < Func < object , int > > (
129+ Expression . Call (
130+ LegacyGetHashCodeMethod ,
131+ Expression . Convert (
132+ prm ,
133+ typeof ( IEnumerable ) ) ,
134+ #pragma warning disable EF9100
135+ elementComparer . ConstructorExpression ) ,
136+ #pragma warning restore EF9100
137+ prm ) ;
138+ }
139+
140+ if ( typeof ( TElement ) . IsAssignableFrom ( elementComparer . Type ) )
141+ {
142+ // o => GetHashCode((IEnumerable)o, elementComparer.GetHashCode)
143+ return Expression . Lambda < Func < object , int > > (
144+ Expression . Call (
145+ GetHashCodeMethod ,
146+ Expression . Convert (
147+ prm ,
148+ typeof ( IEnumerable ) ) ,
149+ elementComparer . HashCodeExpression ) ,
150+ prm ) ;
151+ }
152+
153+ var newInnerPrm = Expression . Parameter ( typeof ( TElement ) , "o" ) ;
154+
155+ var newInnerBody = elementComparer . ExtractHashCodeBody (
156+ Expression . Convert (
157+ newInnerPrm ,
158+ elementComparer . Type ) ) ;
159+
71160 return Expression . Lambda < Func < object , int > > (
72161 Expression . Call (
73162 GetHashCodeMethod ,
74163 Expression . Convert (
75164 prm ,
76165 typeof ( IEnumerable ) ) ,
77- #pragma warning disable EF9100
78- elementComparer . ConstructorExpression ) ,
79- #pragma warning restore EF9100
166+ Expression . Lambda (
167+ newInnerBody ,
168+ newInnerPrm ) ) ,
80169 prm ) ;
81170 }
82171
83172 private static Expression < Func < object , object > > SnapshotLambda ( ValueComparer elementComparer )
84173 {
85174 var prm = Expression . Parameter ( typeof ( object ) , "source" ) ;
86175
176+ if ( UseOldBehavior35239 )
177+ {
178+ // source => Snapshot(source, new Comparer(..))
179+ return Expression . Lambda < Func < object , object > > (
180+ Expression . Call (
181+ LegacySnapshotMethod ,
182+ prm ,
183+ #pragma warning disable EF9100
184+ elementComparer . ConstructorExpression ) ,
185+ #pragma warning restore EF9100
186+ prm ) ;
187+ }
188+
189+ // TElement is both argument and return type so the types need to be the same
190+ if ( typeof ( TElement ) == elementComparer . Type )
191+ {
192+ // source => Snapshot(source, elementComparer.Snapshot)
193+ return Expression . Lambda < Func < object , object > > (
194+ Expression . Call (
195+ SnapshotMethod ,
196+ prm ,
197+ elementComparer . SnapshotExpression ) ,
198+ prm ) ;
199+ }
200+
201+ var newInnerPrm = Expression . Parameter ( typeof ( TElement ) , "source" ) ;
202+
203+ var newInnerBody = elementComparer . ExtractSnapshotBody (
204+ Expression . Convert (
205+ newInnerPrm ,
206+ elementComparer . Type ) ) ;
207+
208+ // note we need to also convert the result of inner lambda back to TElement
87209 return Expression . Lambda < Func < object , object > > (
88210 Expression . Call (
89211 SnapshotMethod ,
90212 prm ,
91- #pragma warning disable EF9100
92- elementComparer . ConstructorExpression ) ,
93- #pragma warning restore EF9100
213+ Expression . Lambda (
214+ Expression . Convert (
215+ newInnerBody ,
216+ typeof ( TElement ) ) ,
217+ newInnerPrm ) ) ,
94218 prm ) ;
95219 }
96220
221+ private static bool Compare ( object ? a , object ? b , Func < TElement ? , TElement ? , bool > elementCompare )
222+ {
223+ if ( ReferenceEquals ( a , b ) )
224+ {
225+ return true ;
226+ }
227+
228+ if ( a is null )
229+ {
230+ return b is null ;
231+ }
232+
233+ if ( b is null )
234+ {
235+ return false ;
236+ }
237+
238+ if ( a is IReadOnlyDictionary < string , TElement ? > aDictionary && b is IReadOnlyDictionary < string , TElement ? > bDictionary )
239+ {
240+ if ( aDictionary . Count != bDictionary . Count )
241+ {
242+ return false ;
243+ }
244+
245+ foreach ( var pair in aDictionary )
246+ {
247+ if ( ! bDictionary . TryGetValue ( pair . Key , out var bValue )
248+ || ! elementCompare ( pair . Value , bValue ) )
249+ {
250+ return false ;
251+ }
252+ }
253+
254+ return true ;
255+ }
256+
257+ throw new InvalidOperationException (
258+ CosmosStrings . BadDictionaryType (
259+ ( a is IDictionary < string , TElement ? > ? b : a ) . GetType ( ) . ShortDisplayName ( ) ,
260+ typeof ( IDictionary < , > ) . MakeGenericType ( typeof ( string ) , typeof ( TElement ) ) . ShortDisplayName ( ) ) ) ;
261+ }
262+
97263 private static bool Compare ( object ? a , object ? b , ValueComparer elementComparer )
98264 {
99265 if ( ReferenceEquals ( a , b ) )
@@ -136,6 +302,27 @@ private static bool Compare(object? a, object? b, ValueComparer elementComparer)
136302 typeof ( IDictionary < , > ) . MakeGenericType ( typeof ( string ) , elementComparer . Type ) . ShortDisplayName ( ) ) ) ;
137303 }
138304
305+ private static int GetHashCode ( IEnumerable source , Func < TElement ? , int > elementGetHashCode )
306+ {
307+ if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
308+ {
309+ throw new InvalidOperationException (
310+ CosmosStrings . BadDictionaryType (
311+ source . GetType ( ) . ShortDisplayName ( ) ,
312+ typeof ( IList < > ) . MakeGenericType ( typeof ( TElement ) ) . ShortDisplayName ( ) ) ) ;
313+ }
314+
315+ var hash = new HashCode ( ) ;
316+
317+ foreach ( var pair in sourceDictionary )
318+ {
319+ hash . Add ( pair . Key ) ;
320+ hash . Add ( pair . Value == null ? 0 : elementGetHashCode ( pair . Value ) ) ;
321+ }
322+
323+ return hash . ToHashCode ( ) ;
324+ }
325+
139326 private static int GetHashCode ( IEnumerable source , ValueComparer elementComparer )
140327 {
141328 if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
@@ -157,6 +344,25 @@ private static int GetHashCode(IEnumerable source, ValueComparer elementComparer
157344 return hash . ToHashCode ( ) ;
158345 }
159346
347+ private static IReadOnlyDictionary < string , TElement ? > Snapshot ( object source , Func < TElement ? , TElement ? > elementSnapshot )
348+ {
349+ if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
350+ {
351+ throw new InvalidOperationException (
352+ CosmosStrings . BadDictionaryType (
353+ source . GetType ( ) . ShortDisplayName ( ) ,
354+ typeof ( IDictionary < , > ) . MakeGenericType ( typeof ( string ) , typeof ( TElement ) ) . ShortDisplayName ( ) ) ) ;
355+ }
356+
357+ var snapshot = new Dictionary < string , TElement ? > ( ) ;
358+ foreach ( var pair in sourceDictionary )
359+ {
360+ snapshot [ pair . Key ] = pair . Value == null ? default : ( TElement ? ) elementSnapshot ( pair . Value ) ;
361+ }
362+
363+ return snapshot ;
364+ }
365+
160366 private static IReadOnlyDictionary < string , TElement ? > Snapshot ( object source , ValueComparer elementComparer )
161367 {
162368 if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
0 commit comments