Skip to content

Commit 8f7ced7

Browse files
Make with_measurement_key_mapping live up to its name (#5610)
Fixes #5552. Since this change only* affects type annotations, it is non-breaking. Any breakages due to `ParamResolver.param_dict` becoming immutable should instead refer to #5548, where this change was applied. *Okay, it also removes some `dict` calls, but those were added in #5548.
1 parent ecd4c81 commit 8f7ced7

13 files changed

+49
-24
lines changed

cirq-core/cirq/circuits/circuit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
AbstractSet,
3030
Any,
3131
Callable,
32+
Mapping,
3233
cast,
3334
Dict,
3435
FrozenSet,
@@ -944,7 +945,7 @@ def all_measurement_key_names(self) -> FrozenSet[str]:
944945
def _measurement_key_names_(self) -> FrozenSet[str]:
945946
return self.all_measurement_key_names()
946947

947-
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
948+
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
948949
return self._with_sliced_moments(
949950
[protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments]
950951
)

cirq-core/cirq/circuits/circuit_operation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def _measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']:
322322
key.with_key_path_prefix(*self.parent_path) for key in circuit_keys
323323
)
324324
return frozenset(
325-
protocols.with_measurement_key_mapping(key, dict(self.measurement_key_map))
325+
protocols.with_measurement_key_mapping(key, self.measurement_key_map)
326326
for key in circuit_keys
327327
)
328328

@@ -368,9 +368,7 @@ def _mapped_any_loop(self) -> 'cirq.Circuit':
368368
if isinstance(self.repetitions, INT_CLASSES) and self.repetitions < 0:
369369
circuit = circuit**-1
370370
if self.measurement_key_map:
371-
circuit = protocols.with_measurement_key_mapping(
372-
circuit, dict(self.measurement_key_map)
373-
)
371+
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
374372
if self.param_resolver:
375373
circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False)
376374
return circuit.unfreeze(copy=False)

cirq-core/cirq/circuits/moment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
FrozenSet,
2323
Iterable,
2424
Iterator,
25+
Mapping,
2526
overload,
2627
Optional,
2728
Sequence,
@@ -229,7 +230,7 @@ def without_operations_touching(self, qubits: Iterable['cirq.Qid']) -> 'cirq.Mom
229230
if qubits.isdisjoint(frozenset(operation.qubits))
230231
)
231232

232-
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
233+
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
233234
return Moment(
234235
protocols.with_measurement_key_mapping(op, key_map)
235236
if protocols.measurement_keys_touched(op)

cirq-core/cirq/ops/classically_controlled_operation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import (
1515
AbstractSet,
1616
Any,
17+
Mapping,
1718
cast,
1819
Dict,
1920
FrozenSet,
@@ -178,7 +179,7 @@ def _act_on_(self, sim_state: 'cirq.SimulationStateBase') -> bool:
178179
return True
179180

180181
def _with_measurement_key_mapping_(
181-
self, key_map: Dict[str, str]
182+
self, key_map: Mapping[str, str]
182183
) -> 'ClassicallyControlledOperation':
183184
conditions = [protocols.with_measurement_key_mapping(c, key_map) for c in self._conditions]
184185
sub_operation = protocols.with_measurement_key_mapping(self._sub_operation, key_map)

cirq-core/cirq/ops/gate_operation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import (
1919
AbstractSet,
2020
Any,
21+
Mapping,
2122
cast,
2223
Collection,
2324
Dict,
@@ -81,7 +82,7 @@ def with_gate(self, new_gate: 'cirq.Gate') -> 'cirq.Operation':
8182
return self
8283
return new_gate.on(*self.qubits)
8384

84-
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
85+
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
8586
new_gate = protocols.with_measurement_key_mapping(self.gate, key_map)
8687
if new_gate is NotImplemented:
8788
return NotImplemented

cirq-core/cirq/ops/kraus_channel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# pylint: disable=wrong-or-nonexistent-copyright-notice
2-
from typing import Any, Dict, FrozenSet, Iterable, Tuple, TYPE_CHECKING, Union
2+
from typing import Any, Dict, FrozenSet, Iterable, Mapping, Tuple, TYPE_CHECKING, Union
33
import numpy as np
44

55
from cirq import linalg, protocols, value
@@ -84,7 +84,7 @@ def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':
8484
return NotImplemented
8585
return self._key
8686

87-
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
87+
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
8888
if self._key is None:
8989
return NotImplemented
9090
if self._key not in key_map:

cirq-core/cirq/ops/measurement_gate.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, FrozenSet, Iterable, Optional, Tuple, Sequence, TYPE_CHECKING, Union
15+
from typing import (
16+
Any,
17+
Dict,
18+
FrozenSet,
19+
Iterable,
20+
Mapping,
21+
Optional,
22+
Tuple,
23+
Sequence,
24+
TYPE_CHECKING,
25+
Union,
26+
)
1627

1728
import numpy as np
1829

@@ -129,7 +140,7 @@ def _with_rescoped_keys_(
129140
):
130141
return self.with_key(protocols.with_rescoped_keys(self.mkey, path, bindable_keys))
131142

132-
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
143+
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
133144
return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map))
134145

135146
def with_bits_flipped(self, *bit_positions: int) -> 'MeasurementGate':

cirq-core/cirq/ops/mixed_unitary_channel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# pylint: disable=wrong-or-nonexistent-copyright-notice
2-
from typing import Any, Dict, FrozenSet, Iterable, Tuple, TYPE_CHECKING, Union
2+
from typing import Any, Dict, FrozenSet, Iterable, Mapping, Tuple, TYPE_CHECKING, Union
33
import numpy as np
44

55
from cirq import linalg, protocols, value
@@ -87,7 +87,7 @@ def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':
8787
return NotImplemented
8888
return self._key
8989

90-
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
90+
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
9191
if self._key is None:
9292
return NotImplemented
9393
if self._key not in key_map:

cirq-core/cirq/ops/pauli_measurement_gate.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, FrozenSet, Iterable, Tuple, Sequence, TYPE_CHECKING, Union, cast
15+
from typing import (
16+
Any,
17+
Dict,
18+
FrozenSet,
19+
Iterable,
20+
Mapping,
21+
Tuple,
22+
Sequence,
23+
TYPE_CHECKING,
24+
Union,
25+
cast,
26+
)
1627

1728
from cirq import protocols, value
1829
from cirq.ops import (
@@ -103,7 +114,7 @@ def _with_rescoped_keys_(
103114
) -> 'PauliMeasurementGate':
104115
return self.with_key(protocols.with_rescoped_keys(self.mkey, path, bindable_keys))
105116

106-
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'PauliMeasurementGate':
117+
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]) -> 'PauliMeasurementGate':
107118
return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map))
108119

109120
def with_observable(

cirq-core/cirq/ops/raw_types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Hashable,
2727
Iterable,
2828
List,
29+
Mapping,
2930
Optional,
3031
Sequence,
3132
Tuple,
@@ -733,7 +734,7 @@ def gate(self) -> Optional['cirq.Gate']:
733734
def with_qubits(self, *new_qubits: 'cirq.Qid'):
734735
return TaggedOperation(self.sub_operation.with_qubits(*new_qubits), *self._tags)
735736

736-
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
737+
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
737738
sub_op = protocols.with_measurement_key_mapping(self.sub_operation, key_map)
738739
if sub_op is NotImplemented:
739740
return NotImplemented

0 commit comments

Comments
 (0)