1616
1717from __future__ import annotations
1818
19- from typing import TYPE_CHECKING
19+ from typing import Callable , cast , Hashable , TYPE_CHECKING
2020
2121from cirq import circuits , ops , protocols
22- from cirq .transformers import merge_k_qubit_gates , transformer_api , transformer_primitives
22+ from cirq .study .resolver import ParamResolver
23+ from cirq .study .sweeps import dict_to_zip_sweep , ListSweep , ProductOrZipSweepLike , Sweep , Zip
24+ from cirq .transformers import (
25+ align ,
26+ merge_k_qubit_gates ,
27+ symbolize ,
28+ tag_transformers ,
29+ transformer_api ,
30+ transformer_primitives ,
31+ )
2332from cirq .transformers .analytical_decompositions import single_qubit_decompositions
2433
2534if TYPE_CHECKING :
35+ import sympy
36+
2637 import cirq
2738
2839
@@ -67,6 +78,7 @@ def merge_single_qubit_gates_to_phxz(
6778 circuit : cirq .AbstractCircuit ,
6879 * ,
6980 context : cirq .TransformerContext | None = None ,
81+ merge_tags_fn : Callable [[cirq .CircuitOperation ], list [Hashable ]] | None = None ,
7082 atol : float = 1e-8 ,
7183) -> cirq .Circuit :
7284 """Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`.
@@ -77,19 +89,21 @@ def merge_single_qubit_gates_to_phxz(
7789 Args:
7890 circuit: Input circuit to transform. It will not be modified.
7991 context: `cirq.TransformerContext` storing common configurable options for transformers.
92+ merge_tags_fn: A callable returns the tags to be added to the merged operation.
8093 atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
8194 dropped, smaller values increase accuracy.
8295
8396 Returns:
8497 Copy of the transformed input circuit.
8598 """
8699
87- def rewriter (op : cirq .CircuitOperation ) -> cirq .OP_TREE :
88- u = protocols .unitary (op )
89- if protocols .num_qubits (op ) == 0 :
100+ def rewriter (circuit_op : cirq .CircuitOperation ) -> cirq .OP_TREE :
101+ u = protocols .unitary (circuit_op )
102+ if protocols .num_qubits (circuit_op ) == 0 :
90103 return ops .GlobalPhaseGate (u [0 , 0 ]).on ()
91- gate = single_qubit_decompositions .single_qubit_matrix_to_phxz (u , atol )
92- return gate (op .qubits [0 ]) if gate else []
104+ gate = single_qubit_decompositions .single_qubit_matrix_to_phxz (u , atol ) or ops .I
105+ phxz_op = gate .on (circuit_op .qubits [0 ])
106+ return phxz_op .with_tags (* merge_tags_fn (circuit_op )) if merge_tags_fn else phxz_op
93107
94108 return merge_k_qubit_gates .merge_k_qubit_unitaries (
95109 circuit , k = 1 , context = context , rewriter = rewriter
@@ -158,3 +172,160 @@ def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> cirq.Moment | None:
158172 deep = context .deep if context else False ,
159173 tags_to_ignore = tuple (tags_to_ignore ),
160174 ).unfreeze (copy = False )
175+
176+
177+ def _sweep_on_symbols (sweep : Sweep , symbols : set [sympy .Symbol ]) -> Sweep :
178+ new_resolvers : list [cirq .ParamResolver ] = []
179+ for resolver in sweep :
180+ param_dict : cirq .ParamMappingType = {s : resolver .value_of (s ) for s in symbols }
181+ new_resolvers .append (ParamResolver (param_dict ))
182+ return ListSweep (new_resolvers )
183+
184+
185+ def _calc_phxz_sweeps (
186+ symbolized_circuit : cirq .Circuit , resolved_circuits : list [cirq .Circuit ]
187+ ) -> Sweep :
188+ """Return the phxz sweep of the symbolized_circuit on resolved_circuits.
189+
190+ Raises:
191+ ValueError: Structural mismatch: A `resolved_circuit` contains an unexpected gate type.
192+ Expected a `PhasedXZGate` or `IdentityGate` at a position corresponding to a
193+ symbolic `PhasedXZGate` in the `symbolized_circuit`.
194+ """
195+
196+ def _extract_axz (op : ops .Operation | None ) -> tuple [float , float , float ]:
197+ if not op or not op .gate or not isinstance (op .gate , ops .IdentityGate | ops .PhasedXZGate ):
198+ raise ValueError (f"Expect a PhasedXZGate or IdentityGate on op { op } ." )
199+ if isinstance (op .gate , ops .IdentityGate ):
200+ return 0.0 , 0.0 , 0.0 # Identity gate's a, x, z in PhasedXZ
201+ return op .gate .axis_phase_exponent , op .gate .x_exponent , op .gate .z_exponent
202+
203+ values_by_params : dict [sympy .Symbol , tuple [float , ...]] = {}
204+ for mid , moment in enumerate (symbolized_circuit ):
205+ for op in moment .operations :
206+ if op .gate and isinstance (op .gate , ops .PhasedXZGate ) and protocols .is_parameterized (op ):
207+ sa , sx , sz = op .gate .axis_phase_exponent , op .gate .x_exponent , op .gate .z_exponent
208+ values_by_params [sa ], values_by_params [sx ], values_by_params [sz ] = zip (
209+ * [_extract_axz (c [mid ].operation_at (op .qubits [0 ])) for c in resolved_circuits ]
210+ )
211+
212+ return dict_to_zip_sweep (cast (ProductOrZipSweepLike , values_by_params ))
213+
214+
215+ def merge_single_qubit_gates_to_phxz_symbolized (
216+ circuit : cirq .AbstractCircuit ,
217+ * ,
218+ context : cirq .TransformerContext | None = None ,
219+ sweep : Sweep ,
220+ atol : float = 1e-8 ,
221+ ) -> tuple [cirq .Circuit , Sweep ]:
222+ """Merges consecutive single qubit gates as PhasedXZ Gates. Symbolizes if any of
223+ the consecutive gates is symbolized.
224+
225+ Example:
226+ >>> q0, q1 = cirq.LineQubit.range(2)
227+ >>> c = cirq.Circuit(\
228+ cirq.X(q0),\
229+ cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\
230+ cirq.Y(q0)**sympy.Symbol("y_exp"),\
231+ cirq.X(q0))
232+ >>> print(c)
233+ 0: ───X───@──────────Y^y_exp───X───
234+ │
235+ 1: ───────@^cz_exp─────────────────
236+ >>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
237+ c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\
238+ cirq.Points(key="y_exp", points=[0, 1])))
239+ >>> print(new_circuit)
240+ 0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
241+ │
242+ 1: ────────────────────────@^cz_exp──────────────────────────
243+ >>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0})
244+ >>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1})
245+
246+ Args:
247+ circuit: Input circuit to transform. It will not be modified.
248+ context: `cirq.TransformerContext` storing common configurable options for transformers.
249+ sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
250+ based on the transformation.
251+ atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
252+ dropped, smaller values increase accuracy.
253+
254+ Returns:
255+ Copy of the transformed input circuit.
256+ """
257+ deep = context .deep if context else False
258+
259+ # Tag symbolized single-qubit op.
260+ symbolized_single_tag = "_tmp_symbolize_tag"
261+
262+ circuit_tagged = transformer_primitives .map_operations (
263+ circuit ,
264+ lambda op , _ : (
265+ op .with_tags (symbolized_single_tag )
266+ if protocols .is_parameterized (op ) and len (op .qubits ) == 1
267+ else op
268+ ),
269+ deep = deep ,
270+ )
271+
272+ # Step 0, isolate single qubit symbols and resolve the circuit on them.
273+ single_qubit_gate_symbols : set [sympy .Symbol ] = set ().union (
274+ * [
275+ protocols .parameter_symbols (op ) if symbolized_single_tag in op .tags else set ()
276+ for op in circuit_tagged .all_operations ()
277+ ]
278+ )
279+ # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
280+ remaining_symbols : set [sympy .Symbol ] = set (
281+ protocols .parameter_symbols (circuit ) - single_qubit_gate_symbols
282+ )
283+ # If all single qubit gates are not parameterized, call the nonparamerized version of
284+ # the transformer.
285+ if not single_qubit_gate_symbols :
286+ return (merge_single_qubit_gates_to_phxz (circuit , context = context , atol = atol ), sweep )
287+ sweep_of_single : Sweep = _sweep_on_symbols (sweep , single_qubit_gate_symbols )
288+ # Get all resolved circuits from all sets of resolvers in sweep_of_single.
289+ resolved_circuits = [
290+ protocols .resolve_parameters (circuit_tagged , resolver ) for resolver in sweep_of_single
291+ ]
292+
293+ # Step 1, merge single qubit gates per resolved circuit, preserving
294+ # the symbolized_single_tag to indicate the operator is a merged one.
295+ merged_circuits : list [cirq .Circuit ] = [
296+ merge_single_qubit_gates_to_phxz (
297+ c ,
298+ context = context ,
299+ merge_tags_fn = lambda circuit_op : (
300+ [symbolized_single_tag ]
301+ if any (
302+ symbolized_single_tag in set (op .tags )
303+ for op in circuit_op .circuit .all_operations ()
304+ )
305+ else []
306+ ),
307+ atol = atol ,
308+ )
309+ for c in resolved_circuits
310+ ]
311+
312+ # Step 2, get the new symbolized circuit by symbolizing on indexed symbolized_single_tag.
313+ new_circuit = tag_transformers .remove_tags ( # remove the temp tags used to track merges
314+ symbolize .symbolize_single_qubit_gates_by_indexed_tags (
315+ tag_transformers .index_tags ( # index all 1-qubit-ops merged from ops with symbols
316+ merged_circuits [0 ],
317+ context = transformer_api .TransformerContext (deep = deep ),
318+ target_tags = {symbolized_single_tag },
319+ ),
320+ symbolize_tag = symbolize .SymbolizeTag (prefix = symbolized_single_tag ),
321+ ),
322+ remove_if = lambda tag : str (tag ).startswith (symbolized_single_tag ),
323+ )
324+
325+ # Step 3, get N sets of parameterizations as new_sweep.
326+ new_sweep = Zip (
327+ _calc_phxz_sweeps (new_circuit , merged_circuits ), # phxz sweeps
328+ _sweep_on_symbols (sweep , remaining_symbols ), # remaining sweeps
329+ )
330+
331+ return align .align_right (new_circuit ), new_sweep
0 commit comments