2828 Optional ,
2929 overload ,
3030 Sequence ,
31- Set ,
3231 Tuple ,
3332 Type ,
3433 Union ,
@@ -221,10 +220,22 @@ class CirqEncoder(json.JSONEncoder):
221220 See https://github.com/quantumlib/Cirq/issues/2014
222221 """
223222
223+ def __init__ (self , * args , ** kwargs ) -> None :
224+ super ().__init__ (* args , ** kwargs )
225+ self ._memo : dict [Any , dict ] = {}
226+
224227 def default (self , o ):
225228 # Object with custom method?
226229 if hasattr (o , '_json_dict_' ):
227- return _json_dict_with_cirq_type (o )
230+ json_dict = _json_dict_with_cirq_type (o )
231+ if isinstance (o , SerializableByKey ):
232+ if ref := self ._memo .get (o ):
233+ return ref
234+ key = len (self ._memo )
235+ ref = {"cirq_type" : "REF" , "key" : key }
236+ self ._memo [o ] = ref
237+ return {"cirq_type" : "VAL" , "key" : key , "val" : json_dict }
238+ return json_dict
228239
229240 # Sympy object? (Must come before general number checks.)
230241 # TODO: More support for sympy
@@ -306,27 +317,46 @@ def default(self, o):
306317 return super ().default (o ) # pragma: no cover
307318
308319
309- def _cirq_object_hook (d , resolvers : Sequence [JsonResolver ], context_map : Dict [str , Any ]):
310- if 'cirq_type' not in d :
311- return d
320+ class ObjectHook :
321+ """Callable to be used as object_hook during deserialization."""
322+
323+ LEGACY_CONTEXT_TYPES = {'_ContextualSerialization' , '_SerializedKey' , '_SerializedContext' }
324+
325+ def __init__ (self , resolvers : Sequence [JsonResolver ]) -> None :
326+ self .resolvers = resolvers
327+ self .memo : Dict [int , SerializableByKey ] = {}
328+ self .context_map : Dict [int , SerializableByKey ] = {}
312329
313- if d ['cirq_type' ] == '_SerializedKey' :
314- return _SerializedKey .read_from_context (context_map , ** d )
330+ def __call__ (self , d ):
331+ cirq_type = d .get ('cirq_type' )
332+ if cirq_type is None :
333+ return d
315334
316- if d ['cirq_type' ] == '_SerializedContext' :
317- _SerializedContext .update_context (context_map , ** d )
318- return None
335+ if cirq_type == 'VAL' :
336+ obj = d ['val' ]
337+ self .memo [d ['key' ]] = obj
338+ return obj
319339
320- if d [ ' cirq_type' ] == '_ContextualSerialization ' :
321- return _ContextualSerialization . deserialize_with_context ( ** d )
340+ if cirq_type == 'REF ' :
341+ return self . memo [ d [ 'key' ]]
322342
323- cls = factory_from_json (d ['cirq_type' ], resolvers = resolvers )
324- from_json_dict = getattr (cls , '_from_json_dict_' , None )
325- if from_json_dict is not None :
326- return from_json_dict (** d )
343+ # Deserialize from legacy "contextual serialization" format
344+ if cirq_type in self .LEGACY_CONTEXT_TYPES :
345+ if cirq_type == '_SerializedKey' :
346+ return self .context_map [d ['key' ]]
347+ if cirq_type == '_SerializedContext' :
348+ self .context_map [d ['key' ]] = d ['obj' ]
349+ return None
350+ if cirq_type == '_ContextualSerialization' :
351+ return d ['object_dag' ][- 1 ]
327352
328- del d ['cirq_type' ]
329- return cls (** d )
353+ cls = factory_from_json (cirq_type , resolvers = self .resolvers )
354+ from_json_dict = getattr (cls , '_from_json_dict_' , None )
355+ if from_json_dict is not None :
356+ return from_json_dict (** d )
357+
358+ del d ['cirq_type' ]
359+ return cls (** d )
330360
331361
332362class SerializableByKey (SupportsJSON ):
@@ -338,137 +368,6 @@ class SerializableByKey(SupportsJSON):
338368 """
339369
340370
341- class _SerializedKey (SupportsJSON ):
342- """Internal object for holding a SerializableByKey key.
343-
344- This is a private type used in contextual serialization. Its deserialization
345- is context-dependent, and is not expected to match the original; in other
346- words, `cls._from_json_dict_(obj._json_dict_())` does not return
347- the original `obj` for this type.
348- """
349-
350- def __init__ (self , key : str ):
351- self .key = key
352-
353- def _json_dict_ (self ):
354- return obj_to_dict_helper (self , ['key' ])
355-
356- @classmethod
357- def _from_json_dict_ (cls , ** kwargs ):
358- raise TypeError (f'Internal error: { cls } should never deserialize with _from_json_dict_.' )
359-
360- @classmethod
361- def read_from_context (cls , context_map , key , ** kwargs ):
362- return context_map [key ]
363-
364-
365- class _SerializedContext (SupportsJSON ):
366- """Internal object for a single SerializableByKey key-to-object mapping.
367-
368- This is a private type used in contextual serialization. Its deserialization
369- is context-dependent, and is not expected to match the original; in other
370- words, `cls._from_json_dict_(obj._json_dict_())` does not return
371- the original `obj` for this type.
372- """
373-
374- def __init__ (self , obj : SerializableByKey , uid : int ):
375- self .key = uid
376- self .obj = obj
377-
378- def _json_dict_ (self ):
379- return obj_to_dict_helper (self , ['key' , 'obj' ])
380-
381- @classmethod
382- def _from_json_dict_ (cls , ** kwargs ):
383- raise TypeError (f'Internal error: { cls } should never deserialize with _from_json_dict_.' )
384-
385- @classmethod
386- def update_context (cls , context_map , key , obj , ** kwargs ):
387- context_map .update ({key : obj })
388-
389-
390- class _ContextualSerialization (SupportsJSON ):
391- """Internal object for serializing an object with its context.
392-
393- This is a private type used in contextual serialization. Its deserialization
394- is context-dependent, and is not expected to match the original; in other
395- words, `cls._from_json_dict_(obj._json_dict_())` does not return
396- the original `obj` for this type.
397- """
398-
399- def __init__ (self , obj : Any ):
400- # Context information and the wrapped object are stored together in
401- # `object_dag` to ensure consistent serialization ordering.
402- self .object_dag = []
403- context = []
404- for sbk in get_serializable_by_keys (obj ):
405- if sbk not in context :
406- context .append (sbk )
407- new_sc = _SerializedContext (sbk , len (context ))
408- self .object_dag .append (new_sc )
409- self .object_dag += [obj ]
410-
411- def _json_dict_ (self ):
412- return obj_to_dict_helper (self , ['object_dag' ])
413-
414- @classmethod
415- def _from_json_dict_ (cls , ** kwargs ):
416- raise TypeError (f'Internal error: { cls } should never deserialize with _from_json_dict_.' )
417-
418- @classmethod
419- def deserialize_with_context (cls , object_dag , ** kwargs ):
420- # The last element of object_dag is the object to be deserialized.
421- return object_dag [- 1 ]
422-
423-
424- def has_serializable_by_keys (obj : Any ) -> bool :
425- """Returns true if obj contains one or more SerializableByKey objects."""
426- if isinstance (obj , SerializableByKey ):
427- return True
428- json_dict = getattr (obj , '_json_dict_' , lambda : None )()
429- if isinstance (json_dict , Dict ):
430- return any (has_serializable_by_keys (v ) for v in json_dict .values ())
431-
432- # Handle primitive container types.
433- if isinstance (obj , Dict ):
434- return any (has_serializable_by_keys (elem ) for pair in obj .items () for elem in pair )
435-
436- if hasattr (obj , '__iter__' ) and not isinstance (obj , str ):
437- # Return False on TypeError because some numpy values
438- # (like np.array(1)) have iterable methods
439- # yet return a TypeError when there is an attempt to iterate over them
440- try :
441- return any (has_serializable_by_keys (elem ) for elem in obj )
442- except TypeError :
443- return False
444- return False
445-
446-
447- def get_serializable_by_keys (obj : Any ) -> List [SerializableByKey ]:
448- """Returns all SerializableByKeys contained by obj.
449-
450- Objects are ordered such that nested objects appear before the object they
451- are nested inside. This is required to ensure SerializableByKeys are only
452- fully defined once in serialization.
453- """
454- result = []
455- if isinstance (obj , SerializableByKey ):
456- result .append (obj )
457- json_dict = getattr (obj , '_json_dict_' , lambda : None )()
458- if isinstance (json_dict , Dict ):
459- for v in json_dict .values ():
460- result = get_serializable_by_keys (v ) + result
461- if result :
462- return result
463-
464- # Handle primitive container types.
465- if isinstance (obj , Dict ):
466- return [sbk for pair in obj .items () for sbk in get_serializable_by_keys (pair )]
467- if hasattr (obj , '__iter__' ) and not isinstance (obj , str ):
468- return [sbk for v in obj for sbk in get_serializable_by_keys (v )]
469- return []
470-
471-
472371def json_namespace (type_obj : Type ) -> str :
473372 """Returns a namespace for JSON serialization of `type_obj`.
474373
@@ -610,37 +509,12 @@ def to_json(
610509 party classes, prefer adding the `_json_dict_` magic method
611510 to your classes rather than overriding this default.
612511 """
613- if has_serializable_by_keys (obj ):
614- obj = _ContextualSerialization (obj )
615-
616- class ContextualEncoder (cls ): # type: ignore
617- """An encoder with a context map for concise serialization."""
618-
619- # These lists populate gradually during serialization. An object
620- # with components defined in 'context' will represent those
621- # components using their keys instead of inline definition.
622- seen : Set [str ] = set ()
623-
624- def default (self , o ):
625- if not isinstance (o , SerializableByKey ):
626- return super ().default (o )
627- for candidate in obj .object_dag [:- 1 ]:
628- if candidate .obj == o :
629- if not candidate .key in ContextualEncoder .seen :
630- ContextualEncoder .seen .add (candidate .key )
631- return _json_dict_with_cirq_type (candidate .obj )
632- else :
633- return _json_dict_with_cirq_type (_SerializedKey (candidate .key ))
634- raise ValueError ("Object mutated during serialization." ) # pragma: no cover
635-
636- cls = ContextualEncoder
637-
638512 if file_or_fn is None :
639513 return json .dumps (obj , indent = indent , separators = separators , cls = cls )
640514
641515 if isinstance (file_or_fn , (str , pathlib .Path )):
642516 with open (file_or_fn , 'w' ) as actually_a_file :
643- json .dump (obj , actually_a_file , indent = indent , cls = cls )
517+ json .dump (obj , actually_a_file , indent = indent , separators = separators , cls = cls )
644518 return None
645519
646520 json .dump (obj , file_or_fn , indent = indent , separators = separators , cls = cls )
@@ -682,10 +556,7 @@ def read_json(
682556 if resolvers is None :
683557 resolvers = DEFAULT_RESOLVERS
684558
685- context_map : Dict [str , 'SerializableByKey' ] = {}
686-
687- def obj_hook (x ):
688- return _cirq_object_hook (x , resolvers , context_map )
559+ obj_hook = ObjectHook (resolvers )
689560
690561 if json_text is not None :
691562 return json .loads (json_text , object_hook = obj_hook )
0 commit comments