Skip to content

Commit 33c2573

Browse files
authored
Cache Qid instances for common types (#6371)
Review: @dstrain115
1 parent 6d437c4 commit 33c2573

File tree

6 files changed

+195
-67
lines changed

6 files changed

+195
-67
lines changed

cirq-core/cirq/devices/grid_qubit.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import abc
1616
import functools
17+
import weakref
1718
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, TYPE_CHECKING, Union
1819
from typing_extensions import Self
1920

@@ -34,14 +35,6 @@ class _BaseGridQid(ops.Qid):
3435
_dimension: int
3536
_hash: Optional[int] = None
3637

37-
def __getstate__(self):
38-
# Don't save hash when pickling; see #3777.
39-
state = self.__dict__
40-
if "_hash" in state:
41-
state = state.copy()
42-
del state["_hash"]
43-
return state
44-
4538
def __hash__(self) -> int:
4639
if self._hash is None:
4740
self._hash = hash((self._row, self._col, self._dimension))
@@ -50,7 +43,7 @@ def __hash__(self) -> int:
5043
def __eq__(self, other):
5144
# Explicitly implemented for performance (vs delegating to Qid).
5245
if isinstance(other, _BaseGridQid):
53-
return (
46+
return self is other or (
5447
self._row == other._row
5548
and self._col == other._col
5649
and self._dimension == other._dimension
@@ -60,7 +53,7 @@ def __eq__(self, other):
6053
def __ne__(self, other):
6154
# Explicitly implemented for performance (vs delegating to Qid).
6255
if isinstance(other, _BaseGridQid):
63-
return (
56+
return self is not other and (
6457
self._row != other._row
6558
or self._col != other._col
6659
or self._dimension != other._dimension
@@ -178,22 +171,36 @@ class GridQid(_BaseGridQid):
178171
cirq.GridQid(5, 4, dimension=2)
179172
"""
180173

181-
def __init__(self, row: int, col: int, *, dimension: int) -> None:
182-
"""Initializes a grid qid at the given row, col coordinate
174+
# Cache of existing GridQid instances, returned by __new__ if available.
175+
# Holds weak references so instances can still be garbage collected.
176+
_cache = weakref.WeakValueDictionary[Tuple[int, int, int], 'cirq.GridQid']()
177+
178+
def __new__(cls, row: int, col: int, *, dimension: int) -> 'cirq.GridQid':
179+
"""Creates a grid qid at the given row, col coordinate
183180
184181
Args:
185182
row: the row coordinate
186183
col: the column coordinate
187184
dimension: The dimension of the qid's Hilbert space, i.e.
188185
the number of quantum levels.
189186
"""
190-
self.validate_dimension(dimension)
191-
self._row = row
192-
self._col = col
193-
self._dimension = dimension
187+
key = (row, col, dimension)
188+
inst = cls._cache.get(key)
189+
if inst is None:
190+
cls.validate_dimension(dimension)
191+
inst = super().__new__(cls)
192+
inst._row = row
193+
inst._col = col
194+
inst._dimension = dimension
195+
cls._cache[key] = inst
196+
return inst
197+
198+
def __getnewargs_ex__(self):
199+
"""Returns a tuple of (args, kwargs) to pass to __new__ when unpickling."""
200+
return (self._row, self._col), {"dimension": self._dimension}
194201

195202
def _with_row_col(self, row: int, col: int) -> 'GridQid':
196-
return GridQid(row, col, dimension=self.dimension)
203+
return GridQid(row, col, dimension=self._dimension)
197204

198205
@staticmethod
199206
def square(diameter: int, top: int = 0, left: int = 0, *, dimension: int) -> List['GridQid']:
@@ -290,16 +297,16 @@ def from_diagram(diagram: str, dimension: int) -> List['GridQid']:
290297
return [GridQid(*c, dimension=dimension) for c in coords]
291298

292299
def __repr__(self) -> str:
293-
return f"cirq.GridQid({self._row}, {self._col}, dimension={self.dimension})"
300+
return f"cirq.GridQid({self._row}, {self._col}, dimension={self._dimension})"
294301

295302
def __str__(self) -> str:
296-
return f"q({self._row}, {self._col}) (d={self.dimension})"
303+
return f"q({self._row}, {self._col}) (d={self._dimension})"
297304

298305
def _circuit_diagram_info_(
299306
self, args: 'cirq.CircuitDiagramInfoArgs'
300307
) -> 'cirq.CircuitDiagramInfo':
301308
return protocols.CircuitDiagramInfo(
302-
wire_symbols=(f"({self._row}, {self._col}) (d={self.dimension})",)
309+
wire_symbols=(f"({self._row}, {self._col}) (d={self._dimension})",)
303310
)
304311

305312
def _json_dict_(self) -> Dict[str, Any]:
@@ -325,11 +332,31 @@ class GridQubit(_BaseGridQid):
325332

326333
_dimension = 2
327334

328-
def __init__(self, row: int, col: int) -> None:
329-
self._row = row
330-
self._col = col
335+
# Cache of existing GridQubit instances, returned by __new__ if available.
336+
# Holds weak references so instances can still be garbage collected.
337+
_cache = weakref.WeakValueDictionary[Tuple[int, int], 'cirq.GridQubit']()
331338

332-
def _with_row_col(self, row: int, col: int):
339+
def __new__(cls, row: int, col: int) -> 'cirq.GridQubit':
340+
"""Creates a grid qubit at the given row, col coordinate
341+
342+
Args:
343+
row: the row coordinate
344+
col: the column coordinate
345+
"""
346+
key = (row, col)
347+
inst = cls._cache.get(key)
348+
if inst is None:
349+
inst = super().__new__(cls)
350+
inst._row = row
351+
inst._col = col
352+
cls._cache[key] = inst
353+
return inst
354+
355+
def __getnewargs__(self):
356+
"""Returns a tuple of args to pass to __new__ when unpickling."""
357+
return (self._row, self._col)
358+
359+
def _with_row_col(self, row: int, col: int) -> 'GridQubit':
333360
return GridQubit(row, col)
334361

335362
def _cmp_tuple(self):

cirq-core/cirq/devices/grid_qubit_test.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,29 @@ def test_eq():
4040
eq.make_equality_group(lambda: cirq.GridQid(0, 0, dimension=3))
4141

4242

43-
def test_pickled_hash():
44-
q = cirq.GridQubit(3, 4)
45-
q_bad = cirq.GridQubit(3, 4)
43+
def test_grid_qubit_pickled_hash():
44+
# Use a large number that is unlikely to be used by any other tests.
45+
row, col = 123456789, 2345678910
46+
q_bad = cirq.GridQubit(row, col)
47+
cirq.GridQubit._cache.pop((row, col))
48+
q = cirq.GridQubit(row, col)
49+
_test_qid_pickled_hash(q, q_bad)
50+
51+
52+
def test_grid_qid_pickled_hash():
53+
# Use a large number that is unlikely to be used by any other tests.
54+
row, col = 123456789, 2345678910
55+
q_bad = cirq.GridQid(row, col, dimension=3)
56+
cirq.GridQid._cache.pop((row, col, 3))
57+
q = cirq.GridQid(row, col, dimension=3)
58+
_test_qid_pickled_hash(q, q_bad)
59+
60+
61+
def _test_qid_pickled_hash(q: 'cirq.Qid', q_bad: 'cirq.Qid') -> None:
62+
"""Test that hashes are not pickled with Qid instances."""
63+
assert q_bad is not q
4664
_ = hash(q_bad) # compute hash to ensure it is cached.
47-
q_bad._hash = q_bad._hash + 1
65+
q_bad._hash = q_bad._hash + 1 # type: ignore[attr-defined]
4866
assert q_bad == q
4967
assert hash(q_bad) != hash(q)
5068
data = pickle.dumps(q_bad)

cirq-core/cirq/devices/line_qubit.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
import abc
1616
import functools
17-
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, TYPE_CHECKING, Union
17+
import weakref
18+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
1819
from typing_extensions import Self
1920

2021
from cirq import ops, protocols
@@ -31,14 +32,6 @@ class _BaseLineQid(ops.Qid):
3132
_dimension: int
3233
_hash: Optional[int] = None
3334

34-
def __getstate__(self):
35-
# Don't save hash when pickling; see #3777.
36-
state = self.__dict__
37-
if "_hash" in state:
38-
state = state.copy()
39-
del state["_hash"]
40-
return state
41-
4235
def __hash__(self) -> int:
4336
if self._hash is None:
4437
self._hash = hash((self._x, self._dimension))
@@ -47,13 +40,15 @@ def __hash__(self) -> int:
4740
def __eq__(self, other):
4841
# Explicitly implemented for performance (vs delegating to Qid).
4942
if isinstance(other, _BaseLineQid):
50-
return self._x == other._x and self._dimension == other._dimension
43+
return self is other or (self._x == other._x and self._dimension == other._dimension)
5144
return NotImplemented
5245

5346
def __ne__(self, other):
5447
# Explicitly implemented for performance (vs delegating to Qid).
5548
if isinstance(other, _BaseLineQid):
56-
return self._x != other._x or self._dimension != other._dimension
49+
return self is not other and (
50+
self._x != other._x or self._dimension != other._dimension
51+
)
5752
return NotImplemented
5853

5954
def _comparison_key(self):
@@ -154,17 +149,31 @@ class LineQid(_BaseLineQid):
154149
155150
"""
156151

157-
def __init__(self, x: int, dimension: int) -> None:
152+
# Cache of existing LineQid instances, returned by __new__ if available.
153+
# Holds weak references so instances can still be garbage collected.
154+
_cache = weakref.WeakValueDictionary[Tuple[int, int], 'cirq.LineQid']()
155+
156+
def __new__(cls, x: int, dimension: int) -> 'cirq.LineQid':
158157
"""Initializes a line qid at the given x coordinate.
159158
160159
Args:
161160
x: The x coordinate.
162161
dimension: The dimension of the qid's Hilbert space, i.e.
163162
the number of quantum levels.
164163
"""
165-
self.validate_dimension(dimension)
166-
self._x = x
167-
self._dimension = dimension
164+
key = (x, dimension)
165+
inst = cls._cache.get(key)
166+
if inst is None:
167+
cls.validate_dimension(dimension)
168+
inst = super().__new__(cls)
169+
inst._x = x
170+
inst._dimension = dimension
171+
cls._cache[key] = inst
172+
return inst
173+
174+
def __getnewargs__(self):
175+
"""Returns a tuple of args to pass to __new__ when unpickling."""
176+
return (self._x, self._dimension)
168177

169178
def _with_x(self, x: int) -> 'LineQid':
170179
return LineQid(x, dimension=self._dimension)
@@ -246,13 +255,26 @@ class LineQubit(_BaseLineQid):
246255

247256
_dimension = 2
248257

249-
def __init__(self, x: int) -> None:
250-
"""Initializes a line qubit at the given x coordinate.
258+
# Cache of existing LineQubit instances, returned by __new__ if available.
259+
# Holds weak references so instances can still be garbage collected.
260+
_cache = weakref.WeakValueDictionary[int, 'cirq.LineQubit']()
261+
262+
def __new__(cls, x: int) -> 'cirq.LineQubit':
263+
"""Initializes a line qid at the given x coordinate.
251264
252265
Args:
253266
x: The x coordinate.
254267
"""
255-
self._x = x
268+
inst = cls._cache.get(x)
269+
if inst is None:
270+
inst = super().__new__(cls)
271+
inst._x = x
272+
cls._cache[x] = inst
273+
return inst
274+
275+
def __getnewargs__(self):
276+
"""Returns a tuple of args to pass to __new__ when unpickling."""
277+
return (self._x,)
256278

257279
def _with_x(self, x: int) -> 'LineQubit':
258280
return LineQubit(x)

cirq-core/cirq/devices/line_qubit_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616

1717
import cirq
18+
from cirq.devices.grid_qubit_test import _test_qid_pickled_hash
1819

1920

2021
def test_init():
@@ -67,6 +68,24 @@ def test_cmp_failure():
6768
_ = cirq.LineQid(1, 3) < 0
6869

6970

71+
def test_line_qubit_pickled_hash():
72+
# Use a large number that is unlikely to be used by any other tests.
73+
x = 1234567891011
74+
q_bad = cirq.LineQubit(x)
75+
cirq.LineQubit._cache.pop(x)
76+
q = cirq.LineQubit(x)
77+
_test_qid_pickled_hash(q, q_bad)
78+
79+
80+
def test_line_qid_pickled_hash():
81+
# Use a large number that is unlikely to be used by any other tests.
82+
x = 1234567891011
83+
q_bad = cirq.LineQid(x, dimension=3)
84+
cirq.LineQid._cache.pop((x, 3))
85+
q = cirq.LineQid(x, dimension=3)
86+
_test_qid_pickled_hash(q, q_bad)
87+
88+
7089
def test_is_adjacent():
7190
assert cirq.LineQubit(1).is_adjacent(cirq.LineQubit(2))
7291
assert cirq.LineQubit(1).is_adjacent(cirq.LineQubit(0))

0 commit comments

Comments
 (0)