1313# limitations under the License.
1414
1515import abc
16- from typing import Dict , Iterator , List , Sequence , Tuple
16+ from typing import Callable , Dict , Iterator , List , Sequence , Tuple
1717from numpy .typing import NDArray
1818
1919import cirq
@@ -34,6 +34,7 @@ def _unary_iteration_segtree(
3434 r : int ,
3535 l_iter : int ,
3636 r_iter : int ,
37+ break_early : Callable [[int , int ], bool ],
3738) -> Iterator [Tuple [cirq .OP_TREE , cirq .Qid , int ]]:
3839 """Constructs a unary iteration circuit by iterating over nodes of an implicit Segment Tree.
3940
@@ -53,6 +54,11 @@ def _unary_iteration_segtree(
5354 r: Right index of the range represented by current node of the segment tree.
5455 l_iter: Left index of iteration range over which the segment tree should be constructed.
5556 r_iter: Right index of iteration range over which the segment tree should be constructed.
57+ break_early: For each internal node of the segment tree, `break_early(l, r)` is called to
58+ evaluate whether the unary iteration should terminate early and not recurse in the
59+ subtree of the node representing range `[l, r)`. If True, the internal node is
60+ considered equivalent to a leaf node and the method yields only one tuple
61+ `(OP_TREE, control_qubit, l)` for all integers in the range `[l, r)`.
5662
5763 Yields:
5864 One `Tuple[cirq.OP_TREE, cirq.Qid, int]` for each leaf node in the segment tree. The i'th
@@ -68,8 +74,8 @@ def _unary_iteration_segtree(
6874 if l >= r_iter or l_iter >= r :
6975 # Range corresponding to this node is completely outside of iteration range.
7076 return
71- if l == (r - 1 ):
72- # Reached a leaf node; yield the operations.
77+ if l_iter <= l < r <= r_iter and ( l == (r - 1 ) or break_early ( l , r ) ):
78+ # Reached a leaf node or a "special" internal node ; yield the operations.
7379 yield tuple (ops ), control , l
7480 ops .clear ()
7581 return
@@ -78,20 +84,24 @@ def _unary_iteration_segtree(
7884 if r_iter <= m :
7985 # Yield only left sub-tree.
8086 yield from _unary_iteration_segtree (
81- ops , control , selection , ancilla , sl + 1 , l , m , l_iter , r_iter
87+ ops , control , selection , ancilla , sl + 1 , l , m , l_iter , r_iter , break_early
8288 )
8389 return
8490 if l_iter >= m :
8591 # Yield only right sub-tree
8692 yield from _unary_iteration_segtree (
87- ops , control , selection , ancilla , sl + 1 , m , r , l_iter , r_iter
93+ ops , control , selection , ancilla , sl + 1 , m , r , l_iter , r_iter , break_early
8894 )
8995 return
9096 anc , sq = ancilla [sl ], selection [sl ]
9197 ops .append (and_gate .And ((1 , 0 )).on (control , sq , anc ))
92- yield from _unary_iteration_segtree (ops , anc , selection , ancilla , sl + 1 , l , m , l_iter , r_iter )
98+ yield from _unary_iteration_segtree (
99+ ops , anc , selection , ancilla , sl + 1 , l , m , l_iter , r_iter , break_early
100+ )
93101 ops .append (cirq .CNOT (control , anc ))
94- yield from _unary_iteration_segtree (ops , anc , selection , ancilla , sl + 1 , m , r , l_iter , r_iter )
102+ yield from _unary_iteration_segtree (
103+ ops , anc , selection , ancilla , sl + 1 , m , r , l_iter , r_iter , break_early
104+ )
95105 ops .append (and_gate .And (adjoint = True ).on (control , sq , anc ))
96106
97107
@@ -101,16 +111,17 @@ def _unary_iteration_zero_control(
101111 ancilla : Sequence [cirq .Qid ],
102112 l_iter : int ,
103113 r_iter : int ,
114+ break_early : Callable [[int , int ], bool ],
104115) -> Iterator [Tuple [cirq .OP_TREE , cirq .Qid , int ]]:
105116 sl , l , r = 0 , 0 , 2 ** len (selection )
106117 m = (l + r ) >> 1
107118 ops .append (cirq .X (selection [0 ]))
108119 yield from _unary_iteration_segtree (
109- ops , selection [0 ], selection [1 :], ancilla , sl , l , m , l_iter , r_iter
120+ ops , selection [0 ], selection [1 :], ancilla , sl , l , m , l_iter , r_iter , break_early
110121 )
111122 ops .append (cirq .X (selection [0 ]))
112123 yield from _unary_iteration_segtree (
113- ops , selection [0 ], selection [1 :], ancilla , sl , m , r , l_iter , r_iter
124+ ops , selection [0 ], selection [1 :], ancilla , sl , m , r , l_iter , r_iter , break_early
114125 )
115126
116127
@@ -121,9 +132,12 @@ def _unary_iteration_single_control(
121132 ancilla : Sequence [cirq .Qid ],
122133 l_iter : int ,
123134 r_iter : int ,
135+ break_early : Callable [[int , int ], bool ],
124136) -> Iterator [Tuple [cirq .OP_TREE , cirq .Qid , int ]]:
125137 sl , l , r = 0 , 0 , 2 ** len (selection )
126- yield from _unary_iteration_segtree (ops , control , selection , ancilla , sl , l , r , l_iter , r_iter )
138+ yield from _unary_iteration_segtree (
139+ ops , control , selection , ancilla , sl , l , r , l_iter , r_iter , break_early
140+ )
127141
128142
129143def _unary_iteration_multi_controls (
@@ -133,6 +147,7 @@ def _unary_iteration_multi_controls(
133147 ancilla : Sequence [cirq .Qid ],
134148 l_iter : int ,
135149 r_iter : int ,
150+ break_early : Callable [[int , int ], bool ],
136151) -> Iterator [Tuple [cirq .OP_TREE , cirq .Qid , int ]]:
137152 num_controls = len (controls )
138153 and_ancilla = ancilla [: num_controls - 2 ]
@@ -142,7 +157,7 @@ def _unary_iteration_multi_controls(
142157 )
143158 ops .append (multi_controlled_and )
144159 yield from _unary_iteration_single_control (
145- ops , and_target , selection , ancilla [num_controls - 1 :], l_iter , r_iter
160+ ops , and_target , selection , ancilla [num_controls - 1 :], l_iter , r_iter , break_early
146161 )
147162 ops .append (cirq .inverse (multi_controlled_and ))
148163
@@ -154,6 +169,7 @@ def unary_iteration(
154169 controls : Sequence [cirq .Qid ],
155170 selection : Sequence [cirq .Qid ],
156171 qubit_manager : cirq .QubitManager ,
172+ break_early : Callable [[int , int ], bool ] = lambda l , r : False ,
157173) -> Iterator [Tuple [cirq .OP_TREE , cirq .Qid , int ]]:
158174 """The method performs unary iteration on `selection` integer in `range(l_iter, r_iter)`.
159175
@@ -181,6 +197,9 @@ def unary_iteration(
181197 ... circuit.append(j_ops)
182198 >>> circuit.append(i_ops)
183199
200+ Note: Unary iteration circuits assume that the selection register stores integers only in the
201+ range `[l, r)` for which the corresponding unary iteration circuit should be built.
202+
184203 Args:
185204 l_iter: Starting index of the iteration range.
186205 r_iter: Ending index of the iteration range.
@@ -192,6 +211,11 @@ def unary_iteration(
192211 controls: Control register of qubits.
193212 selection: Selection register of qubits.
194213 qubit_manager: A `cirq.QubitManager` to allocate new qubits.
214+ break_early: For each internal node of the segment tree, `break_early(l, r)` is called to
215+ evaluate whether the unary iteration should terminate early and not recurse in the
216+ subtree of the node representing range `[l, r)`. If True, the internal node is
217+ considered equivalent to a leaf node and the method yields only one tuple
218+ `(OP_TREE, control_qubit, l)` for all integers in the range `[l, r)`.
195219
196220 Yields:
197221 (r_iter - l_iter) different tuples, each corresponding to an integer in range
@@ -207,14 +231,16 @@ def unary_iteration(
207231 assert len (selection ) > 0
208232 ancilla = qubit_manager .qalloc (max (0 , len (controls ) + len (selection ) - 1 ))
209233 if len (controls ) == 0 :
210- yield from _unary_iteration_zero_control (flanking_ops , selection , ancilla , l_iter , r_iter )
234+ yield from _unary_iteration_zero_control (
235+ flanking_ops , selection , ancilla , l_iter , r_iter , break_early
236+ )
211237 elif len (controls ) == 1 :
212238 yield from _unary_iteration_single_control (
213- flanking_ops , controls [0 ], selection , ancilla , l_iter , r_iter
239+ flanking_ops , controls [0 ], selection , ancilla , l_iter , r_iter , break_early
214240 )
215241 else :
216242 yield from _unary_iteration_multi_controls (
217- flanking_ops , controls , selection , ancilla , l_iter , r_iter
243+ flanking_ops , controls , selection , ancilla , l_iter , r_iter , break_early
218244 )
219245 qubit_manager .qfree (ancilla )
220246
@@ -231,6 +257,9 @@ class UnaryIterationGate(infra.GateWithRegisters):
231257 indexed operations on a target register depending on the index value stored in a selection
232258 register.
233259
260+ Note: Unary iteration circuits assume that the selection register stores integers only in the
261+ range `[l, r)` for which the corresponding unary iteration circuit should be built.
262+
234263 References:
235264 [Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity]
236265 (https://arxiv.org/abs/1805.03662).
@@ -308,10 +337,38 @@ def decompose_zero_selection(
308337 """
309338 raise NotImplementedError ("Selection register must not be empty." )
310339
340+ def _break_early (self , selection_index_prefix : Tuple [int , ...], l : int , r : int ) -> bool :
341+ """Derived classes should override this method to specify an early termination condition.
342+
343+ For each internal node of the unary iteration segment tree, `break_early(l, r)` is called
344+ to evaluate whether the unary iteration should not recurse in the subtree of the node
345+ representing range `[l, r)`. If True, the internal node is considered equivalent to a leaf
346+ node and thus, `self.nth_operation` will be called for only integer `l` in the range [l, r).
347+
348+ When the `UnaryIteration` class is constructed using multiple selection registers, i.e. we
349+ wish to perform nested coherent for-loops, a unary iteration segment tree is constructed
350+ corresponding to each nested coherent for-loop. For every such unary iteration segment tree,
351+ the `_break_early` condition is checked by passing the `selection_index_prefix` tuple.
352+
353+ Args:
354+ selection_index_prefix: To evaluate the early breaking condition for the i'th nested
355+ for-loop, the `selection_index_prefix` contains `i-1` integers corresponding to
356+ the loop variable values for the first `i-1` nested loops.
357+ l: Beginning of range `[l, r)` for internal node of unary iteration segment tree.
358+ r: End (exclusive) of range `[l, r)` for internal node of unary iteration segment tree.
359+
360+ Returns:
361+ True of the `len(selection_index_prefix)`'th unary iteration should terminate early for
362+ the given parameters.
363+ """
364+ return False
365+
311366 def decompose_from_registers (
312367 self , * , context : cirq .DecompositionContext , ** quregs : NDArray [cirq .Qid ]
313368 ) -> cirq .OP_TREE :
314- if self .selection_registers .total_bits () == 0 :
369+ if self .selection_registers .total_bits () == 0 or self ._break_early (
370+ (), 0 , self .selection_registers [0 ].iteration_length
371+ ):
315372 return self .decompose_zero_selection (context = context , ** quregs )
316373
317374 num_loops = len (self .selection_registers )
@@ -354,20 +411,23 @@ def unary_iteration_loops(
354411 return
355412 # Use recursion to write `num_loops` nested loops using unary_iteration().
356413 ops : List [cirq .Operation ] = []
414+ selection_index_prefix = tuple (selection_reg_name_to_val .values ())
357415 ith_for_loop = unary_iteration (
358416 l_iter = 0 ,
359417 r_iter = self .selection_registers [nested_depth ].iteration_length ,
360418 flanking_ops = ops ,
361419 controls = controls ,
362420 selection = [* quregs [self .selection_registers [nested_depth ].name ]],
363421 qubit_manager = context .qubit_manager ,
422+ break_early = lambda l , r : self ._break_early (selection_index_prefix , l , r ),
364423 )
365424 for op_tree , control_qid , n in ith_for_loop :
366425 yield op_tree
367426 selection_reg_name_to_val [self .selection_registers [nested_depth ].name ] = n
368427 yield from unary_iteration_loops (
369428 nested_depth + 1 , selection_reg_name_to_val , (control_qid ,)
370429 )
430+ selection_reg_name_to_val .pop (self .selection_registers [nested_depth ].name )
371431 yield ops
372432
373433 return unary_iteration_loops (0 , {}, self .control_registers .merge_qubits (** quregs ))
0 commit comments