3030import  numpy  as  np 
3131import  pennylane  as  qml 
3232from  pennylane .devices  import  DefaultExecutionConfig , ExecutionConfig 
33- from  pennylane .devices .default_qubit  import  adjoint_ops 
33+ from  pennylane .devices .capabilities  import  OperatorProperties 
3434from  pennylane .devices .modifiers  import  simulator_tracking , single_tape_support 
3535from  pennylane .devices .preprocess  import  (
3636    decompose ,
4343)
4444from  pennylane .measurements  import  MidMeasureMP 
4545from  pennylane .operation  import  DecompositionUndefinedError , Operator 
46- from  pennylane .ops  import  Prod , SProd , Sum 
46+ from  pennylane .ops  import  Conditional ,  PauliRot ,  Prod , SProd , Sum 
4747from  pennylane .tape  import  QuantumScript 
4848from  pennylane .transforms .core  import  TransformProgram 
4949from  pennylane .typing  import  Result 
7474from  ._mpi_handler  import  MPIHandler 
7575from  ._state_vector  import  LightningGPUStateVector 
7676
77- # The set of supported operations. 
78- _operations  =  frozenset (
79-     {
80-         "Identity" ,
81-         "QubitUnitary" ,
82-         "ControlledQubitUnitary" ,
83-         "MultiControlledX" ,
84-         "DiagonalQubitUnitary" ,
85-         "PauliX" ,
86-         "PauliY" ,
87-         "PauliZ" ,
88-         "MultiRZ" ,
89-         "GlobalPhase" ,
90-         "C(PauliX)" ,
91-         "C(PauliY)" ,
92-         "C(PauliZ)" ,
93-         "C(Hadamard)" ,
94-         "C(S)" ,
95-         "C(T)" ,
96-         "C(PhaseShift)" ,
97-         "C(RX)" ,
98-         "C(RY)" ,
99-         "C(RZ)" ,
100-         "C(Rot)" ,
101-         "C(SWAP)" ,
102-         "C(IsingXX)" ,
103-         "C(IsingXY)" ,
104-         "C(IsingYY)" ,
105-         "C(IsingZZ)" ,
106-         "C(SingleExcitation)" ,
107-         "C(SingleExcitationMinus)" ,
108-         "C(SingleExcitationPlus)" ,
109-         "C(DoubleExcitation)" ,
110-         "C(DoubleExcitationMinus)" ,
111-         "C(DoubleExcitationPlus)" ,
112-         "C(MultiRZ)" ,
113-         "C(GlobalPhase)" ,
114-         "Hadamard" ,
115-         "S" ,
116-         "Adjoint(S)" ,
117-         "T" ,
118-         "Adjoint(T)" ,
119-         "SX" ,
120-         "Adjoint(SX)" ,
121-         "CNOT" ,
122-         "SWAP" ,
123-         "ISWAP" ,
124-         "PSWAP" ,
125-         "Adjoint(ISWAP)" ,
126-         "SISWAP" ,
127-         "Adjoint(SISWAP)" ,
128-         "SQISW" ,
129-         "CSWAP" ,
130-         "Toffoli" ,
131-         "CY" ,
132-         "CZ" ,
133-         "PhaseShift" ,
134-         "ControlledPhaseShift" ,
135-         "RX" ,
136-         "RY" ,
137-         "RZ" ,
138-         "Rot" ,
139-         "CRX" ,
140-         "CRY" ,
141-         "CRZ" ,
142-         "CRot" ,
143-         "IsingXX" ,
144-         "IsingYY" ,
145-         "IsingZZ" ,
146-         "IsingXY" ,
147-         "SingleExcitation" ,
148-         "SingleExcitationPlus" ,
149-         "SingleExcitationMinus" ,
150-         "DoubleExcitation" ,
151-         "DoubleExcitationPlus" ,
152-         "DoubleExcitationMinus" ,
153-         "Adjoint(MultiRZ)" ,
154-         "Adjoint(GlobalPhase)" ,
155-         "Adjoint(PhaseShift)" ,
156-         "Adjoint(ControlledPhaseShift)" ,
157-         "Adjoint(RX)" ,
158-         "Adjoint(RY)" ,
159-         "Adjoint(RZ)" ,
160-         "Adjoint(CRX)" ,
161-         "Adjoint(CRY)" ,
162-         "Adjoint(CRZ)" ,
163-         "Adjoint(IsingXX)" ,
164-         "Adjoint(IsingYY)" ,
165-         "Adjoint(IsingZZ)" ,
166-         "Adjoint(IsingXY)" ,
167-         "Adjoint(SingleExcitation)" ,
168-         "Adjoint(SingleExcitationPlus)" ,
169-         "Adjoint(SingleExcitationMinus)" ,
170-         "Adjoint(DoubleExcitation)" ,
171-         "Adjoint(DoubleExcitationPlus)" ,
172-         "Adjoint(DoubleExcitationMinus)" ,
173-         "QubitCarry" ,
174-         "QubitSum" ,
175-         "OrbitalRotation" ,
176-         "ECR" ,
177-         "BlockEncode" ,
178-         "C(BlockEncode)" ,
179-     }
180- )
181- # End the set of supported operations. 
182- 
183- # The set of supported observables. 
184- _observables  =  frozenset (
185-     {
186-         "PauliX" ,
187-         "PauliY" ,
188-         "PauliZ" ,
189-         "Hadamard" ,
190-         "SparseHamiltonian" ,
191-         "LinearCombination" ,
192-         "Hermitian" ,
193-         "Identity" ,
194-         "Projector" ,
195-         "Sum" ,
196-         "Prod" ,
197-         "SProd" ,
198-         "Exp" ,
199-     }
200- )
77+ _to_matrix_ops  =  {
78+     "BlockEncode" : OperatorProperties (controllable = True ),
79+     "ControlledQubitUnitary" : OperatorProperties (),
80+     "ECR" : OperatorProperties (),
81+     "SX" : OperatorProperties (),
82+     "ISWAP" : OperatorProperties (),
83+     "PSWAP" : OperatorProperties (),
84+     "SISWAP" : OperatorProperties (),
85+     "SQISW" : OperatorProperties (),
86+     "OrbitalRotation" : OperatorProperties (),
87+     "QubitCarry" : OperatorProperties (),
88+     "QubitSum" : OperatorProperties (),
89+     "DiagonalQubitUnitary" : OperatorProperties (),
90+ }
20191
20292
20393def  stopping_condition (op : Operator ) ->  bool :
20494    """A function that determines whether or not an operation is supported by ``lightning.gpu``.""" 
205-     return  op .name   in   _operations 
95+     return  _supports_operation ( op .name ) 
20696
20797
20898def  stopping_condition_shots (op : Operator ) ->  bool :
@@ -213,7 +103,7 @@ def stopping_condition_shots(op: Operator) -> bool:
213103
214104def  accepted_observables (obs : Operator ) ->  bool :
215105    """A function that determines whether or not an observable is supported by ``lightning.gpu``.""" 
216-     return  obs .name   in   _observables 
106+     return  _supports_observable ( obs .name ) 
217107
218108
219109def  adjoint_observables (obs : Operator ) ->  bool :
@@ -228,7 +118,7 @@ def adjoint_observables(obs: Operator) -> bool:
228118    if  isinstance (obs , (Sum , Prod )):
229119        return  all (adjoint_observables (o ) for  o  in  obs )
230120
231-     return  obs .name   in   _observables 
121+     return  _supports_observable ( obs .name ) 
232122
233123
234124def  adjoint_measurements (mp : qml .measurements .MeasurementProcess ) ->  bool :
@@ -252,7 +142,10 @@ def _supports_adjoint(circuit):
252142
253143def  _adjoint_ops (op : qml .operation .Operator ) ->  bool :
254144    """Specify whether or not an Operator is supported by adjoint differentiation.""" 
255-     return  adjoint_ops (op ) and  not  isinstance (op , qml .PauliRot )
145+ 
146+     return  not  isinstance (op , (Conditional , MidMeasureMP , PauliRot )) and  (
147+         not  qml .operation .is_trainable (op ) or  (op .num_params  ==  1  and  op .has_generator )
148+     )
256149
257150
258151def  _add_adjoint_transforms (program : TransformProgram ) ->  None :
@@ -333,15 +226,13 @@ class LightningGPU(LightningBase):
333226    _CPP_BINARY_AVAILABLE  =  LGPU_CPP_BINARY_AVAILABLE 
334227    _backend_info  =  backend_info  if  LGPU_CPP_BINARY_AVAILABLE  else  None 
335228
336-     # This `config` is used in Catalyst-Frontend 
337-     config  =  Path (__file__ ).parent  /  "lightning_gpu.toml" 
229+     # TODO: This is to communicate to Catalyst in qjit-compiled workflows that these operations 
230+     #       should be converted to QubitUnitary instead of their original decompositions. Remove 
231+     #       this when customizable multiple decomposition pathways are implemented 
232+     _to_matrix_ops  =  _to_matrix_ops 
338233
339-     # TODO: Move supported ops/obs to TOML file 
340-     operations  =  _operations 
341-     # The names of the supported operations. 
342- 
343-     observables  =  _observables 
344-     # The names of the supported observables. 
234+     # This configuration file declares capabilities of the device 
235+     config_filepath  =  Path (__file__ ).parent  /  "lightning_gpu.toml" 
345236
346237    def  __init__ (  # pylint: disable=too-many-arguments 
347238        self ,
@@ -607,3 +498,7 @@ def get_c_interface():
607498                return  "LightningGPUSimulator" , lib_location 
608499
609500        raise  RuntimeError ("'LightningGPUSimulator' shared library not found" )  # pragma: no cover 
501+ 
502+ 
503+ _supports_operation  =  LightningGPU .capabilities .supports_operation 
504+ _supports_observable  =  LightningGPU .capabilities .supports_observable 
0 commit comments