@@ -14,13 +14,25 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.ChangeTracking.Internal;
1414/// </summary>
1515public sealed class StringDictionaryComparer < TDictionary , TElement > : ValueComparer < object > , IInfrastructure < ValueComparer >
1616{
17+ private static readonly bool UseOldBehavior35239 =
18+ AppContext . TryGetSwitch ( "Microsoft.EntityFrameworkCore.Issue35239" , out var enabled35239 ) && enabled35239 ;
19+
1720 private static readonly MethodInfo CompareMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
21+ nameof ( Compare ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( object ) , typeof ( Func < TElement , TElement , bool > ) ] ) ! ;
22+
23+ private static readonly MethodInfo LegacyCompareMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
1824 nameof ( Compare ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( object ) , typeof ( ValueComparer ) ] ) ! ;
1925
2026 private static readonly MethodInfo GetHashCodeMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
27+ nameof ( GetHashCode ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( IEnumerable ) , typeof ( Func < TElement , int > ) ] ) ! ;
28+
29+ private static readonly MethodInfo LegacyGetHashCodeMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
2130 nameof ( GetHashCode ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( IEnumerable ) , typeof ( ValueComparer ) ] ) ! ;
2231
2332 private static readonly MethodInfo SnapshotMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
33+ nameof ( Snapshot ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( Func < TElement , TElement > ) ] ) ! ;
34+
35+ private static readonly MethodInfo LegacySnapshotMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
2436 nameof ( Snapshot ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( ValueComparer ) ] ) ! ;
2537
2638 /// <summary>
@@ -52,9 +64,23 @@ ValueComparer IInfrastructure<ValueComparer>.Instance
5264 var prm1 = Expression . Parameter ( typeof ( object ) , "a" ) ;
5365 var prm2 = Expression . Parameter ( typeof ( object ) , "b" ) ;
5466
67+ if ( elementComparer is ValueComparer < TElement > && ! UseOldBehavior35239 )
68+ {
69+ // (a, b) => Compare(a, b, elementComparer.Equals)
70+ return Expression . Lambda < Func < object ? , object ? , bool > > (
71+ Expression . Call (
72+ CompareMethod ,
73+ prm1 ,
74+ prm2 ,
75+ elementComparer . EqualsExpression ) ,
76+ prm1 ,
77+ prm2 ) ;
78+ }
79+
80+ // (a, b) => Compare(a, b, new Comparer(...))
5581 return Expression . Lambda < Func < object ? , object ? , bool > > (
5682 Expression . Call (
57- CompareMethod ,
83+ LegacyCompareMethod ,
5884 prm1 ,
5985 prm2 ,
6086#pragma warning disable EF9100
@@ -68,9 +94,23 @@ private static Expression<Func<object, int>> GetHashCodeLambda(ValueComparer ele
6894 {
6995 var prm = Expression . Parameter ( typeof ( object ) , "o" ) ;
7096
97+ if ( elementComparer is ValueComparer < TElement > && ! UseOldBehavior35239 )
98+ {
99+ // o => GetHashCode((IEnumerable)o, elementComparer.GetHashCode)
100+ return Expression . Lambda < Func < object , int > > (
101+ Expression . Call (
102+ GetHashCodeMethod ,
103+ Expression . Convert (
104+ prm ,
105+ typeof ( IEnumerable ) ) ,
106+ elementComparer . HashCodeExpression ) ,
107+ prm ) ;
108+ }
109+
110+ // o => GetHashCode((IEnumerable)o, new Comparer(...))
71111 return Expression . Lambda < Func < object , int > > (
72112 Expression . Call (
73- GetHashCodeMethod ,
113+ LegacyGetHashCodeMethod ,
74114 Expression . Convert (
75115 prm ,
76116 typeof ( IEnumerable ) ) ,
@@ -84,16 +124,70 @@ private static Expression<Func<object, object>> SnapshotLambda(ValueComparer ele
84124 {
85125 var prm = Expression . Parameter ( typeof ( object ) , "source" ) ;
86126
127+ if ( elementComparer is ValueComparer < TElement > && ! UseOldBehavior35239 )
128+ {
129+ // source => Snapshot(source, elementComparer.Snapshot)
130+ return Expression . Lambda < Func < object , object > > (
131+ Expression . Call (
132+ SnapshotMethod ,
133+ prm ,
134+ elementComparer . SnapshotExpression ) ,
135+ prm ) ;
136+ }
137+
138+ // source => Snapshot(source, new Comparer(..))
87139 return Expression . Lambda < Func < object , object > > (
88140 Expression . Call (
89- SnapshotMethod ,
141+ LegacySnapshotMethod ,
90142 prm ,
91143#pragma warning disable EF9100
92144 elementComparer . ConstructorExpression ) ,
93145#pragma warning restore EF9100
94146 prm ) ;
95147 }
96148
149+ private static bool Compare ( object ? a , object ? b , Func < TElement ? , TElement ? , bool > elementCompare )
150+ {
151+ if ( ReferenceEquals ( a , b ) )
152+ {
153+ return true ;
154+ }
155+
156+ if ( a is null )
157+ {
158+ return b is null ;
159+ }
160+
161+ if ( b is null )
162+ {
163+ return false ;
164+ }
165+
166+ if ( a is IReadOnlyDictionary < string , TElement ? > aDictionary && b is IReadOnlyDictionary < string , TElement ? > bDictionary )
167+ {
168+ if ( aDictionary . Count != bDictionary . Count )
169+ {
170+ return false ;
171+ }
172+
173+ foreach ( var pair in aDictionary )
174+ {
175+ if ( ! bDictionary . TryGetValue ( pair . Key , out var bValue )
176+ || ! elementCompare ( pair . Value , bValue ) )
177+ {
178+ return false ;
179+ }
180+ }
181+
182+ return true ;
183+ }
184+
185+ throw new InvalidOperationException (
186+ CosmosStrings . BadDictionaryType (
187+ ( a is IDictionary < string , TElement ? > ? b : a ) . GetType ( ) . ShortDisplayName ( ) ,
188+ typeof ( IDictionary < , > ) . MakeGenericType ( typeof ( string ) , typeof ( TElement ) ) . ShortDisplayName ( ) ) ) ;
189+ }
190+
97191 private static bool Compare ( object ? a , object ? b , ValueComparer elementComparer )
98192 {
99193 if ( ReferenceEquals ( a , b ) )
@@ -136,6 +230,27 @@ private static bool Compare(object? a, object? b, ValueComparer elementComparer)
136230 typeof ( IDictionary < , > ) . MakeGenericType ( typeof ( string ) , elementComparer . Type ) . ShortDisplayName ( ) ) ) ;
137231 }
138232
233+ private static int GetHashCode ( IEnumerable source , Func < TElement ? , int > elementGetHashCode )
234+ {
235+ if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
236+ {
237+ throw new InvalidOperationException (
238+ CosmosStrings . BadDictionaryType (
239+ source . GetType ( ) . ShortDisplayName ( ) ,
240+ typeof ( IList < > ) . MakeGenericType ( typeof ( TElement ) ) . ShortDisplayName ( ) ) ) ;
241+ }
242+
243+ var hash = new HashCode ( ) ;
244+
245+ foreach ( var pair in sourceDictionary )
246+ {
247+ hash . Add ( pair . Key ) ;
248+ hash . Add ( pair . Value == null ? 0 : elementGetHashCode ( pair . Value ) ) ;
249+ }
250+
251+ return hash . ToHashCode ( ) ;
252+ }
253+
139254 private static int GetHashCode ( IEnumerable source , ValueComparer elementComparer )
140255 {
141256 if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
@@ -157,6 +272,25 @@ private static int GetHashCode(IEnumerable source, ValueComparer elementComparer
157272 return hash . ToHashCode ( ) ;
158273 }
159274
275+ private static IReadOnlyDictionary < string , TElement ? > Snapshot ( object source , Func < TElement ? , TElement ? > elementSnapshot )
276+ {
277+ if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
278+ {
279+ throw new InvalidOperationException (
280+ CosmosStrings . BadDictionaryType (
281+ source . GetType ( ) . ShortDisplayName ( ) ,
282+ typeof ( IDictionary < , > ) . MakeGenericType ( typeof ( string ) , typeof ( TElement ) ) . ShortDisplayName ( ) ) ) ;
283+ }
284+
285+ var snapshot = new Dictionary < string , TElement ? > ( ) ;
286+ foreach ( var pair in sourceDictionary )
287+ {
288+ snapshot [ pair . Key ] = pair . Value == null ? default : ( TElement ? ) elementSnapshot ( pair . Value ) ;
289+ }
290+
291+ return snapshot ;
292+ }
293+
160294 private static IReadOnlyDictionary < string , TElement ? > Snapshot ( object source , ValueComparer elementComparer )
161295 {
162296 if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
0 commit comments