Skip to content

Commit 58d9619

Browse files
authored
Flip back to default use_repetition_ids=True in CircuitOperation (#7237)
* Revert "CircuitOperation: change use_repetition_ids default to False (#6910)" Put back the default `use_repetition_ids=True` so we do not make API change without deprecation warning. This reverts commit 5ffb3ad. * Add FutureWarning for upcoming change of use_repetition_ids default * Adjust unit tests for default `use_repetition_ids=True`
1 parent 5ee16b7 commit 58d9619

File tree

5 files changed

+127
-152
lines changed

5 files changed

+127
-152
lines changed

cirq-core/cirq/circuits/circuit_operation.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from __future__ import annotations
2323

2424
import math
25+
import warnings
2526
from functools import cached_property
2627
from typing import (
2728
Any,
@@ -48,7 +49,6 @@
4849
if TYPE_CHECKING:
4950
import cirq
5051

51-
5252
INT_CLASSES = (int, np.integer)
5353
INT_TYPE = Union[int, np.integer]
5454
IntParam = Union[INT_TYPE, sympy.Expr]
@@ -123,8 +123,9 @@ def __init__(
123123
use_repetition_ids: When True, any measurement key in the subcircuit
124124
will have its path prepended with the repetition id for each
125125
repetition. When False, this will not happen and the measurement
126-
key will be repeated. When None, default to False unless the caller
127-
passes `repetition_ids` explicitly.
126+
key will be repeated. The default is True, but it will be changed
127+
to False in the next release. Please pass an explicit argument
128+
``use_repetition_ids=True`` to preserve the current behavior.
128129
repeat_until: A condition that will be tested after each iteration of
129130
the subcircuit. The subcircuit will repeat until condition returns
130131
True, but will always run at least once, and the measurement key
@@ -161,8 +162,18 @@ def __init__(
161162
self._repetitions = repetitions
162163
self._repetition_ids = None if repetition_ids is None else list(repetition_ids)
163164
if use_repetition_ids is None:
164-
use_repetition_ids = repetition_ids is not None
165-
self._use_repetition_ids = use_repetition_ids
165+
if repetition_ids is None:
166+
msg = (
167+
"In cirq 1.6 the default value of `use_repetition_ids` will change to\n"
168+
"`use_repetition_ids=False`. To make this warning go away, please pass\n"
169+
"explicit `use_repetition_ids`, e.g., to preserve current behavior, use\n"
170+
"\n"
171+
" CircuitOperations(..., use_repetition_ids=True)"
172+
)
173+
warnings.warn(msg, FutureWarning)
174+
self._use_repetition_ids = True
175+
else:
176+
self._use_repetition_ids = use_repetition_ids
166177
if isinstance(self._repetitions, float):
167178
if math.isclose(self._repetitions, round(self._repetitions)):
168179
self._repetitions = round(self._repetitions)
@@ -270,9 +281,7 @@ def replace(self, **changes) -> cirq.CircuitOperation:
270281
'repetition_ids': self.repetition_ids,
271282
'parent_path': self.parent_path,
272283
'extern_keys': self._extern_keys,
273-
'use_repetition_ids': (
274-
True if changes.get('repetition_ids') is not None else self.use_repetition_ids
275-
),
284+
'use_repetition_ids': self.use_repetition_ids,
276285
'repeat_until': self.repeat_until,
277286
**changes,
278287
}
@@ -476,9 +485,11 @@ def __repr__(self):
476485
args += f'param_resolver={proper_repr(self.param_resolver)},\n'
477486
if self.parent_path:
478487
args += f'parent_path={proper_repr(self.parent_path)},\n'
479-
if self.use_repetition_ids:
488+
if self.repetition_ids != self._default_repetition_ids():
480489
# Default repetition_ids need not be specified.
481490
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
491+
if not self.use_repetition_ids:
492+
args += 'use_repetition_ids=False,\n'
482493
if self.repeat_until:
483494
args += f'repeat_until={self.repeat_until!r},\n'
484495
indented_args = args.replace('\n', '\n ')
@@ -503,15 +514,14 @@ def dict_str(d: Mapping) -> str:
503514
args.append(f'params={self.param_resolver.param_dict}')
504515
if self.parent_path:
505516
args.append(f'parent_path={self.parent_path}')
506-
if self.use_repetition_ids:
507-
if self.repetition_ids != self._default_repetition_ids():
508-
args.append(f'repetition_ids={self.repetition_ids}')
509-
else:
510-
# Default repetition_ids need not be specified.
511-
args.append(f'loops={self.repetitions}, use_repetition_ids=True')
517+
if self.repetition_ids != self._default_repetition_ids():
518+
# Default repetition_ids need not be specified.
519+
args.append(f'repetition_ids={self.repetition_ids}')
512520
elif self.repetitions != 1:
513-
# Add loops if not using repetition_ids.
521+
# Only add loops if we haven't added repetition_ids.
514522
args.append(f'loops={self.repetitions}')
523+
if not self.use_repetition_ids:
524+
args.append('no_rep_ids')
515525
if self.repeat_until:
516526
args.append(f'until={self.repeat_until}')
517527
if not args:
@@ -556,9 +566,10 @@ def _json_dict_(self):
556566
'measurement_key_map': self.measurement_key_map,
557567
'param_resolver': self.param_resolver,
558568
'repetition_ids': self.repetition_ids,
559-
'use_repetition_ids': self.use_repetition_ids,
560569
'parent_path': self.parent_path,
561570
}
571+
if not self.use_repetition_ids:
572+
resp['use_repetition_ids'] = False
562573
if self.repeat_until:
563574
resp['repeat_until'] = self.repeat_until
564575
return resp
@@ -592,10 +603,7 @@ def _from_json_dict_(
592603
# Methods for constructing a similar object with one field modified.
593604

594605
def repeat(
595-
self,
596-
repetitions: Optional[IntParam] = None,
597-
repetition_ids: Optional[Sequence[str]] = None,
598-
use_repetition_ids: Optional[bool] = None,
606+
self, repetitions: Optional[IntParam] = None, repetition_ids: Optional[Sequence[str]] = None
599607
) -> CircuitOperation:
600608
"""Returns a copy of this operation repeated 'repetitions' times.
601609
Each repetition instance will be identified by a single repetition_id.
@@ -606,10 +614,6 @@ def repeat(
606614
defaults to the length of `repetition_ids`.
607615
repetition_ids: List of IDs, one for each repetition. If unset,
608616
defaults to `default_repetition_ids(repetitions)`.
609-
use_repetition_ids: If given, this specifies the value for `use_repetition_ids`
610-
of the resulting circuit operation. If not given, we enable ids if
611-
`repetition_ids` is not None, and otherwise fall back to
612-
`self.use_repetition_ids`.
613617
614618
Returns:
615619
A copy of this operation repeated `repetitions` times with the
@@ -624,9 +628,6 @@ def repeat(
624628
ValueError: Unexpected length of `repetition_ids`.
625629
ValueError: Both `repetitions` and `repetition_ids` are None.
626630
"""
627-
if use_repetition_ids is None:
628-
use_repetition_ids = True if repetition_ids is not None else self.use_repetition_ids
629-
630631
if repetitions is None:
631632
if repetition_ids is None:
632633
raise ValueError('At least one of repetitions and repetition_ids must be set')
@@ -640,7 +641,7 @@ def repeat(
640641
expected_repetition_id_length: int = np.abs(repetitions)
641642

642643
if repetition_ids is None:
643-
if use_repetition_ids:
644+
if self.use_repetition_ids:
644645
repetition_ids = default_repetition_ids(expected_repetition_id_length)
645646
elif len(repetition_ids) != expected_repetition_id_length:
646647
raise ValueError(
@@ -653,11 +654,7 @@ def repeat(
653654

654655
# The eventual number of repetitions of the returned CircuitOperation.
655656
final_repetitions = protocols.mul(self.repetitions, repetitions)
656-
return self.replace(
657-
repetitions=final_repetitions,
658-
repetition_ids=repetition_ids,
659-
use_repetition_ids=use_repetition_ids,
660-
)
657+
return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)
661658

662659
def __pow__(self, power: IntParam) -> cirq.CircuitOperation:
663660
return self.repeat(power)

cirq-core/cirq/circuits/circuit_operation_test.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -294,15 +294,15 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
294294
op_with_reps: Optional[cirq.CircuitOperation] = None
295295
rep_ids = []
296296
if use_default_ids_for_initial_rep:
297+
op_with_reps = op_base.repeat(initial_repetitions)
297298
rep_ids = ['0', '1', '2']
298-
op_with_reps = op_base.repeat(initial_repetitions, use_repetition_ids=True)
299+
assert op_base**initial_repetitions == op_with_reps
299300
else:
300301
rep_ids = ['a', 'b', 'c']
301302
op_with_reps = op_base.repeat(initial_repetitions, rep_ids)
302-
assert op_base**initial_repetitions != op_with_reps
303-
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
303+
assert op_base**initial_repetitions != op_with_reps
304+
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
304305
assert op_with_reps.repetitions == initial_repetitions
305-
assert op_with_reps.use_repetition_ids
306306
assert op_with_reps.repetition_ids == rep_ids
307307
assert op_with_reps.repeat(1) is op_with_reps
308308

@@ -332,6 +332,8 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
332332
assert op_base.repeat(2.99999999999).repetitions == 3
333333

334334

335+
# TODO: #7232 - enable and fix immediately after the 1.5.0 release
336+
@pytest.mark.xfail(reason='broken by rollback of use_repetition_ids for #7232')
335337
def test_replace_repetition_ids() -> None:
336338
a, b = cirq.LineQubit.range(2)
337339
circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b), cirq.M(b, key='mb'), cirq.M(a, key='ma'))
@@ -458,7 +460,6 @@ def test_parameterized_repeat_side_effects():
458460
op = cirq.CircuitOperation(
459461
cirq.FrozenCircuit(cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')),
460462
repetitions=sympy.Symbol('a'),
461-
use_repetition_ids=True,
462463
)
463464

464465
# Control keys can be calculated because they only "lift" if there's a matching
@@ -712,6 +713,7 @@ def test_string_format():
712713
),
713714
),
714715
]),
716+
use_repetition_ids=False,
715717
)"""
716718
)
717719
op7 = cirq.CircuitOperation(
@@ -728,6 +730,7 @@ def test_string_format():
728730
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')),
729731
),
730732
]),
733+
use_repetition_ids=False,
731734
repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')),
732735
)"""
733736
)
@@ -758,7 +761,6 @@ def test_json_dict():
758761
'param_resolver': op.param_resolver,
759762
'parent_path': op.parent_path,
760763
'repetition_ids': None,
761-
'use_repetition_ids': False,
762764
}
763765

764766

@@ -865,26 +867,6 @@ def test_decompose_loops_with_measurements():
865867
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
866868
base_op = cirq.CircuitOperation(circuit)
867869

868-
op = base_op.with_qubits(b, a).repeat(3)
869-
expected_circuit = cirq.Circuit(
870-
cirq.H(b),
871-
cirq.CX(b, a),
872-
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
873-
cirq.H(b),
874-
cirq.CX(b, a),
875-
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
876-
cirq.H(b),
877-
cirq.CX(b, a),
878-
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
879-
)
880-
assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit
881-
882-
883-
def test_decompose_loops_with_measurements_use_rep_ids():
884-
a, b = cirq.LineQubit.range(2)
885-
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
886-
base_op = cirq.CircuitOperation(circuit, use_repetition_ids=True)
887-
888870
op = base_op.with_qubits(b, a).repeat(3)
889871
expected_circuit = cirq.Circuit(
890872
cirq.H(b),
@@ -1041,9 +1023,7 @@ def test_keys_under_parent_path():
10411023
op3 = cirq.with_key_path_prefix(op2, ('C',))
10421024
assert cirq.measurement_key_names(op3) == {'C:B:A'}
10431025
op4 = op3.repeat(2)
1044-
assert cirq.measurement_key_names(op4) == {'C:B:A'}
1045-
op4_rep = op3.repeat(2).replace(use_repetition_ids=True)
1046-
assert cirq.measurement_key_names(op4_rep) == {'C:B:0:A', 'C:B:1:A'}
1026+
assert cirq.measurement_key_names(op4) == {'C:B:0:A', 'C:B:1:A'}
10471027

10481028

10491029
def test_mapped_circuit_preserves_moments():
@@ -1121,8 +1101,12 @@ def test_mapped_circuit_allows_repeated_keys():
11211101
def test_simulate_no_repetition_ids_both_levels(sim):
11221102
q = cirq.LineQubit(0)
11231103
inner = cirq.Circuit(cirq.measure(q, key='a'))
1124-
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
1125-
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
1104+
middle = cirq.Circuit(
1105+
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
1106+
)
1107+
outer_subcircuit = cirq.CircuitOperation(
1108+
middle.freeze(), repetitions=2, use_repetition_ids=False
1109+
)
11261110
circuit = cirq.Circuit(outer_subcircuit)
11271111
result = sim.run(circuit)
11281112
assert result.records['a'].shape == (1, 4, 1)
@@ -1132,10 +1116,10 @@ def test_simulate_no_repetition_ids_both_levels(sim):
11321116
def test_simulate_no_repetition_ids_outer(sim):
11331117
q = cirq.LineQubit(0)
11341118
inner = cirq.Circuit(cirq.measure(q, key='a'))
1135-
middle = cirq.Circuit(
1136-
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=True)
1119+
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
1120+
outer_subcircuit = cirq.CircuitOperation(
1121+
middle.freeze(), repetitions=2, use_repetition_ids=False
11371122
)
1138-
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
11391123
circuit = cirq.Circuit(outer_subcircuit)
11401124
result = sim.run(circuit)
11411125
assert result.records['0:a'].shape == (1, 2, 1)
@@ -1146,10 +1130,10 @@ def test_simulate_no_repetition_ids_outer(sim):
11461130
def test_simulate_no_repetition_ids_inner(sim):
11471131
q = cirq.LineQubit(0)
11481132
inner = cirq.Circuit(cirq.measure(q, key='a'))
1149-
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
1150-
outer_subcircuit = cirq.CircuitOperation(
1151-
middle.freeze(), repetitions=2, use_repetition_ids=True
1133+
middle = cirq.Circuit(
1134+
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
11521135
)
1136+
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
11531137
circuit = cirq.Circuit(outer_subcircuit)
11541138
result = sim.run(circuit)
11551139
assert result.records['0:a'].shape == (1, 2, 1)
@@ -1164,6 +1148,7 @@ def test_repeat_until(sim):
11641148
cirq.X(q),
11651149
cirq.CircuitOperation(
11661150
cirq.FrozenCircuit(cirq.X(q), cirq.measure(q, key=key)),
1151+
use_repetition_ids=False,
11671152
repeat_until=cirq.KeyCondition(key),
11681153
),
11691154
)
@@ -1178,6 +1163,7 @@ def test_repeat_until_sympy(sim):
11781163
q1, q2 = cirq.LineQubit.range(2)
11791164
circuitop = cirq.CircuitOperation(
11801165
cirq.FrozenCircuit(cirq.X(q2), cirq.measure(q2, key='b')),
1166+
use_repetition_ids=False,
11811167
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), sympy.Symbol('b'))),
11821168
)
11831169
c = cirq.Circuit(cirq.measure(q1, key='a'), circuitop)
@@ -1197,6 +1183,7 @@ def test_post_selection(sim):
11971183
c = cirq.Circuit(
11981184
cirq.CircuitOperation(
11991185
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
1186+
use_repetition_ids=False,
12001187
repeat_until=cirq.KeyCondition(key),
12011188
)
12021189
)
@@ -1212,13 +1199,14 @@ def test_repeat_until_diagram():
12121199
c = cirq.Circuit(
12131200
cirq.CircuitOperation(
12141201
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
1202+
use_repetition_ids=False,
12151203
repeat_until=cirq.KeyCondition(key),
12161204
)
12171205
)
12181206
cirq.testing.assert_has_diagram(
12191207
c,
12201208
"""
1221-
0: ───[ 0: ───X^0.2───M('m')─── ](until=m)───
1209+
0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)───
12221210
""",
12231211
use_unicode_characters=True,
12241212
)
@@ -1235,6 +1223,7 @@ def test_repeat_until_error():
12351223
with pytest.raises(ValueError, match='Infinite loop'):
12361224
cirq.CircuitOperation(
12371225
cirq.FrozenCircuit(cirq.measure(q, key='m')),
1226+
use_repetition_ids=False,
12381227
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
12391228
)
12401229

@@ -1244,6 +1233,8 @@ def test_repeat_until_protocols():
12441233
op = cirq.CircuitOperation(
12451234
cirq.FrozenCircuit(cirq.H(q) ** sympy.Symbol('p'), cirq.measure(q, key='a')),
12461235
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), 0)),
1236+
# TODO: #7232 - remove immediately after the 1.5.0 release
1237+
use_repetition_ids=False,
12471238
)
12481239
scoped = cirq.with_rescoped_keys(op, ('0',))
12491240
# Ensure the _repeat_until has been mapped, the measurement has been mapped to the same key,
@@ -1276,6 +1267,8 @@ def test_inner_repeat_until_simulate():
12761267
inner_loop = cirq.CircuitOperation(
12771268
cirq.FrozenCircuit(cirq.H(q), cirq.measure(q, key="inner_loop")),
12781269
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol("inner_loop"), 0)),
1270+
# TODO: #7232 - remove immediately after the 1.5.0 release
1271+
use_repetition_ids=False,
12791272
)
12801273
outer_loop = cirq.Circuit(inner_loop, cirq.X(q), cirq.measure(q, key="outer_loop"))
12811274
circuit = cirq.Circuit(

0 commit comments

Comments
 (0)