@@ -104,6 +104,134 @@ def map_moments(
104104 )
105105
106106
107+ def _map_operations_impl (
108+ circuit : CIRCUIT_TYPE ,
109+ map_func : Callable [[ops .Operation , int ], ops .OP_TREE ],
110+ * ,
111+ deep : bool = False ,
112+ raise_if_add_qubits = True ,
113+ tags_to_ignore : Sequence [Hashable ] = (),
114+ wrap_in_circuit_op : bool = True ,
115+ ) -> CIRCUIT_TYPE :
116+ """Applies local transformations, by calling `map_func(op, moment_index)` for each operation.
117+
118+ This method provides a fast, iterative implementation for the two `map_operations_*` variants
119+ exposed as public transformer primitives. The high level idea for the iterative implementation
120+ is to
121+ 1) For each operation `op`, find the corresponding mapped operation(s) `mapped_ops`. The
122+ set of mapped operations can be either wrapped in a circuit operation or not, depending
123+ on the value of flag `wrap_in_circuit_op` and whether the mapped operations will end up
124+ occupying more than one moment or not.
125+ 2) Use the `get_earliest_accommodating_moment_index` infrastructure built for `cirq.Circuit`
126+ construction to determine the index at which the mapped operations should be inserted.
127+ This step takes care of the nuances that arise due to (a) preserving moment structure
128+ and (b) mapped operations spanning across multiple moments (these both are trivial when
129+ `op` is mapped to a single `mapped_op` that acts on the same set of qubits).
130+
131+ By default, the function assumes `issubset(qubit_set(map_func(op, moment_index)), op.qubits)` is
132+ True.
133+
134+ Args:
135+ circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
136+ map_func: Mapping function from (cirq.Operation, moment_index) to a cirq.OP_TREE. If the
137+ resulting optree spans more than 1 moment, it's either wrapped in a tagged circuit
138+ operation and inserted in-place in the same moment (if `wrap_in_circuit_op` is True)
139+ OR the mapped operations are inserted directly in the circuit, preserving moment
140+ strucutre. The effect is equivalent to (but much faster) a two-step approach of first
141+ wrapping the operations in a circuit operation and then calling `cirq.unroll_circuit_op`
142+ to unroll the corresponding circuit ops.
143+ deep: If true, `map_func` will be recursively applied to circuits wrapped inside
144+ any circuit operations contained within `circuit`.
145+ raise_if_add_qubits: Set to True by default. If True, raises ValueError if
146+ `map_func(op, idx)` adds operations on qubits outside of `op.qubits`.
147+ tags_to_ignore: Sequence of tags which should be ignored while applying `map_func` on
148+ tagged operations -- i.e. `map_func(op, idx)` will be called only for operations that
149+ satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
150+ wrap_in_circuit_op: If True, the mapped operations will be wrapped in a tagged circuit
151+ operation and inserted in-place if they occupy more than one moment.
152+
153+ Raises:
154+ ValueError if `issubset(qubit_set(map_func(op, idx)), op.qubits) is False` and
155+ `raise_if_add_qubits is True`.
156+
157+ Returns:
158+ Copy of input circuit with mapped operations.
159+ """
160+ tags_to_ignore_set = set (tags_to_ignore )
161+
162+ def apply_map_func (op : 'cirq.Operation' , idx : int ) -> List ['cirq.Operation' ]:
163+ if tags_to_ignore_set .intersection (op .tags ):
164+ return [op ]
165+ if deep and isinstance (op .untagged , circuits .CircuitOperation ):
166+ op = op .untagged .replace (
167+ circuit = _map_operations_impl (
168+ op .untagged .circuit ,
169+ map_func ,
170+ deep = deep ,
171+ raise_if_add_qubits = raise_if_add_qubits ,
172+ tags_to_ignore = tags_to_ignore ,
173+ wrap_in_circuit_op = wrap_in_circuit_op ,
174+ )
175+ ).with_tags (* op .tags )
176+ mapped_ops = [* ops .flatten_to_ops (map_func (op , idx ))]
177+ op_qubits = set (op .qubits )
178+ mapped_ops_qubits : Set ['cirq.Qid' ] = set ()
179+ has_overlapping_ops = False
180+ for mapped_op in mapped_ops :
181+ if raise_if_add_qubits and not op_qubits .issuperset (mapped_op .qubits ):
182+ raise ValueError (
183+ f"Mapped operations { mapped_ops } should act on a subset "
184+ f"of qubits of the original operation { op } "
185+ )
186+ if mapped_ops_qubits .intersection (mapped_op .qubits ):
187+ has_overlapping_ops = True
188+ mapped_ops_qubits = mapped_ops_qubits .union (mapped_op .qubits )
189+ if wrap_in_circuit_op and has_overlapping_ops :
190+ # Mapped operations should be wrapped in a `CircuitOperation` only iff they occupy more
191+ # than one moment, i.e. there are at least two operations that share a qubit.
192+ mapped_ops = [
193+ circuits .CircuitOperation (circuits .FrozenCircuit (mapped_ops )).with_tags (
194+ MAPPED_CIRCUIT_OP_TAG
195+ )
196+ ]
197+ return mapped_ops
198+
199+ new_moments : List [List ['cirq.Operation' ]] = []
200+
201+ # Keep track of the latest time index for each qubit, measurement key, and control key.
202+ qubit_time_index : Dict ['cirq.Qid' , int ] = {}
203+ measurement_time_index : Dict ['cirq.MeasurementKey' , int ] = {}
204+ control_time_index : Dict ['cirq.MeasurementKey' , int ] = {}
205+
206+ # New mapped operations in the current moment should be inserted after `last_moment_time_index`.
207+ last_moment_time_index = - 1
208+
209+ for idx , moment in enumerate (circuit ):
210+ if wrap_in_circuit_op :
211+ new_moments .append ([])
212+ for op in moment :
213+ mapped_ops = apply_map_func (op , idx )
214+
215+ for mapped_op in mapped_ops :
216+ # Identify the earliest moment that can accommodate this op.
217+ placement_index = circuits .circuit .get_earliest_accommodating_moment_index (
218+ mapped_op , qubit_time_index , measurement_time_index , control_time_index
219+ )
220+ placement_index = max (placement_index , last_moment_time_index + 1 )
221+ new_moments .extend ([[] for _ in range (placement_index - len (new_moments ) + 1 )])
222+ new_moments [placement_index ].append (mapped_op )
223+ for qubit in mapped_op .qubits :
224+ qubit_time_index [qubit ] = placement_index
225+ for key in protocols .measurement_key_objs (mapped_op ):
226+ measurement_time_index [key ] = placement_index
227+ for key in protocols .control_keys (mapped_op ):
228+ control_time_index [key ] = placement_index
229+
230+ last_moment_time_index = len (new_moments ) - 1
231+
232+ return _create_target_circuit_type ([circuits .Moment (moment ) for moment in new_moments ], circuit )
233+
234+
107235def map_operations (
108236 circuit : CIRCUIT_TYPE ,
109237 map_func : Callable [[ops .Operation , int ], ops .OP_TREE ],
@@ -139,29 +267,13 @@ def map_operations(
139267 Returns:
140268 Copy of input circuit with mapped operations (wrapped in a tagged CircuitOperation).
141269 """
142-
143- def apply_map (op : ops .Operation , idx : int ) -> ops .OP_TREE :
144- if not set (op .tags ).isdisjoint (tags_to_ignore ):
145- return op
146- c = circuits .FrozenCircuit (map_func (op , idx ))
147- if raise_if_add_qubits and not c .all_qubits ().issubset (op .qubits ):
148- raise ValueError (
149- f"Mapped operations { c .all_operations ()} should act on a subset "
150- f"of qubits of the original operation { op } "
151- )
152- if len (c ) <= 1 :
153- # Either empty circuit or all operations act in the same moment;
154- # So, we don't need to wrap them in a circuit_op.
155- return c [0 ].operations if c else []
156- circuit_op = circuits .CircuitOperation (c ).with_tags (MAPPED_CIRCUIT_OP_TAG )
157- return circuit_op
158-
159- return map_moments (
270+ return _map_operations_impl (
160271 circuit ,
161- lambda m , i : circuits .Circuit (apply_map (op , i ) for op in m .operations ).moments
162- or [circuits .Moment ()],
272+ map_func ,
163273 deep = deep ,
274+ raise_if_add_qubits = raise_if_add_qubits ,
164275 tags_to_ignore = tags_to_ignore ,
276+ wrap_in_circuit_op = True ,
165277 )
166278
167279
@@ -191,15 +303,13 @@ def map_operations_and_unroll(
191303 Returns:
192304 Copy of input circuit with mapped operations, unrolled in a moment preserving way.
193305 """
194- return unroll_circuit_op (
195- map_operations (
196- circuit ,
197- map_func ,
198- deep = deep ,
199- raise_if_add_qubits = raise_if_add_qubits ,
200- tags_to_ignore = tags_to_ignore ,
201- ),
306+ return _map_operations_impl (
307+ circuit ,
308+ map_func ,
202309 deep = deep ,
310+ raise_if_add_qubits = raise_if_add_qubits ,
311+ tags_to_ignore = tags_to_ignore ,
312+ wrap_in_circuit_op = False ,
203313 )
204314
205315
0 commit comments