diff --git a/cirq-core/cirq/devices/grid_qubit.py b/cirq-core/cirq/devices/grid_qubit.py index e91565ff8c5..6344a88cfff 100644 --- a/cirq-core/cirq/devices/grid_qubit.py +++ b/cirq-core/cirq/devices/grid_qubit.py @@ -19,7 +19,7 @@ import numpy as np -from cirq import _compat, ops, protocols +from cirq import ops, protocols if TYPE_CHECKING: import cirq @@ -29,9 +29,43 @@ class _BaseGridQid(ops.Qid): """The Base class for `GridQid` and `GridQubit`.""" - def __init__(self, row: int, col: int): - self._row = row - self._col = col + _row: int + _col: int + _dimension: int + _hash: Optional[int] = None + + def __getstate__(self): + # Don't save hash when pickling; see #3777. + state = self.__dict__ + if "_hash" in state: + state = state.copy() + del state["_hash"] + return state + + def __hash__(self) -> int: + if self._hash is None: + self._hash = hash((self._row, self._col, self._dimension)) + return self._hash + + def __eq__(self, other): + # Explicitly implemented for performance (vs delegating to Qid). + if isinstance(other, _BaseGridQid): + return ( + self._row == other._row + and self._col == other._col + and self._dimension == other._dimension + ) + return NotImplemented + + def __ne__(self, other): + # Explicitly implemented for performance (vs delegating to Qid). + if isinstance(other, _BaseGridQid): + return ( + self._row != other._row + or self._col != other._col + or self._dimension != other._dimension + ) + return NotImplemented def _comparison_key(self): return self._row, self._col @@ -44,6 +78,10 @@ def row(self) -> int: def col(self) -> int: return self._col + @property + def dimension(self) -> int: + return self._dimension + def with_dimension(self, dimension: int) -> 'GridQid': return GridQid(self._row, self._col, dimension=dimension) @@ -149,13 +187,10 @@ def __init__(self, row: int, col: int, *, dimension: int) -> None: dimension: The dimension of the qid's Hilbert space, i.e. the number of quantum levels. """ - super().__init__(row, col) - self._dimension = dimension self.validate_dimension(dimension) - - @property - def dimension(self): - return self._dimension + self._row = row + self._col = col + self._dimension = dimension def _with_row_col(self, row: int, col: int) -> 'GridQid': return GridQid(row, col, dimension=self.dimension) @@ -288,35 +323,11 @@ class GridQubit(_BaseGridQid): cirq.GridQubit(5, 4) """ - def __getstate__(self): - # Don't save hash when pickling; see #3777. - state = self.__dict__ - hash_key = _compat._method_cache_name(self.__hash__) - if hash_key in state: - state = state.copy() - del state[hash_key] - return state - - @_compat.cached_method - def __hash__(self) -> int: - # Explicitly cached for performance (vs delegating to Qid). - return super().__hash__() + _dimension = 2 - def __eq__(self, other): - # Explicitly implemented for performance (vs delegating to Qid). - if isinstance(other, GridQubit): - return self._row == other._row and self._col == other._col - return NotImplemented - - def __ne__(self, other): - # Explicitly implemented for performance (vs delegating to Qid). - if isinstance(other, GridQubit): - return self._row != other._row or self._col != other._col - return NotImplemented - - @property - def dimension(self) -> int: - return 2 + def __init__(self, row: int, col: int) -> None: + self._row = row + self._col = col def _with_row_col(self, row: int, col: int): return GridQubit(row, col) diff --git a/cirq-core/cirq/devices/grid_qubit_test.py b/cirq-core/cirq/devices/grid_qubit_test.py index cce810c3930..2f642806ddd 100644 --- a/cirq-core/cirq/devices/grid_qubit_test.py +++ b/cirq-core/cirq/devices/grid_qubit_test.py @@ -19,7 +19,6 @@ import pytest import cirq -from cirq import _compat def test_init(): @@ -45,8 +44,7 @@ def test_pickled_hash(): q = cirq.GridQubit(3, 4) q_bad = cirq.GridQubit(3, 4) _ = hash(q_bad) # compute hash to ensure it is cached. - hash_key = _compat._method_cache_name(cirq.GridQubit.__hash__) - setattr(q_bad, hash_key, getattr(q_bad, hash_key) + 1) + q_bad._hash = q_bad._hash + 1 assert q_bad == q assert hash(q_bad) != hash(q) data = pickle.dumps(q_bad) diff --git a/cirq-core/cirq/devices/line_qubit.py b/cirq-core/cirq/devices/line_qubit.py index 2937558a9ef..2f9bf6a6bca 100644 --- a/cirq-core/cirq/devices/line_qubit.py +++ b/cirq-core/cirq/devices/line_qubit.py @@ -27,19 +27,48 @@ class _BaseLineQid(ops.Qid): """The base class for `LineQid` and `LineQubit`.""" - def __init__(self, x: int) -> None: - """Initializes a line qubit at the given x coordinate.""" - self._x = x + _x: int + _dimension: int + _hash: Optional[int] = None + + def __getstate__(self): + # Don't save hash when pickling; see #3777. + state = self.__dict__ + if "_hash" in state: + state = state.copy() + del state["_hash"] + return state + + def __hash__(self) -> int: + if self._hash is None: + self._hash = hash((self._x, self._dimension)) + return self._hash + + def __eq__(self, other): + # Explicitly implemented for performance (vs delegating to Qid). + if isinstance(other, _BaseLineQid): + return self._x == other._x and self._dimension == other._dimension + return NotImplemented + + def __ne__(self, other): + # Explicitly implemented for performance (vs delegating to Qid). + if isinstance(other, _BaseLineQid): + return self._x != other._x or self._dimension != other._dimension + return NotImplemented def _comparison_key(self): - return self.x + return self._x @property def x(self) -> int: return self._x + @property + def dimension(self) -> int: + return self._dimension + def with_dimension(self, dimension: int) -> 'LineQid': - return LineQid(self.x, dimension) + return LineQid(self._x, dimension) def is_adjacent(self, other: 'cirq.Qid') -> bool: """Determines if two qubits are adjacent line qubits. @@ -49,7 +78,7 @@ def is_adjacent(self, other: 'cirq.Qid') -> bool: Returns: True iff other and self are adjacent. """ - return isinstance(other, _BaseLineQid) and abs(self.x - other.x) == 1 + return isinstance(other, _BaseLineQid) and abs(self._x - other._x) == 1 def neighbors(self, qids: Optional[Iterable[ops.Qid]] = None) -> Set['_BaseLineQid']: """Returns qubits that are potential neighbors to this LineQubit @@ -57,11 +86,7 @@ def neighbors(self, qids: Optional[Iterable[ops.Qid]] = None) -> Set['_BaseLineQ Args: qids: optional Iterable of qubits to constrain neighbors to. """ - neighbors = set() - for q in [self - 1, self + 1]: - if qids is None or q in qids: - neighbors.add(q) - return neighbors + return {q for q in [self - 1, self + 1] if qids is None or q in qids} @abc.abstractmethod def _with_x(self, x: int) -> Self: @@ -69,29 +94,29 @@ def _with_x(self, x: int) -> Self: def __add__(self, other: Union[int, Self]) -> Self: if isinstance(other, _BaseLineQid): - if self.dimension != other.dimension: + if self._dimension != other._dimension: raise TypeError( "Can only add LineQids with identical dimension. " - f"Got {self.dimension} and {other.dimension}" + f"Got {self._dimension} and {other._dimension}" ) - return self._with_x(x=self.x + other.x) + return self._with_x(x=self._x + other._x) if not isinstance(other, int): raise TypeError(f"Can only add ints and {type(self).__name__}. Instead was {other}") - return self._with_x(self.x + other) + return self._with_x(self._x + other) def __sub__(self, other: Union[int, Self]) -> Self: if isinstance(other, _BaseLineQid): - if self.dimension != other.dimension: + if self._dimension != other._dimension: raise TypeError( "Can only subtract LineQids with identical dimension. " - f"Got {self.dimension} and {other.dimension}" + f"Got {self._dimension} and {other._dimension}" ) - return self._with_x(x=self.x - other.x) + return self._with_x(x=self._x - other._x) if not isinstance(other, int): raise TypeError( f"Can only subtract ints and {type(self).__name__}. Instead was {other}" ) - return self._with_x(self.x - other) + return self._with_x(self._x - other) def __radd__(self, other: int) -> Self: return self + other @@ -100,16 +125,16 @@ def __rsub__(self, other: int) -> Self: return -self + other def __neg__(self) -> Self: - return self._with_x(-self.x) + return self._with_x(-self._x) def __complex__(self) -> complex: - return complex(self.x) + return complex(self._x) def __float__(self) -> float: - return float(self.x) + return float(self._x) def __int__(self) -> int: - return int(self.x) + return int(self._x) class LineQid(_BaseLineQid): @@ -137,16 +162,12 @@ def __init__(self, x: int, dimension: int) -> None: dimension: The dimension of the qid's Hilbert space, i.e. the number of quantum levels. """ - super().__init__(x) - self._dimension = dimension self.validate_dimension(dimension) - - @property - def dimension(self): - return self._dimension + self._x = x + self._dimension = dimension def _with_x(self, x: int) -> 'LineQid': - return LineQid(x, dimension=self.dimension) + return LineQid(x, dimension=self._dimension) @staticmethod def range(*range_args, dimension: int) -> List['LineQid']: @@ -192,15 +213,15 @@ def for_gate(val: Any, start: int = 0, step: int = 1) -> List['LineQid']: return LineQid.for_qid_shape(qid_shape(val), start=start, step=step) def __repr__(self) -> str: - return f"cirq.LineQid({self.x}, dimension={self.dimension})" + return f"cirq.LineQid({self._x}, dimension={self._dimension})" def __str__(self) -> str: - return f"q({self.x}) (d={self.dimension})" + return f"q({self._x}) (d={self._dimension})" def _circuit_diagram_info_( self, args: 'cirq.CircuitDiagramInfoArgs' ) -> 'cirq.CircuitDiagramInfo': - return protocols.CircuitDiagramInfo(wire_symbols=(f"{self.x} (d={self.dimension})",)) + return protocols.CircuitDiagramInfo(wire_symbols=(f"{self._x} (d={self._dimension})",)) def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['x', 'dimension']) @@ -223,9 +244,15 @@ class LineQubit(_BaseLineQid): """ - @property - def dimension(self) -> int: - return 2 + _dimension = 2 + + def __init__(self, x: int) -> None: + """Initializes a line qubit at the given x coordinate. + + Args: + x: The x coordinate. + """ + self._x = x def _with_x(self, x: int) -> 'LineQubit': return LineQubit(x) @@ -234,7 +261,7 @@ def _cmp_tuple(self): cls = LineQid if type(self) is LineQubit else type(self) # Must be the same as Qid._cmp_tuple but with cls in place of # type(self). - return (cls.__name__, repr(cls), self._comparison_key(), self.dimension) + return (cls.__name__, repr(cls), self._comparison_key(), self._dimension) @staticmethod def range(*range_args) -> List['LineQubit']: @@ -249,15 +276,15 @@ def range(*range_args) -> List['LineQubit']: return [LineQubit(i) for i in range(*range_args)] def __repr__(self) -> str: - return f"cirq.LineQubit({self.x})" + return f"cirq.LineQubit({self._x})" def __str__(self) -> str: - return f"q({self.x})" + return f"q({self._x})" def _circuit_diagram_info_( self, args: 'cirq.CircuitDiagramInfoArgs' ) -> 'cirq.CircuitDiagramInfo': - return protocols.CircuitDiagramInfo(wire_symbols=(f"{self.x}",)) + return protocols.CircuitDiagramInfo(wire_symbols=(f"{self._x}",)) def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['x']) diff --git a/cirq-core/cirq/ops/named_qubit.py b/cirq-core/cirq/ops/named_qubit.py index 76bf2391fca..6024c91dcb7 100644 --- a/cirq-core/cirq/ops/named_qubit.py +++ b/cirq-core/cirq/ops/named_qubit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Dict, List, TYPE_CHECKING +from typing import Any, Dict, List, Optional, TYPE_CHECKING from cirq import protocols from cirq.ops import raw_types @@ -26,17 +26,52 @@ class _BaseNamedQid(raw_types.Qid): """The base class for `NamedQid` and `NamedQubit`.""" - def __init__(self, name: str) -> None: - self._name = name - self._comp_key = _pad_digits(name) + _name: str + _dimension: int + _comp_key: Optional[str] = None + _hash: Optional[int] = None + + def __getstate__(self): + # Don't save hash when pickling; see #3777. + state = self.__dict__ + if "_hash" in state or "_comp_key" in state: + state = state.copy() + if "_hash" in state: + del state["_hash"] + if "_comp_key" in state: + del state["_comp_key"] + return state + + def __hash__(self) -> int: + if self._hash is None: + self._hash = hash((self._name, self._dimension)) + return self._hash + + def __eq__(self, other): + # Explicitly implemented for performance (vs delegating to Qid). + if isinstance(other, _BaseNamedQid): + return self._name == other._name and self._dimension == other._dimension + return NotImplemented + + def __ne__(self, other): + # Explicitly implemented for performance (vs delegating to Qid). + if isinstance(other, _BaseNamedQid): + return self._name != other._name or self._dimension != other._dimension + return NotImplemented def _comparison_key(self): + if self._comp_key is None: + self._comp_key = _pad_digits(self._name) return self._comp_key @property def name(self) -> str: return self._name + @property + def dimension(self) -> int: + return self._dimension + def with_dimension(self, dimension: int) -> 'NamedQid': return NamedQid(self._name, dimension=dimension) @@ -59,19 +94,15 @@ def __init__(self, name: str, dimension: int) -> None: dimension: The dimension of the qid's Hilbert space, i.e. the number of quantum levels. """ - super().__init__(name) - self._dimension = dimension self.validate_dimension(dimension) - - @property - def dimension(self) -> int: - return self._dimension + self._name = name + self._dimension = dimension def __repr__(self) -> str: - return f'cirq.NamedQid({self.name!r}, dimension={self.dimension})' + return f'cirq.NamedQid({self._name!r}, dimension={self._dimension})' def __str__(self) -> str: - return f'{self.name} (d={self.dimension})' + return f'{self._name} (d={self._dimension})' @staticmethod def range(*args, prefix: str, dimension: int) -> List['NamedQid']: @@ -95,7 +126,7 @@ def range(*args, prefix: str, dimension: int) -> List['NamedQid']: Returns: A list of ``NamedQid``\\s. """ - return [NamedQid(prefix + str(i), dimension=dimension) for i in range(*args)] + return [NamedQid(f"{prefix}{i}", dimension=dimension) for i in range(*args)] def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['name', 'dimension']) @@ -110,14 +141,20 @@ class NamedQubit(_BaseNamedQid): wire for 'qubit3' will correctly come before 'qubit22'. """ - @property - def dimension(self) -> int: - return 2 + _dimension = 2 + + def __init__(self, name: str) -> None: + """Initializes a `NamedQubit` with a given name. + + Args: + name: The name. + """ + self._name = name def _cmp_tuple(self): cls = NamedQid if type(self) is NamedQubit else type(self) # Must be same as Qid._cmp_tuple but with cls in place of type(self). - return (cls.__name__, repr(cls), self._comparison_key(), self.dimension) + return (cls.__name__, repr(cls), self._comparison_key(), self._dimension) def __str__(self) -> str: return self._name @@ -146,7 +183,7 @@ def range(*args, prefix: str) -> List['NamedQubit']: Returns: A list of ``NamedQubit``\\s. """ - return [NamedQubit(prefix + str(i)) for i in range(*args)] + return [NamedQubit(f"{prefix}{i}") for i in range(*args)] def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['name'])