Skip to content

Commit 0a326d0

Browse files
codrut3mhuckapavoljuhas
authored
Assume PhasedXPowGate is different from XPowGate and YPowGate (#7070)
* Assume PhasedXPowGate is different from XPowGate and YPowGate for value_equality protocol. A similar change was done for PhasedISwapPowGate. This addresses issue #6528. * Remove _value_equality_values_cls_ and update tests. * Explicitly test that Pauli gates differ from equal-effect PhasedXPowGate * Change engine tests to compare circuit unitaries Adjust for unequality between X and equal-effect PhasedXPowGate. * Fix unintentionally empty Moment in the test * eject_phased_paulis - replace PhasedXPowGate with equivalent X or Y This mostly restores the eject_phased_paulis unit tests to their initial form with a few `cirq.Y` replacements sparkled around. * phase_by for X and Y - replace PhasedXPowGate with equivalent X or Y This mostly restores `eject_z_test.py` to its initial form. --------- Co-authored-by: Michael Hucka <[email protected]> Co-authored-by: Pavol Juhas <[email protected]>
1 parent d39112e commit 0a326d0

File tree

9 files changed

+77
-58
lines changed

9 files changed

+77
-58
lines changed

cirq-core/cirq/ops/common_gates.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,7 @@ def phase_exponent(self):
289289

290290
def _phase_by_(self, phase_turns, qubit_index):
291291
"""See `cirq.SupportsPhase`."""
292-
return cirq.ops.phased_x_gate.PhasedXPowGate(
293-
exponent=self._exponent, phase_exponent=phase_turns * 2
294-
)
292+
return _phased_x_or_pauli_gate(exponent=self._exponent, phase_exponent=phase_turns * 2)
295293

296294
def _has_stabilizer_effect_(self) -> Optional[bool]:
297295
if self._is_parameterized_() or self._dimension != 2:
@@ -484,7 +482,7 @@ def phase_exponent(self):
484482

485483
def _phase_by_(self, phase_turns, qubit_index):
486484
"""See `cirq.SupportsPhase`."""
487-
return cirq.ops.phased_x_gate.PhasedXPowGate(
485+
return _phased_x_or_pauli_gate(
488486
exponent=self._exponent, phase_exponent=0.5 + phase_turns * 2
489487
)
490488

@@ -1542,3 +1540,17 @@ def cphase(rads: value.TParamVal) -> CZPowGate:
15421540
$$
15431541
""",
15441542
)
1543+
1544+
1545+
def _phased_x_or_pauli_gate(
1546+
exponent: Union[float, sympy.Expr], phase_exponent: Union[float, sympy.Expr]
1547+
) -> Union['cirq.PhasedXPowGate', 'cirq.XPowGate', 'cirq.YPowGate']:
1548+
"""Return PhasedXPowGate or X or Y gate if equivalent at the given phase_exponent."""
1549+
if not isinstance(phase_exponent, sympy.Expr) or phase_exponent.is_constant():
1550+
half_turns = value.canonicalize_half_turns(float(phase_exponent))
1551+
match half_turns:
1552+
case 0.0:
1553+
return XPowGate(exponent=exponent)
1554+
case 0.5:
1555+
return YPowGate(exponent=exponent)
1556+
return cirq.ops.PhasedXPowGate(exponent=exponent, phase_exponent=phase_exponent)

cirq-core/cirq/ops/phased_x_gate.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
import cirq
2525
from cirq import protocols, value
2626
from cirq._compat import proper_repr
27-
from cirq.ops import common_gates, raw_types
27+
from cirq.ops import raw_types
2828

2929

30-
@value.value_equality(manual_cls=True, approximate=True)
30+
@value.value_equality(approximate=True)
3131
class PhasedXPowGate(raw_types.Gate):
3232
r"""A gate equivalent to $Z^{-p} X^t Z^{p}$ (in time order).
3333
@@ -241,22 +241,7 @@ def _canonical_exponent(self):
241241

242242
return self._exponent % period
243243

244-
def _value_equality_values_cls_(self):
245-
if self.phase_exponent == 0:
246-
return common_gates.XPowGate
247-
if self.phase_exponent == 0.5:
248-
return common_gates.YPowGate
249-
return PhasedXPowGate
250-
251244
def _value_equality_values_(self):
252-
if self.phase_exponent == 0:
253-
return common_gates.XPowGate(
254-
exponent=self._exponent, global_shift=self._global_shift
255-
)._value_equality_values_()
256-
if self.phase_exponent == 0.5:
257-
return common_gates.YPowGate(
258-
exponent=self._exponent, global_shift=self._global_shift
259-
)._value_equality_values_()
260245
return self.phase_exponent, self._canonical_exponent, self._global_shift
261246

262247
def _json_dict_(self) -> Dict[str, Any]:

cirq-core/cirq/ops/phased_x_gate_test.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,17 @@ def test_eq():
7979
cirq.PhasedXPowGate(exponent=1, phase_exponent=0),
8080
cirq.PhasedXPowGate(exponent=1, phase_exponent=2),
8181
cirq.PhasedXPowGate(exponent=1, phase_exponent=-2),
82-
cirq.X,
8382
)
83+
eq.add_equality_group(cirq.X)
8484
eq.add_equality_group(cirq.PhasedXPowGate(exponent=1, phase_exponent=2, global_shift=0.1))
8585

8686
eq.add_equality_group(
8787
cirq.PhasedXPowGate(phase_exponent=0.5, exponent=1),
8888
cirq.PhasedXPowGate(phase_exponent=2.5, exponent=3),
89-
cirq.Y,
9089
)
91-
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.25), cirq.Y**0.25)
90+
eq.add_equality_group(cirq.Y)
91+
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.25))
92+
eq.add_equality_group(cirq.Y**0.25)
9293

9394
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.25, exponent=0.25, global_shift=0.1))
9495
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=2.25, exponent=0.25, global_shift=0.2))
@@ -266,3 +267,18 @@ def test_exponent_consistency(exponent, phase_exponent):
266267
u = cirq.protocols.unitary(g)
267268
u2 = cirq.protocols.unitary(g2)
268269
assert np.all(u == u2)
270+
271+
272+
def test_approx_eq_for_close_phase_exponents():
273+
gate1 = cirq.PhasedXPowGate(phase_exponent=0)
274+
gate2 = cirq.PhasedXPowGate(phase_exponent=1e-12)
275+
gate3 = cirq.PhasedXPowGate(phase_exponent=2e-12)
276+
gate4 = cirq.PhasedXPowGate(phase_exponent=0.345)
277+
278+
assert cirq.approx_eq(gate2, gate3)
279+
assert cirq.approx_eq(gate2, gate1)
280+
assert not cirq.approx_eq(gate2, gate4)
281+
282+
assert cirq.equal_up_to_global_phase(gate2, gate3)
283+
assert cirq.equal_up_to_global_phase(gate2, gate1)
284+
assert not cirq.equal_up_to_global_phase(gate2, gate4)

cirq-core/cirq/transformers/eject_phased_paulis.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Transformer pass that pushes 180° rotations around axes in the XY plane later in the circuit."""
1616

17-
from typing import cast, Dict, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING
17+
from typing import cast, Dict, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING, Union
1818

1919
import numpy as np
2020
import sympy
@@ -63,15 +63,15 @@ def eject_phased_paulis(
6363
def map_func(op: 'cirq.Operation', _: int) -> 'cirq.OP_TREE':
6464
# Dump if `op` marked with a no compile tag.
6565
if set(op.tags) & tags_to_ignore:
66-
return [_dump_held(op.qubits, held_w_phases), op]
66+
return [_dump_held(op.qubits, held_w_phases, atol), op]
6767

6868
# Collect, phase, and merge Ws.
6969
w = _try_get_known_phased_pauli(op, no_symbolic=not eject_parameterized)
7070
if w is not None:
7171
return (
7272
_potential_cross_whole_w(op, atol, held_w_phases)
7373
if single_qubit_decompositions.is_negligible_turn((w[0] - 1) / 2, atol)
74-
else _potential_cross_partial_w(op, held_w_phases)
74+
else _potential_cross_partial_w(op, held_w_phases, atol)
7575
)
7676

7777
affected = [q for q in op.qubits if q in held_w_phases]
@@ -96,12 +96,12 @@ def map_func(op: 'cirq.Operation', _: int) -> 'cirq.OP_TREE':
9696
)
9797

9898
# Don't know how to handle this situation. Dump the gates.
99-
return [_dump_held(op.qubits, held_w_phases), op]
99+
return [_dump_held(op.qubits, held_w_phases, atol), op]
100100

101101
# Map operations and put anything that's still held at the end of the circuit.
102102
return circuits.Circuit(
103103
transformer_primitives.map_operations_and_unroll(circuit, map_func),
104-
_dump_held(held_w_phases.keys(), held_w_phases),
104+
_dump_held(held_w_phases.keys(), held_w_phases, atol),
105105
)
106106

107107

@@ -127,14 +127,14 @@ def _absorb_z_into_w(
127127

128128

129129
def _dump_held(
130-
qubits: Iterable[ops.Qid], held_w_phases: Dict[ops.Qid, value.TParamVal]
130+
qubits: Iterable[ops.Qid], held_w_phases: Dict[ops.Qid, value.TParamVal], atol: float
131131
) -> Iterator['cirq.OP_TREE']:
132132
# Note: sorting is to avoid non-determinism in the insertion order.
133133
for q in sorted(qubits):
134134
p = held_w_phases.get(q)
135135
if p is not None:
136-
dump_op = ops.PhasedXPowGate(phase_exponent=p).on(q)
137-
yield dump_op
136+
gate = _phased_x_or_pauli_gate(exponent=1.0, phase_exponent=p, atol=atol)
137+
yield gate.on(q)
138138
held_w_phases.pop(q, None)
139139

140140

@@ -184,7 +184,7 @@ def _potential_cross_whole_w(
184184

185185

186186
def _potential_cross_partial_w(
187-
op: ops.Operation, held_w_phases: Dict[ops.Qid, value.TParamVal]
187+
op: ops.Operation, held_w_phases: Dict[ops.Qid, value.TParamVal], atol: float
188188
) -> 'cirq.OP_TREE':
189189
"""Cross the held W over a partial W gate.
190190
@@ -204,10 +204,10 @@ def _potential_cross_partial_w(
204204
exponent, phase_exponent = cast(
205205
Tuple[value.TParamVal, value.TParamVal], _try_get_known_phased_pauli(op)
206206
)
207-
new_op = ops.PhasedXPowGate(exponent=exponent, phase_exponent=2 * a - phase_exponent).on(
208-
op.qubits[0]
207+
gate = _phased_x_or_pauli_gate(
208+
exponent=exponent, phase_exponent=2 * a - phase_exponent, atol=atol
209209
)
210-
return new_op
210+
return gate.on(op.qubits[0])
211211

212212

213213
def _single_cross_over_cz(op: ops.Operation, qubit_with_w: 'cirq.Qid') -> 'cirq.OP_TREE':
@@ -351,3 +351,16 @@ def _try_get_known_z_half_turns(
351351
if no_symbolic and isinstance(h, sympy.Basic):
352352
return None
353353
return h
354+
355+
356+
def _phased_x_or_pauli_gate(
357+
exponent: Union[float, sympy.Expr], phase_exponent: Union[float, sympy.Expr], atol: float
358+
) -> Union['cirq.PhasedXPowGate', 'cirq.XPowGate', 'cirq.YPowGate']:
359+
"""Return PhasedXPowGate or X or Y gate if equivalent within atol in z-axis turns."""
360+
if not isinstance(phase_exponent, sympy.Expr) or phase_exponent.is_constant():
361+
half_turns = value.canonicalize_half_turns(float(phase_exponent))
362+
if abs(half_turns / 2) <= atol:
363+
return ops.XPowGate(exponent=exponent)
364+
if abs((half_turns - 0.5) / 2) <= atol:
365+
return ops.YPowGate(exponent=exponent)
366+
return ops.PhasedXPowGate(exponent=exponent, phase_exponent=phase_exponent)

cirq-core/cirq/transformers/eject_phased_paulis_test.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,7 @@ def test_crosses_czs():
212212
[cirq.CZ(a, b) ** 0.25],
213213
),
214214
expected=quick_circuit(
215-
[cirq.CZ(a, b) ** 0.25],
216-
[
217-
cirq.PhasedXPowGate(phase_exponent=0.5).on(b),
218-
cirq.PhasedXPowGate(phase_exponent=0.25).on(a),
219-
],
215+
[cirq.CZ(a, b) ** 0.25], [cirq.Y(b), cirq.PhasedXPowGate(phase_exponent=0.25).on(a)]
220216
),
221217
)
222218
assert_optimizes(
@@ -387,8 +383,7 @@ def test_phases_partial_ws():
387383
assert_optimizes(
388384
before=quick_circuit([cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], [cirq.X(q) ** 0.5]),
389385
expected=quick_circuit(
390-
[cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.5).on(q)],
391-
[cirq.PhasedXPowGate(phase_exponent=0.25).on(q)],
386+
[cirq.Y(q) ** 0.5], [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)]
392387
),
393388
)
394389

cirq-core/cirq/transformers/eject_z_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ def assert_optimizes(
5151
cirq.Moment(cirq.CircuitOperation(before.freeze()).repeat(3).with_tags("preserve_tag")),
5252
)
5353
c_expected = cirq.Circuit(
54-
cirq.PhasedXPowGate(phase_exponent=0, exponent=0.25).on_each(*q),
54+
(cirq.X**0.25).on_each(*q),
5555
(cirq.Z**0.5).on_each(*q),
5656
cirq.Moment(cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore")),
57-
cirq.PhasedXPowGate(phase_exponent=0, exponent=0.25).on_each(*q),
57+
(cirq.X**0.25).on_each(*q),
5858
(cirq.Z**0.5).on_each(*q),
5959
cirq.Moment(cirq.CircuitOperation(expected.freeze()).repeat(3).with_tags("preserve_tag")),
6060
)

cirq-core/cirq/transformers/merge_single_qubit_gates_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_merge_single_qubit_gates_to_phased_x_and_z():
4545
optimized=cirq.merge_single_qubit_gates_to_phased_x_and_z(c),
4646
expected=cirq.Circuit(
4747
cirq.PhasedXPowGate(phase_exponent=1)(a),
48-
cirq.Y(b) ** 0.5,
48+
cirq.PhasedXPowGate(phase_exponent=0.5)(b) ** 0.5,
4949
cirq.CZ(a, b),
5050
(cirq.PhasedXPowGate(phase_exponent=-0.5)(a)) ** 0.5,
5151
cirq.measure(b, key="m"),

cirq-google/cirq_google/api/v1/programs_test.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,27 @@
1515
import pytest
1616
import sympy
1717

18-
import cirq
18+
import cirq.testing
1919
import cirq_google.api.v1.programs as programs
2020
from cirq_google.api.v1 import operations_pb2
2121

2222

2323
def assert_proto_dict_convert(gate: cirq.Gate, proto: operations_pb2.Operation, *qubits: cirq.Qid):
2424
assert programs.gate_to_proto(gate, qubits, delay=0) == proto
25-
assert programs.xmon_op_from_proto(proto) == gate(*qubits)
25+
xmon_op = programs.xmon_op_from_proto(proto)
26+
assert xmon_op.qubits == qubits
27+
assert xmon_op.gate == gate or np.allclose(cirq.unitary(xmon_op.gate), cirq.unitary(gate))
2628

2729

2830
def test_protobuf_round_trip():
2931
qubits = cirq.GridQubit.rect(1, 5)
3032
circuit = cirq.Circuit(
31-
[cirq.X(q) ** 0.5 for q in qubits],
32-
[
33-
cirq.CZ(q, q2)
34-
for q in [cirq.GridQubit(0, 0)]
35-
for q, q2 in zip(qubits, qubits)
36-
if q != q2
37-
],
33+
[cirq.X(q) ** 0.5 for q in qubits], [cirq.CZ(qubits[0], q1) for q1 in qubits[1:]]
3834
)
3935

4036
protos = list(programs.circuit_as_schedule_to_protos(circuit))
4137
s2 = programs.circuit_from_schedule_from_protos(protos)
42-
assert s2 == circuit
38+
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(s2, circuit)
4339

4440

4541
def make_bytes(s: str) -> bytes:

cirq-google/cirq_google/engine/engine_program_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from google.protobuf import any_pb2, timestamp_pb2
2121
from google.protobuf.text_format import Merge
2222

23-
import cirq
23+
import cirq.testing
2424
import cirq_google as cg
2525
from cirq_google.api import v1, v2
2626
from cirq_google.cloud import quantum
@@ -304,7 +304,9 @@ def test_get_circuit_v2(get_program_async):
304304

305305
program = cg.EngineProgram('a', 'b', EngineContext())
306306
get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2)
307-
assert program.get_circuit() == circuit
307+
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
308+
program.get_circuit(), circuit
309+
)
308310
get_program_async.assert_called_once_with('a', 'b', True)
309311

310312

0 commit comments

Comments
 (0)