@@ -32,8 +32,9 @@ class Register:
3232 """
3333
3434 name : str
35+ bitsize : int
3536 shape : Tuple [int , ...] = attr .field (
36- converter = lambda v : (v ,) if isinstance (v , int ) else tuple (v )
37+ converter = lambda v : (v ,) if isinstance (v , int ) else tuple (v ), default = ()
3738 )
3839
3940 def all_idxs (self ) -> Iterable [Tuple [int , ...]]:
@@ -45,15 +46,14 @@ def total_bits(self) -> int:
4546
4647 This is the product of each of the dimensions in `shape`.
4748 """
48- return int (np .product (self .shape ))
49+ return self . bitsize * int (np .product (self .shape ))
4950
5051 def __repr__ (self ):
51- return f'cirq_ft.Register(name="{ self .name } ", shape={ self .shape } )'
52+ return f'cirq_ft.Register(name="{ self .name } ", bitsize= { self . bitsize } , shape={ self .shape } )'
5253
5354
5455def total_bits (registers : Iterable [Register ]) -> int :
5556 """Sum of `reg.total_bits()` for each register `reg` in input `registers`."""
56-
5757 return sum (reg .total_bits () for reg in registers )
5858
5959
@@ -65,7 +65,9 @@ def split_qubits(
6565 qubit_regs = {}
6666 base = 0
6767 for reg in registers :
68- qubit_regs [reg .name ] = np .array (qubits [base : base + reg .total_bits ()]).reshape (reg .shape )
68+ qubit_regs [reg .name ] = np .array (qubits [base : base + reg .total_bits ()]).reshape (
69+ reg .shape + (reg .bitsize ,)
70+ )
6971 base += reg .total_bits ()
7072 return qubit_regs
7173
@@ -82,9 +84,10 @@ def merge_qubits(
8284 raise ValueError (f"All qubit registers must be present. { reg .name } not in qubit_regs" )
8385 qubits = qubit_regs [reg .name ]
8486 qubits = np .array ([qubits ] if isinstance (qubits , cirq .Qid ) else qubits )
85- if qubits .shape != reg .shape :
87+ full_shape = reg .shape + (reg .bitsize ,)
88+ if qubits .shape != full_shape :
8689 raise ValueError (
87- f'{ reg .name } register must of shape { reg . shape } but is of shape { qubits .shape } '
90+ f'{ reg .name } register must of shape { full_shape } but is of shape { qubits .shape } '
8891 )
8992 ret += qubits .flatten ().tolist ()
9093 return ret
@@ -94,13 +97,16 @@ def get_named_qubits(registers: Iterable[Register]) -> Dict[str, NDArray[cirq.Qi
9497 """Returns a dictionary of appropriately shaped named qubit registers for input `registers`."""
9598
9699 def _qubit_array (reg : Register ):
97- qubits = np .empty (reg .shape , dtype = object )
100+ qubits = np .empty (reg .shape + ( reg . bitsize ,) , dtype = object )
98101 for ii in reg .all_idxs ():
99- qubits [ii ] = cirq .NamedQubit (f'{ reg .name } [{ ", " .join (str (i ) for i in ii )} ]' )
102+ for j in range (reg .bitsize ):
103+ prefix = "" if not ii else f'[{ ", " .join (str (i ) for i in ii )} ]'
104+ suffix = "" if reg .bitsize == 1 else f"[{ j } ]"
105+ qubits [ii + (j ,)] = cirq .NamedQubit (reg .name + prefix + suffix )
100106 return qubits
101107
102108 def _qubits_for_reg (reg : Register ):
103- if len (reg .shape ) > 1 :
109+ if len (reg .shape ) > 0 :
104110 return _qubit_array (reg )
105111
106112 return np .array (
@@ -130,8 +136,8 @@ def __repr__(self):
130136 return f'cirq_ft.Registers({ self ._registers } )'
131137
132138 @classmethod
133- def build (cls , ** registers : Union [ int , Tuple [ int , ...]] ) -> 'Registers' :
134- return cls (Register (name = k , shape = v ) for k , v in registers .items ())
139+ def build (cls , ** registers : int ) -> 'Registers' :
140+ return cls (Register (name = k , bitsize = v ) for k , v in registers .items ())
135141
136142 @overload
137143 def __getitem__ (self , key : int ) -> Register :
@@ -216,23 +222,29 @@ class SelectionRegister(Register):
216222 >>> assert len(flat_indices) == N * M * L
217223 """
218224
225+ name : str
226+ bitsize : int
219227 iteration_length : int = attr .field ()
228+ shape : Tuple [int , ...] = attr .field (
229+ converter = lambda v : (v ,) if isinstance (v , int ) else tuple (v ), default = ()
230+ )
220231
221232 @iteration_length .default
222233 def _default_iteration_length (self ):
223- return 2 ** self .shape [ 0 ]
234+ return 2 ** self .bitsize
224235
225236 @iteration_length .validator
226237 def validate_iteration_length (self , attribute , value ):
227- if len (self .shape ) != 1 :
238+ if len (self .shape ) != 0 :
228239 raise ValueError (f'Selection register { self .name } should be flat. Found { self .shape = } ' )
229- if not (0 <= value <= 2 ** self .shape [ 0 ] ):
230- raise ValueError (f'iteration length must be in range [0, 2^{ self .shape [ 0 ] } ]' )
240+ if not (0 <= value <= 2 ** self .bitsize ):
241+ raise ValueError (f'iteration length must be in range [0, 2^{ self .bitsize } ]' )
231242
232243 def __repr__ (self ) -> str :
233244 return (
234245 f'cirq_ft.SelectionRegister('
235246 f'name="{ self .name } ", '
247+ f'bitsize={ self .bitsize } , '
236248 f'shape={ self .shape } , '
237249 f'iteration_length={ self .iteration_length } )'
238250 )
0 commit comments