Skip to content

Commit ce31720

Browse files
Fix moment commutation detection for group-commuting operations (#6659) (#7082)
* Fix Moment.commutes_ to handle multi-qubit unitaries (#6659) Add pairwise commuting checks and, when inconclusive, fall back to full unitary commutator. Includes new tests showing Z⊗Z commutes with RXX. Fixes #6659. * Put back pre-existing commutation tests * prune duplicate or redundant test cases * reuse cirq.AmplitudeDampingChannel to test with non-unitary * move commute-related test functions together * Refactor `Moment._commutes_` to convert to CircuitOperation if necessary Also check `measurement_key_objs` and `control_keys` active for checked Moments as in the Operation commutes protocol. * Test commutation of moments with measurement keys and controls Adapt `cirq.ops.classically_controlled_operation_test.test_commute` * Small tweak of docstring * Add shared internal function _operations_commutes_impl Generalize `Operation._commutes_` in a shared internal function `_operations_commutes_impl` * Sync module variable name with circuit module name No change in code function. * Use shared _operations_commutes_impl for Moment._commutes_ Also adjust `_operations_commutes_impl` to assume equal Moment-s and Operation-s commute. * Few optimizations in _operations_commutes_impl - frozenset().union(*tuples) is faster than iterating over all items in every tuple to construct a frozenset. - base class Operation does not have `__hash__` method, therefore we cannot use `set(operations)`. --------- Co-authored-by: Pavol Juhas <[email protected]>
1 parent 0d3c972 commit ce31720

File tree

3 files changed

+184
-59
lines changed

3 files changed

+184
-59
lines changed

cirq-core/cirq/circuits/moment.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
import cirq
5050

5151
# Lazy imports to break circular dependencies.
52-
circuits = LazyLoader("circuits", globals(), "cirq.circuits.circuit")
52+
circuit = LazyLoader("circuit", globals(), "cirq.circuits.circuit")
5353
op_tree = LazyLoader("op_tree", globals(), "cirq.ops.op_tree")
5454
text_diagram_drawer = LazyLoader(
5555
"text_diagram_drawer", globals(), "cirq.circuits.text_diagram_drawer"
@@ -525,7 +525,7 @@ def _from_json_dict_(cls, operations, **kwargs):
525525
return cls.from_ops(*operations)
526526

527527
def __add__(self, other: 'cirq.OP_TREE') -> 'cirq.Moment':
528-
if isinstance(other, circuits.AbstractCircuit):
528+
if isinstance(other, circuit.AbstractCircuit):
529529
return NotImplemented # Delegate to Circuit.__radd__.
530530
return self.with_operations(other)
531531

@@ -659,37 +659,26 @@ def cleanup_key(key: Any) -> Any:
659659
return diagram.render()
660660

661661
def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]:
662-
"""Determines whether Moment commutes with the Operation.
662+
"""Determines whether Moment commutes with the other Moment or Operation.
663663
664664
Args:
665-
other: An Operation object. Other types are not implemented yet.
666-
In case a different type is specified, NotImplemented is
667-
returned.
665+
other: An Operation or Moment object to test for commutativity.
668666
atol: Absolute error tolerance. If all entries in v1@v2 - v2@v1
669667
have a magnitude less than this tolerance, v1 and v2 can be
670668
reported as commuting. Defaults to 1e-8.
671669
672670
Returns:
673-
True: The Moment and Operation commute OR they don't have shared
674-
quibits.
671+
True: The Moment commutes with Moment or Operation OR they don't
672+
have shared qubits.
675673
False: The two values do not commute.
676674
NotImplemented: In case we don't know how to check this, e.g.
677-
the parameter type is not supported yet.
675+
the parameter type is not supported or commutativity cannot be
676+
determined.
678677
"""
679-
if not isinstance(other, ops.Operation):
678+
if not isinstance(other, (ops.Operation, Moment)):
680679
return NotImplemented
681-
682-
other_qubits = set(other.qubits)
683-
for op in self.operations:
684-
if not other_qubits.intersection(set(op.qubits)):
685-
continue
686-
687-
commutes = protocols.commutes(op, other, atol=atol, default=NotImplemented)
688-
689-
if not commutes or commutes is NotImplemented:
690-
return commutes
691-
692-
return True
680+
other_operations = other.operations if isinstance(other, Moment) else (other,)
681+
return raw_types._operations_commutes_impl(self.operations, other_operations, atol=atol)
693682

694683

695684
class _SortByValFallbackToType:

cirq-core/cirq/circuits/moment_test.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,15 +680,15 @@ def test_text_diagram_does_not_depend_on_insertion_order():
680680
assert str(m1) == str(m2)
681681

682682

683-
def test_commutes():
683+
def test_commutes_moment_and_operation():
684684
a = cirq.NamedQubit('a')
685685
b = cirq.NamedQubit('b')
686686
c = cirq.NamedQubit('c')
687687
d = cirq.NamedQubit('d')
688688

689689
moment = cirq.Moment([cirq.X(a), cirq.Y(b), cirq.H(c)])
690690

691-
assert NotImplemented == cirq.commutes(moment, a, default=NotImplemented)
691+
assert cirq.commutes(moment, a, default=None) is None
692692

693693
assert cirq.commutes(moment, cirq.X(a))
694694
assert cirq.commutes(moment, cirq.Y(b))
@@ -700,6 +700,101 @@ def test_commutes():
700700
assert not cirq.commutes(moment, cirq.H(b))
701701
assert not cirq.commutes(moment, cirq.X(c))
702702

703+
# Empty moment commutes with everything
704+
moment = cirq.Moment()
705+
assert cirq.commutes(moment, cirq.X(a))
706+
assert cirq.commutes(moment, cirq.measure(b))
707+
708+
# Two qubit operation
709+
moment = cirq.Moment(cirq.Z(a), cirq.Z(b))
710+
assert cirq.commutes(moment, cirq.XX(a, b))
711+
712+
713+
def test_commutes_moment_and_moment():
714+
a = cirq.NamedQubit('a')
715+
b = cirq.NamedQubit('b')
716+
c = cirq.NamedQubit('c')
717+
718+
# Test cases where individual operations don't commute but moments do
719+
# Two Z gates (Z⊗Z) commutes with RXX even though individual Z's don't
720+
assert not cirq.commutes(cirq.Moment(cirq.Z(a)), cirq.Moment(cirq.XX(a, b)))
721+
assert cirq.commutes(cirq.Moment(cirq.Z(a), cirq.Z(b)), cirq.Moment(cirq.XX(a, b)))
722+
723+
# Moments that do not commute if acting on same qubits
724+
assert cirq.commutes(cirq.Moment(cirq.X(a)), cirq.Moment(cirq.Y(b)))
725+
assert not cirq.commutes(cirq.Moment(cirq.X(a)), cirq.Moment(cirq.Y(a)))
726+
727+
# Moments commute with themselves
728+
assert cirq.commutes(
729+
cirq.Moment([cirq.X(a), cirq.Y(b), cirq.H(c)]),
730+
cirq.Moment([cirq.X(a), cirq.Y(b), cirq.H(c)]),
731+
)
732+
733+
734+
def test_commutes_moment_with_controls():
735+
a, b = cirq.LineQubit.range(2)
736+
assert cirq.commutes(
737+
cirq.Moment(cirq.measure(a, key='k0')), cirq.Moment(cirq.X(b).with_classical_controls('k1'))
738+
)
739+
assert cirq.commutes(
740+
cirq.Moment(cirq.X(b).with_classical_controls('k1')), cirq.Moment(cirq.measure(a, key='k0'))
741+
)
742+
assert cirq.commutes(
743+
cirq.Moment(cirq.X(a).with_classical_controls('k0')),
744+
cirq.Moment(cirq.H(b).with_classical_controls('k0')),
745+
)
746+
assert cirq.commutes(
747+
cirq.Moment(cirq.X(a).with_classical_controls('k0')),
748+
cirq.Moment(cirq.X(a).with_classical_controls('k0')),
749+
)
750+
assert not cirq.commutes(
751+
cirq.Moment(cirq.measure(a, key='k0')), cirq.Moment(cirq.X(b).with_classical_controls('k0'))
752+
)
753+
assert not cirq.commutes(
754+
cirq.Moment(cirq.X(b).with_classical_controls('k0')), cirq.Moment(cirq.measure(a, key='k0'))
755+
)
756+
assert not cirq.commutes(
757+
cirq.Moment(cirq.X(a).with_classical_controls('k0')),
758+
cirq.Moment(cirq.H(a).with_classical_controls('k0')),
759+
)
760+
761+
762+
def test_commutes_moment_and_moment_comprehensive():
763+
a, b, c, d = cirq.LineQubit.range(4)
764+
765+
# Basic Z⊗Z commuting with XX at different angles
766+
m1 = cirq.Moment([cirq.Z(a), cirq.Z(b)])
767+
m2 = cirq.Moment([cirq.XXPowGate(exponent=0.5)(a, b)])
768+
assert cirq.commutes(m1, m2)
769+
770+
# Disjoint qubit sets
771+
m1 = cirq.Moment([cirq.X(a), cirq.Y(b)])
772+
m2 = cirq.Moment([cirq.Z(c), cirq.H(d)])
773+
assert cirq.commutes(m1, m2)
774+
775+
# Mixed case - some commute individually, some as group
776+
m1 = cirq.Moment([cirq.Z(a), cirq.Z(b), cirq.X(c)])
777+
m2 = cirq.Moment([cirq.XXPowGate(exponent=0.5)(a, b), cirq.X(c)])
778+
assert cirq.commutes(m1, m2)
779+
780+
# Non-commuting case: X on first qubit, Z on second with XX gate
781+
m1 = cirq.Moment([cirq.X(a), cirq.Z(b)])
782+
m2 = cirq.Moment([cirq.XX(a, b)])
783+
assert not cirq.commutes(m1, m2)
784+
785+
# Complex case requiring unitary calculation - non-commuting case
786+
m1 = cirq.Moment([cirq.Z(a), cirq.Z(b), cirq.Z(c)])
787+
m2 = cirq.Moment([cirq.XXPowGate(exponent=0.5)(a, b), cirq.X(c)])
788+
assert not cirq.commutes(m1, m2) # Z⊗Z⊗Z doesn't commute with XX⊗X
789+
790+
791+
def test_commutes_handles_non_unitary_operation():
792+
a = cirq.NamedQubit('a')
793+
op_damp_a = cirq.AmplitudeDampingChannel(gamma=0.1).on(a)
794+
assert cirq.commutes(cirq.Moment(cirq.X(a)), op_damp_a, default=None) is None
795+
assert cirq.commutes(cirq.Moment(cirq.X(a)), cirq.Moment(op_damp_a), default=None) is None
796+
assert cirq.commutes(cirq.Moment(op_damp_a), cirq.Moment(op_damp_a))
797+
703798

704799
def test_transform_qubits():
705800
a, b = cirq.LineQubit.range(2)

cirq-core/cirq/ops/raw_types.py

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -688,41 +688,7 @@ def _commutes_(
688688
"""Determine if this Operation commutes with the object"""
689689
if not isinstance(other, Operation):
690690
return NotImplemented
691-
692-
self_keys = protocols.measurement_key_objs(self)
693-
other_keys = protocols.measurement_key_objs(other)
694-
if (
695-
not self_keys.isdisjoint(other_keys)
696-
or not protocols.control_keys(self).isdisjoint(other_keys)
697-
or not protocols.control_keys(other).isdisjoint(self_keys)
698-
):
699-
return False
700-
701-
if hasattr(other, 'qubits') and set(self.qubits).isdisjoint(other.qubits):
702-
return True
703-
704-
from cirq import circuits
705-
706-
# Remove the classical controls to validate the quantum commutativity. This can be done
707-
# because during execution, the two operations will either both be run, in which case they
708-
# behave like the suboperations, so if the suboperations commute then these commute. Or
709-
# one of them is cold in which case it behaves like the identity, which always commutes.
710-
self_raw = self.without_classical_controls()
711-
other_raw = other.without_classical_controls()
712-
circuit12 = circuits.Circuit(self_raw, other_raw)
713-
circuit21 = circuits.Circuit(other_raw, self_raw)
714-
715-
# Don't create gigantic matrices.
716-
shape = protocols.qid_shape_protocol.qid_shape(circuit12)
717-
if np.prod(shape, dtype=np.int64) > 2**10:
718-
return NotImplemented # pragma: no cover
719-
720-
m12 = protocols.unitary_protocol.unitary(circuit12, default=None)
721-
m21 = protocols.unitary_protocol.unitary(circuit21, default=None)
722-
if m12 is None or m21 is None:
723-
return NotImplemented
724-
725-
return np.allclose(m12, m21, atol=atol)
691+
return _operations_commutes_impl([self], [other], atol=atol)
726692

727693
@property
728694
def classical_controls(self) -> FrozenSet['cirq.Condition']:
@@ -1112,3 +1078,78 @@ def _validate_qid_shape(val: Any, qubits: Sequence['cirq.Qid']) -> None:
11121078
raise ValueError(
11131079
f'Duplicate qids for <{val!r}>. Expected unique qids but got <{qubits!r}>.'
11141080
)
1081+
1082+
1083+
def _operations_commutes_impl(
1084+
ops1: Collection[Operation], ops2: Collection[Operation], *, atol: float
1085+
) -> Union[bool, NotImplementedType]:
1086+
"""Determine if two collections of non-overlapping Operations commute.
1087+
1088+
This function implements the commutes protocol for the Operation and Moment classes
1089+
and is not intended for other use.
1090+
1091+
Args:
1092+
ops1: The first collection of operations. It is assumed each operation
1093+
acts on different qubits, i.e., the operations can form a Moment.
1094+
ops2: The second collection of operations to be checked for commutation
1095+
with `ops1`. It is assumed each operation acts on different qubits,
1096+
i.e., the operations can form a Moment.
1097+
atol: Absolute error tolerance. If all entries in ops1@ops2 - ops2@ops1
1098+
have a magnitude less than this tolerance, ops1 and ops2 are considered
1099+
to commute.
1100+
1101+
Returns:
1102+
True: `ops1` and `ops2` commute (or approximately commute).
1103+
False: `ops1` and `ops2` do not commute.
1104+
NotImplemented: The commutativity cannot be determined here.
1105+
"""
1106+
ops1_keys = frozenset(k for op in ops1 for k in protocols.measurement_key_objs(op))
1107+
ops2_keys = frozenset(k for op in ops2 for k in protocols.measurement_key_objs(op))
1108+
ops1_control_keys = frozenset(k for op in ops1 for k in protocols.control_keys(op))
1109+
ops2_control_keys = frozenset(k for op in ops2 for k in protocols.control_keys(op))
1110+
if (
1111+
not ops1_keys.isdisjoint(ops2_keys)
1112+
or not ops1_control_keys.isdisjoint(ops2_keys)
1113+
or not ops2_control_keys.isdisjoint(ops1_keys)
1114+
):
1115+
return False
1116+
1117+
ops1_qubits = frozenset().union(*(op.qubits for op in ops1))
1118+
ops2_qubits = frozenset().union(*(op.qubits for op in ops2))
1119+
if ops1_qubits.isdisjoint(ops2_qubits):
1120+
return True
1121+
1122+
from cirq import circuits
1123+
1124+
# Remove the classical controls to validate the quantum commutativity. This can be done
1125+
# because during execution, the two operations will either both be run, in which case they
1126+
# behave like the suboperations, so if the suboperations commute then these commute. Or
1127+
# one of them is cold in which case it behaves like the identity, which always commutes.
1128+
shared_qubits = ops1_qubits.intersection(ops2_qubits)
1129+
ops1_raw = [
1130+
op.without_classical_controls() for op in ops1 if not shared_qubits.isdisjoint(op.qubits)
1131+
]
1132+
ops2_raw = [
1133+
op.without_classical_controls() for op in ops2 if not shared_qubits.isdisjoint(op.qubits)
1134+
]
1135+
moment1 = circuits.Moment(ops1_raw)
1136+
moment2 = circuits.Moment(ops2_raw)
1137+
1138+
# shortcut if we have equal moments
1139+
if moment1 == moment2:
1140+
return True
1141+
1142+
circuit12 = circuits.Circuit(moment1, moment2)
1143+
circuit21 = circuits.Circuit(moment2, moment1)
1144+
1145+
# Don't create gigantic matrices.
1146+
shape = protocols.qid_shape_protocol.qid_shape(circuit12)
1147+
if np.prod(shape, dtype=np.int64) > 2**10:
1148+
return NotImplemented # pragma: no cover
1149+
1150+
m12 = protocols.unitary_protocol.unitary(circuit12, default=None)
1151+
m21 = protocols.unitary_protocol.unitary(circuit21, default=None)
1152+
if m12 is None or m21 is None:
1153+
return NotImplemented
1154+
1155+
return np.allclose(m12, m21, atol=atol)

0 commit comments

Comments
 (0)