Skip to content

Commit 89ebad8

Browse files
authored
Cache the result of CircuitOperation.has_unitary (#7483)
Otherwise every time `CircuitOperation.has_unitary` is called, `has_unitary_protocol` [decomposes the circuit into operations](https://github.com/quantumlib/Cirq/blob/609d93dbc51a6608a0d0c3f5d50d51325052e027/cirq-core/cirq/protocols/has_unitary_protocol.py#L98) and calls `has_unitary` on each. This is inefficient, especially if `repetitions > 1`, because `has_unitary_protocol` looks at the circuit with repetitions. To ensure this change is correct, I wrote the tests first, and then made the change.
1 parent 460dda0 commit 89ebad8

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

cirq-core/cirq/circuits/circuit_operation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import sympy
3030

3131
from cirq import circuits, ops, protocols, study, value
32-
from cirq._compat import proper_repr
32+
from cirq._compat import cached_method, proper_repr
3333

3434
if TYPE_CHECKING:
3535
import cirq
@@ -296,12 +296,13 @@ def _qid_shape_(self) -> tuple[int, ...]:
296296
def _is_measurement_(self) -> bool:
297297
return self.circuit._is_measurement_()
298298

299+
@cached_method
299300
def _has_unitary_(self) -> bool:
300301
# Return false if parameterized for early exit of has_unitary protocol.
301-
# Otherwise return NotImplemented instructing the protocol to try alternate strategies
302302
if self._is_parameterized_() or self.repeat_until:
303303
return False
304-
return NotImplemented
304+
operations = self._mapped_any_loop.all_operations()
305+
return all(protocols.has_unitary(op) for op in operations)
305306

306307
def _ensure_deterministic_loop_count(self):
307308
if self.repeat_until or isinstance(self.repetitions, sympy.Expr):

cirq-core/cirq/circuits/circuit_operation_test.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import cirq
2424
import cirq.circuits.circuit_operation as circuit_operation
25-
from cirq import _compat
25+
from cirq import _compat, protocols
2626
from cirq.circuits.circuit_operation import _full_join_string_lists
2727

2828
ALL_SIMULATORS = (cirq.Simulator(), cirq.DensityMatrixSimulator(), cirq.CliffordSimulator())
@@ -1297,3 +1297,30 @@ def test_inner_repeat_until_simulate() -> None:
12971297

12981298

12991299
# TODO: Operation has a "gate" property. What is this for a CircuitOperation?
1300+
1301+
1302+
def test_has_unitary_protocol_returns_true_if_all_common_gates() -> None:
1303+
q = cirq.LineQubit(0)
1304+
op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q), cirq.Y(q), cirq.Z(q)))
1305+
assert protocols.has_unitary(op)
1306+
1307+
1308+
def test_has_unitary_protocol_returns_false_if_measurement_gate() -> None:
1309+
q = cirq.LineQubit(0)
1310+
key = cirq.MeasurementKey('m')
1311+
op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)))
1312+
assert not protocols.has_unitary(op)
1313+
1314+
1315+
def test_has_unitary_protocol_returns_false_if_parametrized() -> None:
1316+
q = cirq.LineQubit(0)
1317+
exp = sympy.Symbol('exp')
1318+
op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q) ** exp))
1319+
assert not protocols.has_unitary(op)
1320+
1321+
1322+
def test_has_unitary_protocol_returns_true_if_all_params_resolve() -> None:
1323+
q = cirq.LineQubit(0)
1324+
exp = sympy.Symbol('exp')
1325+
op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q) ** exp), param_resolver={exp: 0.5})
1326+
assert protocols.has_unitary(op)

0 commit comments

Comments
 (0)