Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 49 additions & 38 deletions cirq-core/cirq/devices/grid_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from cirq import _compat, ops, protocols
from cirq import ops, protocols

if TYPE_CHECKING:
import cirq
Expand All @@ -29,9 +29,43 @@
class _BaseGridQid(ops.Qid):
"""The Base class for `GridQid` and `GridQubit`."""

def __init__(self, row: int, col: int):
self._row = row
self._col = col
_row: int
_col: int
_dimension: int
_hash: Optional[int] = None

def __getstate__(self):
# Don't save hash when pickling; see #3777.
state = self.__dict__
if "_hash" in state:
state = state.copy()
del state["_hash"]
return state

def __hash__(self) -> int:
if self._hash is None:
self._hash = hash((self._row, self._col, self._dimension))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just compute this in the constructor? It should be immutable, right? That would remove the need for an if statement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but since construction is more common than hashing I think it would be good to keep the constructor as fast as possible. I played around a bit with adding a __new__ method so we can actually cache instances instead of allocating new ones each time the constructor is called, but that interacts with pickling in strange ways that I haven't worked out yet.

Copy link
Contributor Author

@maffoo maffoo Nov 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a comparison of timings with lazy hashing (this PR) vs eager hashing where as you suggest the hash is computed in the constructor and there's no conditional in __hash__. It does indeed make hashing faster, but slows down the constructor quite a bit:

Operation Lazy Hash [ns] Eager Hash [ns] t_eager / t_lazy
%timeit cirq.GridQubit(1, 2) 294 430 1.46x
%timeit hash(cirq.GridQubit(1, 2)) 652 589 0.90x
q = cirq.GridQubit(1, 2); %timeit hash(q) 163 135 0.83x
%timeit cirq.LineQubit(3) 254 369 1.45x
%timeit hash(cirq.LineQubit(3) 560 506 0.90x
q = cirq.LineQubit(3); %timeit hash(q) 163 144 0.88x
%timeit cirq.NamedQubit("abc") 249 361 1.45x
%timeit hash(cirq.NamedQubit("abc") 568 502 0.88x
q = cirq.NamedQubit("abc"); %timeit hash(q) 159 137 0.86x

So eager hashing speeds up the __hash__ calls by between 10-20%, but slows down the constructors by almost 50%. If we can figure out __new__ to cache qid instances, then I think eager hashing would be a clear win, but for now I'd suggest we stick with lazy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks for the thorough analysis. This makes sense.

return self._hash

def __eq__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
return (
self._row == other._row
and self._col == other._col
and self._dimension == other._dimension
)
return NotImplemented

def __ne__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
return (
self._row != other._row
or self._col != other._col
or self._dimension != other._dimension
)
return NotImplemented

def _comparison_key(self):
return self._row, self._col
Expand All @@ -44,6 +78,10 @@ def row(self) -> int:
def col(self) -> int:
return self._col

@property
def dimension(self) -> int:
return self._dimension

def with_dimension(self, dimension: int) -> 'GridQid':
return GridQid(self._row, self._col, dimension=dimension)

Expand Down Expand Up @@ -149,13 +187,10 @@ def __init__(self, row: int, col: int, *, dimension: int) -> None:
dimension: The dimension of the qid's Hilbert space, i.e.
the number of quantum levels.
"""
super().__init__(row, col)
self._dimension = dimension
self.validate_dimension(dimension)

@property
def dimension(self):
return self._dimension
self._row = row
self._col = col
self._dimension = dimension

def _with_row_col(self, row: int, col: int) -> 'GridQid':
return GridQid(row, col, dimension=self.dimension)
Expand Down Expand Up @@ -288,35 +323,11 @@ class GridQubit(_BaseGridQid):
cirq.GridQubit(5, 4)
"""

def __getstate__(self):
# Don't save hash when pickling; see #3777.
state = self.__dict__
hash_key = _compat._method_cache_name(self.__hash__)
if hash_key in state:
state = state.copy()
del state[hash_key]
return state

@_compat.cached_method
def __hash__(self) -> int:
# Explicitly cached for performance (vs delegating to Qid).
return super().__hash__()
_dimension = 2

def __eq__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, GridQubit):
return self._row == other._row and self._col == other._col
return NotImplemented

def __ne__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, GridQubit):
return self._row != other._row or self._col != other._col
return NotImplemented

@property
def dimension(self) -> int:
return 2
def __init__(self, row: int, col: int) -> None:
self._row = row
self._col = col

def _with_row_col(self, row: int, col: int):
return GridQubit(row, col)
Expand Down
4 changes: 1 addition & 3 deletions cirq-core/cirq/devices/grid_qubit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pytest

import cirq
from cirq import _compat


def test_init():
Expand All @@ -45,8 +44,7 @@ def test_pickled_hash():
q = cirq.GridQubit(3, 4)
q_bad = cirq.GridQubit(3, 4)
_ = hash(q_bad) # compute hash to ensure it is cached.
hash_key = _compat._method_cache_name(cirq.GridQubit.__hash__)
setattr(q_bad, hash_key, getattr(q_bad, hash_key) + 1)
q_bad._hash = q_bad._hash + 1
assert q_bad == q
assert hash(q_bad) != hash(q)
data = pickle.dumps(q_bad)
Expand Down
107 changes: 67 additions & 40 deletions cirq-core/cirq/devices/line_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,48 @@
class _BaseLineQid(ops.Qid):
"""The base class for `LineQid` and `LineQubit`."""

def __init__(self, x: int) -> None:
"""Initializes a line qubit at the given x coordinate."""
self._x = x
_x: int
_dimension: int
_hash: Optional[int] = None

def __getstate__(self):
# Don't save hash when pickling; see #3777.
state = self.__dict__
if "_hash" in state:
state = state.copy()
del state["_hash"]
return state

def __hash__(self) -> int:
if self._hash is None:
self._hash = hash((self._x, self._dimension))
return self._hash

def __eq__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self._x == other._x and self._dimension == other._dimension
return NotImplemented

def __ne__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self._x != other._x or self._dimension != other._dimension
return NotImplemented

def _comparison_key(self):
return self.x
return self._x

@property
def x(self) -> int:
return self._x

@property
def dimension(self) -> int:
return self._dimension

def with_dimension(self, dimension: int) -> 'LineQid':
return LineQid(self.x, dimension)
return LineQid(self._x, dimension)

def is_adjacent(self, other: 'cirq.Qid') -> bool:
"""Determines if two qubits are adjacent line qubits.
Expand All @@ -49,49 +78,45 @@ def is_adjacent(self, other: 'cirq.Qid') -> bool:

Returns: True iff other and self are adjacent.
"""
return isinstance(other, _BaseLineQid) and abs(self.x - other.x) == 1
return isinstance(other, _BaseLineQid) and abs(self._x - other._x) == 1

def neighbors(self, qids: Optional[Iterable[ops.Qid]] = None) -> Set['_BaseLineQid']:
"""Returns qubits that are potential neighbors to this LineQubit

Args:
qids: optional Iterable of qubits to constrain neighbors to.
"""
neighbors = set()
for q in [self - 1, self + 1]:
if qids is None or q in qids:
neighbors.add(q)
return neighbors
return {q for q in [self - 1, self + 1] if qids is None or q in qids}

@abc.abstractmethod
def _with_x(self, x: int) -> Self:
"""Returns a qubit with the same type but a different value of `x`."""

def __add__(self, other: Union[int, Self]) -> Self:
if isinstance(other, _BaseLineQid):
if self.dimension != other.dimension:
if self._dimension != other._dimension:
raise TypeError(
"Can only add LineQids with identical dimension. "
f"Got {self.dimension} and {other.dimension}"
f"Got {self._dimension} and {other._dimension}"
)
return self._with_x(x=self.x + other.x)
return self._with_x(x=self._x + other._x)
if not isinstance(other, int):
raise TypeError(f"Can only add ints and {type(self).__name__}. Instead was {other}")
return self._with_x(self.x + other)
return self._with_x(self._x + other)

def __sub__(self, other: Union[int, Self]) -> Self:
if isinstance(other, _BaseLineQid):
if self.dimension != other.dimension:
if self._dimension != other._dimension:
raise TypeError(
"Can only subtract LineQids with identical dimension. "
f"Got {self.dimension} and {other.dimension}"
f"Got {self._dimension} and {other._dimension}"
)
return self._with_x(x=self.x - other.x)
return self._with_x(x=self._x - other._x)
if not isinstance(other, int):
raise TypeError(
f"Can only subtract ints and {type(self).__name__}. Instead was {other}"
)
return self._with_x(self.x - other)
return self._with_x(self._x - other)

def __radd__(self, other: int) -> Self:
return self + other
Expand All @@ -100,16 +125,16 @@ def __rsub__(self, other: int) -> Self:
return -self + other

def __neg__(self) -> Self:
return self._with_x(-self.x)
return self._with_x(-self._x)

def __complex__(self) -> complex:
return complex(self.x)
return complex(self._x)

def __float__(self) -> float:
return float(self.x)
return float(self._x)

def __int__(self) -> int:
return int(self.x)
return int(self._x)


class LineQid(_BaseLineQid):
Expand Down Expand Up @@ -137,16 +162,12 @@ def __init__(self, x: int, dimension: int) -> None:
dimension: The dimension of the qid's Hilbert space, i.e.
the number of quantum levels.
"""
super().__init__(x)
self._dimension = dimension
self.validate_dimension(dimension)

@property
def dimension(self):
return self._dimension
self._x = x
self._dimension = dimension

def _with_x(self, x: int) -> 'LineQid':
return LineQid(x, dimension=self.dimension)
return LineQid(x, dimension=self._dimension)

@staticmethod
def range(*range_args, dimension: int) -> List['LineQid']:
Expand Down Expand Up @@ -192,15 +213,15 @@ def for_gate(val: Any, start: int = 0, step: int = 1) -> List['LineQid']:
return LineQid.for_qid_shape(qid_shape(val), start=start, step=step)

def __repr__(self) -> str:
return f"cirq.LineQid({self.x}, dimension={self.dimension})"
return f"cirq.LineQid({self._x}, dimension={self._dimension})"

def __str__(self) -> str:
return f"q({self.x}) (d={self.dimension})"
return f"q({self._x}) (d={self._dimension})"

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self.x} (d={self.dimension})",))
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self._x} (d={self._dimension})",))

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['x', 'dimension'])
Expand All @@ -223,9 +244,15 @@ class LineQubit(_BaseLineQid):

"""

@property
def dimension(self) -> int:
return 2
_dimension = 2

def __init__(self, x: int) -> None:
"""Initializes a line qubit at the given x coordinate.

Args:
x: The x coordinate.
"""
self._x = x

def _with_x(self, x: int) -> 'LineQubit':
return LineQubit(x)
Expand All @@ -234,7 +261,7 @@ def _cmp_tuple(self):
cls = LineQid if type(self) is LineQubit else type(self)
# Must be the same as Qid._cmp_tuple but with cls in place of
# type(self).
return (cls.__name__, repr(cls), self._comparison_key(), self.dimension)
return (cls.__name__, repr(cls), self._comparison_key(), self._dimension)

@staticmethod
def range(*range_args) -> List['LineQubit']:
Expand All @@ -249,15 +276,15 @@ def range(*range_args) -> List['LineQubit']:
return [LineQubit(i) for i in range(*range_args)]

def __repr__(self) -> str:
return f"cirq.LineQubit({self.x})"
return f"cirq.LineQubit({self._x})"

def __str__(self) -> str:
return f"q({self.x})"
return f"q({self._x})"

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self.x}",))
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self._x}",))

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['x'])
Loading