Skip to content

Commit 346c95b

Browse files
andbe91pavoljuhas
andauthored
Promote FrozenCircuit.tags to AbstractCircuit (#7476)
First step towards fixing #7454. This PR just adds the `tags` attribute to `AbstractCircuit`, along with some of the logic (e.g. `__eq__`, `_json_dict_`, etc). After this we have to add support in the proto serialization itself. Partially implements #7454 --------- Co-authored-by: Pavol Juhas <[email protected]>
1 parent e42eac2 commit 346c95b

File tree

4 files changed

+183
-68
lines changed

4 files changed

+183
-68
lines changed

cirq-core/cirq/circuits/circuit.py

Lines changed: 103 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Any,
3434
Callable,
3535
cast,
36+
Hashable,
3637
Iterable,
3738
Iterator,
3839
Mapping,
@@ -141,7 +142,9 @@ class AbstractCircuit(abc.ABC):
141142
"""
142143

143144
@classmethod
144-
def from_moments(cls: type[CIRCUIT_TYPE], *moments: cirq.OP_TREE | None) -> CIRCUIT_TYPE:
145+
def from_moments(
146+
cls: type[CIRCUIT_TYPE], *moments: cirq.OP_TREE | None, tags: Sequence[Hashable] = ()
147+
) -> CIRCUIT_TYPE:
145148
"""Create a circuit from moment op trees.
146149
147150
Args:
@@ -155,8 +158,12 @@ def from_moments(cls: type[CIRCUIT_TYPE], *moments: cirq.OP_TREE | None) -> CIRC
155158
which is then included in the new circuit. Note that in this
156159
case we have the normal restriction that operations in a
157160
moment must be applied to disjoint sets of qubits.
161+
tags: A sequence of any type of object that is useful to attach metadata
162+
to this circuit as long as the type is hashable. If you wish the
163+
resulting circuit to be eventually serialized into JSON, you should
164+
also restrict the tags to be JSON serializable.
158165
"""
159-
return cls._from_moments(cls._make_moments(moments))
166+
return cls._from_moments(cls._make_moments(moments), tags=tags)
160167

161168
@staticmethod
162169
def _make_moments(moments: Iterable[cirq.OP_TREE | None]) -> Iterator[cirq.Moment]:
@@ -170,7 +177,9 @@ def _make_moments(moments: Iterable[cirq.OP_TREE | None]) -> Iterator[cirq.Momen
170177

171178
@classmethod
172179
@abc.abstractmethod
173-
def _from_moments(cls: type[CIRCUIT_TYPE], moments: Iterable[cirq.Moment]) -> CIRCUIT_TYPE:
180+
def _from_moments(
181+
cls: type[CIRCUIT_TYPE], moments: Iterable[cirq.Moment], tags: Sequence[Hashable]
182+
) -> CIRCUIT_TYPE:
174183
"""Create a circuit from moments.
175184
176185
This must be implemented by subclasses. It provides a more efficient way
@@ -201,6 +210,20 @@ def unfreeze(self, copy: bool = True) -> cirq.Circuit:
201210
copy: If True and 'self' is a Circuit, returns a copy that circuit.
202211
"""
203212

213+
@property
214+
@abc.abstractmethod
215+
def tags(self) -> tuple[Hashable, ...]:
216+
"""Returns a tuple of the Circuit's tags."""
217+
218+
@abc.abstractmethod
219+
def with_tags(self, *new_tags: Hashable) -> Self:
220+
"""Creates a new tagged Circuit with `self.tags` and `new_tags` combined."""
221+
222+
@property
223+
def untagged(self) -> Self:
224+
"""Returns the underlying Circuit without any tags."""
225+
return self._from_moments(self.moments, tags=()) if self.tags else self
226+
204227
def __bool__(self) -> bool:
205228
return bool(self.moments)
206229

@@ -210,14 +233,16 @@ def __eq__(self, other) -> bool:
210233
return other is self or (
211234
len(self.moments) == len(other.moments)
212235
and all(m0 == m1 for m0, m1 in zip(self.moments, other.moments))
236+
and self.tags == other.tags
213237
)
214238

215239
def _approx_eq_(self, other: Any, atol: float) -> bool:
216240
"""See `cirq.protocols.SupportsApproximateEquality`."""
217241
if not isinstance(other, AbstractCircuit):
218242
return NotImplemented
219-
return other is self or cirq.protocols.approx_eq(
220-
tuple(self.moments), tuple(other.moments), atol=atol
243+
return other is self or (
244+
self.tags == other.tags
245+
and cirq.protocols.approx_eq(tuple(self.moments), tuple(other.moments), atol=atol)
221246
)
222247

223248
def __ne__(self, other) -> bool:
@@ -259,7 +284,7 @@ def __getitem__(self, key: tuple[slice, Iterable[cirq.Qid]]) -> Self:
259284

260285
def __getitem__(self, key):
261286
if isinstance(key, slice):
262-
return self._from_moments(self.moments[key])
287+
return self._from_moments(self.moments[key], tags=self.tags)
263288
if hasattr(key, '__index__'):
264289
return self.moments[key]
265290
if isinstance(key, tuple):
@@ -272,7 +297,9 @@ def __getitem__(self, key):
272297
return selected_moments[qubit_idx]
273298
if isinstance(qubit_idx, ops.Qid):
274299
qubit_idx = [qubit_idx]
275-
return self._from_moments(moment[qubit_idx] for moment in selected_moments)
300+
return self._from_moments(
301+
(moment[qubit_idx] for moment in selected_moments), tags=self.tags
302+
)
276303

277304
raise TypeError('__getitem__ called with key not of type slice, int, or tuple.')
278305

@@ -283,7 +310,9 @@ def _repr_args(self) -> str:
283310
args = []
284311
if self.moments:
285312
args.append(_list_repr_with_indented_item_lines(self.moments))
286-
return f'{", ".join(args)}'
313+
moments_repr = f'{", ".join(args)}'
314+
tag_repr = ','.join(_compat.proper_repr(t) for t in self.tags)
315+
return f'{moments_repr}, tags=[{tag_repr}]' if self.tags else moments_repr
287316

288317
def __repr__(self) -> str:
289318
cls_name = self.__class__.__name__
@@ -942,7 +971,9 @@ def map_moment(moment: cirq.Moment) -> cirq.Circuit:
942971
"""Apply func to expand each op into a circuit, then zip up the circuits."""
943972
return Circuit.zip(*[Circuit(func(op)) for op in moment])
944973

945-
return self._from_moments(m for moment in self for m in map_moment(moment))
974+
return self._from_moments(
975+
(m for moment in self for m in map_moment(moment)), tags=self.tags
976+
)
946977

947978
def qid_shape(
948979
self, qubit_order: cirq.QubitOrderOrList = ops.QubitOrder.DEFAULT
@@ -983,15 +1014,19 @@ def _measurement_key_names_(self) -> frozenset[str]:
9831014

9841015
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
9851016
return self._from_moments(
986-
protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments
1017+
(protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments),
1018+
tags=self.tags,
9871019
)
9881020

9891021
def _with_key_path_(self, path: tuple[str, ...]):
990-
return self._from_moments(protocols.with_key_path(moment, path) for moment in self.moments)
1022+
return self._from_moments(
1023+
(protocols.with_key_path(moment, path) for moment in self.moments), tags=self.tags
1024+
)
9911025

9921026
def _with_key_path_prefix_(self, prefix: tuple[str, ...]):
9931027
return self._from_moments(
994-
protocols.with_key_path_prefix(moment, prefix) for moment in self.moments
1028+
(protocols.with_key_path_prefix(moment, prefix) for moment in self.moments),
1029+
tags=self.tags,
9951030
)
9961031

9971032
def _with_rescoped_keys_(
@@ -1002,7 +1037,7 @@ def _with_rescoped_keys_(
10021037
new_moment = protocols.with_rescoped_keys(moment, path, bindable_keys)
10031038
moments.append(new_moment)
10041039
bindable_keys |= protocols.measurement_key_objs(new_moment)
1005-
return self._from_moments(moments)
1040+
return self._from_moments(moments, tags=self.tags)
10061041

10071042
def _qid_shape_(self) -> tuple[int, ...]:
10081043
return self.qid_shape()
@@ -1300,22 +1335,33 @@ def default_namer(label_entity):
13001335
return diagram
13011336

13021337
def _is_parameterized_(self) -> bool:
1303-
return any(protocols.is_parameterized(op) for op in self.all_operations())
1338+
return any(protocols.is_parameterized(op) for op in self.all_operations()) or any(
1339+
protocols.is_parameterized(tag) for tag in self.tags
1340+
)
13041341

13051342
def _parameter_names_(self) -> AbstractSet[str]:
1306-
return {name for op in self.all_operations() for name in protocols.parameter_names(op)}
1343+
op_params = {name for op in self.all_operations() for name in protocols.parameter_names(op)}
1344+
tag_params = {name for tag in self.tags for name in protocols.parameter_names(tag)}
1345+
return op_params | tag_params
13071346

13081347
def _resolve_parameters_(self, resolver: cirq.ParamResolver, recursive: bool) -> Self:
13091348
changed = False
13101349
resolved_moments: list[cirq.Moment] = []
1350+
resolved_tags: list[Hashable] = []
13111351
for moment in self:
13121352
resolved_moment = protocols.resolve_parameters(moment, resolver, recursive)
13131353
if resolved_moment is not moment:
13141354
changed = True
13151355
resolved_moments.append(resolved_moment)
1316-
if not changed:
1356+
for tag in self.tags:
1357+
resolved_tag = protocols.resolve_parameters(tag, resolver, recursive)
1358+
if resolved_tag is not tag:
1359+
changed = True
1360+
resolved_tags.append(resolved_tag)
1361+
if changed:
1362+
return self._from_moments(resolved_moments, tags=resolved_tags)
1363+
else:
13171364
return self # pragma: no cover
1318-
return self._from_moments(resolved_moments)
13191365

13201366
def _qasm_(self, args: cirq.QasmArgs | None = None) -> str:
13211367
if args is None:
@@ -1394,11 +1440,13 @@ def save_qasm(
13941440
self._to_qasm_output(header, precision, qubit_order).save(file_path)
13951441

13961442
def _json_dict_(self):
1397-
return protocols.obj_to_dict_helper(self, ['moments'])
1443+
attribute_names = ['moments', 'tags'] if self.tags else ['moments']
1444+
ret = protocols.obj_to_dict_helper(self, attribute_names)
1445+
return ret
13981446

13991447
@classmethod
1400-
def _from_json_dict_(cls, moments, **kwargs):
1401-
return cls(moments, strategy=InsertStrategy.EARLIEST)
1448+
def _from_json_dict_(cls, moments, tags=(), **kwargs):
1449+
return cls(moments, tags=tags, strategy=InsertStrategy.EARLIEST)
14021450

14031451
def zip(
14041452
*circuits: cirq.AbstractCircuit, align: cirq.Alignment | str = Alignment.LEFT
@@ -1462,7 +1510,7 @@ def zip(
14621510
if isinstance(align, str):
14631511
align = Alignment[align.upper()]
14641512

1465-
result = cirq.Circuit()
1513+
result = cirq.Circuit(tags=circuits[0].tags if circuits else ())
14661514
for k in range(n):
14671515
try:
14681516
if align == Alignment.LEFT:
@@ -1531,7 +1579,7 @@ def concat_ragged(
15311579
for k in range(1, len(circuits)):
15321580
offset, n_acc = _concat_ragged_helper(offset, n_acc, buffer, circuits[k].moments, align)
15331581

1534-
return cirq.Circuit(buffer[offset : offset + n_acc])
1582+
return cirq.Circuit(buffer[offset : offset + n_acc], tags=circuits[0].tags)
15351583

15361584
def get_independent_qubit_sets(self) -> list[set[cirq.Qid]]:
15371585
"""Divide circuit's qubits into independent qubit sets.
@@ -1610,7 +1658,10 @@ def factorize(self) -> Iterable[Self]:
16101658
# the qubits from one factor belong to a specific independent qubit set.
16111659
# This makes it possible to create independent circuits based on these
16121660
# moments.
1613-
return (self._from_moments(m[qubits] for m in self.moments) for qubits in qubit_factors)
1661+
return (
1662+
self._from_moments([m[qubits] for m in self.moments], tags=self.tags)
1663+
for qubits in qubit_factors
1664+
)
16141665

16151666
def _control_keys_(self) -> frozenset[cirq.MeasurementKey]:
16161667
controls = frozenset(k for op in self.all_operations() for k in protocols.control_keys(op))
@@ -1753,7 +1804,10 @@ class Circuit(AbstractCircuit):
17531804
"""
17541805

17551806
def __init__(
1756-
self, *contents: cirq.OP_TREE, strategy: cirq.InsertStrategy = InsertStrategy.EARLIEST
1807+
self,
1808+
*contents: cirq.OP_TREE,
1809+
strategy: cirq.InsertStrategy = InsertStrategy.EARLIEST,
1810+
tags: Sequence[Hashable] = (),
17571811
) -> None:
17581812
"""Initializes a circuit.
17591813
@@ -1767,9 +1821,14 @@ def __init__(
17671821
from `contents`, this determines how the operations are packed
17681822
together. This option does not affect later insertions into the
17691823
circuit.
1824+
tags: A sequence of any type of object that is useful to attach metadata
1825+
to this circuit as long as the type is hashable. If you wish the
1826+
resulting circuit to be eventually serialized into JSON, you should
1827+
also restrict the tags to be JSON serializable.
17701828
"""
17711829
self._placement_cache: _PlacementCache | None = _PlacementCache()
17721830
self._moments: list[cirq.Moment] = []
1831+
self._tags = tuple(tags)
17731832

17741833
# Implementation note: the following cached properties are set lazily and then
17751834
# invalidated and reset to None in `self._mutated()`, which is called any time
@@ -1803,10 +1862,11 @@ def _mutated(self, *, preserve_placement_cache=False) -> None:
18031862
self._placement_cache = None
18041863

18051864
@classmethod
1806-
def _from_moments(cls, moments: Iterable[cirq.Moment]) -> Circuit:
1865+
def _from_moments(cls, moments: Iterable[cirq.Moment], tags: Sequence[Hashable]) -> Circuit:
18071866
new_circuit = Circuit()
18081867
new_circuit._moments[:] = moments
18091868
new_circuit._placement_cache = None
1869+
new_circuit._tags = tuple(tags)
18101870
return new_circuit
18111871

18121872
def _load_contents_with_earliest_strategy(self, contents: cirq.OP_TREE):
@@ -1865,7 +1925,7 @@ def freeze(self) -> cirq.FrozenCircuit:
18651925
from cirq.circuits.frozen_circuit import FrozenCircuit
18661926

18671927
if self._frozen is None:
1868-
self._frozen = FrozenCircuit._from_moments(self._moments)
1928+
self._frozen = FrozenCircuit._from_moments(self._moments, tags=self.tags)
18691929
return self._frozen
18701930

18711931
def unfreeze(self, copy: bool = True) -> cirq.Circuit:
@@ -1894,8 +1954,9 @@ def _parameter_names_(self) -> AbstractSet[str]:
18941954
def copy(self) -> Circuit:
18951955
"""Return a copy of this circuit."""
18961956
copied_circuit = Circuit()
1897-
copied_circuit._moments = self._moments[:]
1957+
copied_circuit._moments[:] = self._moments
18981958
copied_circuit._placement_cache = None
1959+
copied_circuit._tags = self.tags
18991960
return copied_circuit
19001961

19011962
@overload
@@ -1955,7 +2016,7 @@ def __imul__(self, repetitions: _INT_TYPE):
19552016
def __mul__(self, repetitions: _INT_TYPE):
19562017
if not isinstance(repetitions, (int, np.integer)):
19572018
return NotImplemented
1958-
return Circuit(self._moments * int(repetitions))
2019+
return Circuit(self._moments * int(repetitions), tags=self.tags)
19592020

19602021
def __rmul__(self, repetitions: _INT_TYPE):
19612022
if not isinstance(repetitions, (int, np.integer)):
@@ -1981,7 +2042,7 @@ def __pow__(self, exponent: int) -> cirq.Circuit:
19812042
return NotImplemented
19822043
inv_moments.append(inv_moment)
19832044

1984-
return cirq.Circuit(inv_moments)
2045+
return cirq.Circuit(inv_moments, tags=self.tags)
19852046

19862047
__hash__ = None # type: ignore
19872048

@@ -2466,6 +2527,18 @@ def clear_operations_touching(
24662527
def moments(self) -> Sequence[cirq.Moment]:
24672528
return self._moments
24682529

2530+
@property
2531+
def tags(self) -> tuple[Hashable, ...]:
2532+
return self._tags
2533+
2534+
def with_tags(self, *new_tags: Hashable) -> cirq.Circuit:
2535+
"""Creates a new tagged `Circuit` with `self.tags` and `new_tags` combined."""
2536+
if not new_tags:
2537+
return self
2538+
new_circuit = Circuit(tags=self.tags + new_tags)
2539+
new_circuit._moments[:] = self._moments
2540+
return new_circuit
2541+
24692542
def with_noise(self, noise: cirq.NOISE_MODEL_LIKE) -> cirq.Circuit:
24702543
"""Make a noisy version of the circuit.
24712544
@@ -2480,7 +2553,7 @@ def with_noise(self, noise: cirq.NOISE_MODEL_LIKE) -> cirq.Circuit:
24802553
"""
24812554
noise_model = devices.NoiseModel.from_noise_model_like(noise)
24822555
qubits = sorted(self.all_qubits())
2483-
c_noisy = Circuit()
2556+
c_noisy = Circuit(tags=self.tags)
24842557
for op_tree in noise_model.noisy_moments(self, qubits):
24852558
# Keep moments aligned
24862559
c_noisy += Circuit(op_tree)

0 commit comments

Comments
 (0)