1414
1515import abc
1616import functools
17+ import weakref
1718from typing import Any , Dict , Iterable , List , Optional , Tuple , Set , TYPE_CHECKING , Union
1819from typing_extensions import Self
1920
@@ -34,14 +35,6 @@ class _BaseGridQid(ops.Qid):
3435 _dimension : int
3536 _hash : Optional [int ] = None
3637
37- def __getstate__ (self ):
38- # Don't save hash when pickling; see #3777.
39- state = self .__dict__
40- if "_hash" in state :
41- state = state .copy ()
42- del state ["_hash" ]
43- return state
44-
4538 def __hash__ (self ) -> int :
4639 if self ._hash is None :
4740 self ._hash = hash ((self ._row , self ._col , self ._dimension ))
@@ -50,7 +43,7 @@ def __hash__(self) -> int:
5043 def __eq__ (self , other ):
5144 # Explicitly implemented for performance (vs delegating to Qid).
5245 if isinstance (other , _BaseGridQid ):
53- return (
46+ return self is other or (
5447 self ._row == other ._row
5548 and self ._col == other ._col
5649 and self ._dimension == other ._dimension
@@ -60,7 +53,7 @@ def __eq__(self, other):
6053 def __ne__ (self , other ):
6154 # Explicitly implemented for performance (vs delegating to Qid).
6255 if isinstance (other , _BaseGridQid ):
63- return (
56+ return self is not other and (
6457 self ._row != other ._row
6558 or self ._col != other ._col
6659 or self ._dimension != other ._dimension
@@ -178,22 +171,36 @@ class GridQid(_BaseGridQid):
178171 cirq.GridQid(5, 4, dimension=2)
179172 """
180173
181- def __init__ (self , row : int , col : int , * , dimension : int ) -> None :
182- """Initializes a grid qid at the given row, col coordinate
174+ # Cache of existing GridQid instances, returned by __new__ if available.
175+ # Holds weak references so instances can still be garbage collected.
176+ _cache = weakref .WeakValueDictionary [Tuple [int , int , int ], 'cirq.GridQid' ]()
177+
178+ def __new__ (cls , row : int , col : int , * , dimension : int ) -> 'cirq.GridQid' :
179+ """Creates a grid qid at the given row, col coordinate
183180
184181 Args:
185182 row: the row coordinate
186183 col: the column coordinate
187184 dimension: The dimension of the qid's Hilbert space, i.e.
188185 the number of quantum levels.
189186 """
190- self .validate_dimension (dimension )
191- self ._row = row
192- self ._col = col
193- self ._dimension = dimension
187+ key = (row , col , dimension )
188+ inst = cls ._cache .get (key )
189+ if inst is None :
190+ cls .validate_dimension (dimension )
191+ inst = super ().__new__ (cls )
192+ inst ._row = row
193+ inst ._col = col
194+ inst ._dimension = dimension
195+ cls ._cache [key ] = inst
196+ return inst
197+
198+ def __getnewargs_ex__ (self ):
199+ """Returns a tuple of (args, kwargs) to pass to __new__ when unpickling."""
200+ return (self ._row , self ._col ), {"dimension" : self ._dimension }
194201
195202 def _with_row_col (self , row : int , col : int ) -> 'GridQid' :
196- return GridQid (row , col , dimension = self .dimension )
203+ return GridQid (row , col , dimension = self ._dimension )
197204
198205 @staticmethod
199206 def square (diameter : int , top : int = 0 , left : int = 0 , * , dimension : int ) -> List ['GridQid' ]:
@@ -290,16 +297,16 @@ def from_diagram(diagram: str, dimension: int) -> List['GridQid']:
290297 return [GridQid (* c , dimension = dimension ) for c in coords ]
291298
292299 def __repr__ (self ) -> str :
293- return f"cirq.GridQid({ self ._row } , { self ._col } , dimension={ self .dimension } )"
300+ return f"cirq.GridQid({ self ._row } , { self ._col } , dimension={ self ._dimension } )"
294301
295302 def __str__ (self ) -> str :
296- return f"q({ self ._row } , { self ._col } ) (d={ self .dimension } )"
303+ return f"q({ self ._row } , { self ._col } ) (d={ self ._dimension } )"
297304
298305 def _circuit_diagram_info_ (
299306 self , args : 'cirq.CircuitDiagramInfoArgs'
300307 ) -> 'cirq.CircuitDiagramInfo' :
301308 return protocols .CircuitDiagramInfo (
302- wire_symbols = (f"({ self ._row } , { self ._col } ) (d={ self .dimension } )" ,)
309+ wire_symbols = (f"({ self ._row } , { self ._col } ) (d={ self ._dimension } )" ,)
303310 )
304311
305312 def _json_dict_ (self ) -> Dict [str , Any ]:
@@ -325,11 +332,31 @@ class GridQubit(_BaseGridQid):
325332
326333 _dimension = 2
327334
328- def __init__ ( self , row : int , col : int ) -> None :
329- self . _row = row
330- self . _col = col
335+ # Cache of existing GridQubit instances, returned by __new__ if available.
336+ # Holds weak references so instances can still be garbage collected.
337+ _cache = weakref . WeakValueDictionary [ Tuple [ int , int ], 'cirq.GridQubit' ]()
331338
332- def _with_row_col (self , row : int , col : int ):
339+ def __new__ (cls , row : int , col : int ) -> 'cirq.GridQubit' :
340+ """Creates a grid qubit at the given row, col coordinate
341+
342+ Args:
343+ row: the row coordinate
344+ col: the column coordinate
345+ """
346+ key = (row , col )
347+ inst = cls ._cache .get (key )
348+ if inst is None :
349+ inst = super ().__new__ (cls )
350+ inst ._row = row
351+ inst ._col = col
352+ cls ._cache [key ] = inst
353+ return inst
354+
355+ def __getnewargs__ (self ):
356+ """Returns a tuple of args to pass to __new__ when unpickling."""
357+ return (self ._row , self ._col )
358+
359+ def _with_row_col (self , row : int , col : int ) -> 'GridQubit' :
333360 return GridQubit (row , col )
334361
335362 def _cmp_tuple (self ):
0 commit comments