3636 ParamResolverOrSimilarType , """Something that can be used to turn parameters into values."""
3737)
3838
39+ # Used to mark values that are not found in a dict.
40+ _NOT_FOUND = object ()
41+
3942# Used to mark values that are being resolved recursively to detect loops.
40- _RecursionFlag = object ()
43+ _RECURSION_FLAG = object ()
4144
4245
4346def _is_param_resolver_or_similar_type (obj : Any ):
@@ -72,7 +75,7 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None
7275
7376 self ._param_hash : Optional [int ] = None
7477 self ._param_dict = cast (ParamDictType , {} if param_dict is None else param_dict )
75- for key in self .param_dict :
78+ for key in self ._param_dict :
7679 if isinstance (key , sympy .Expr ) and not isinstance (key , sympy .Symbol ):
7780 raise TypeError (f'ParamResolver keys cannot be (non-symbol) formulas ({ key } )' )
7881 self ._deep_eval_map : ParamDictType = {}
@@ -120,32 +123,30 @@ def value_of(
120123 if v is not NotImplemented :
121124 return v
122125
123- # Handles 2 cases:
124- # Input is a string and maps to a number in the dictionary
125- # Input is a symbol and maps to a number in the dictionary
126- # In both cases, return it directly.
127- if value in self .param_dict :
128- # Note: if the value is in the dictionary, it will be a key type
129- # Add a cast to make mypy happy.
130- param_value = self .param_dict [cast ('cirq.TParamKey' , value )]
126+ # Handle string or symbol
127+ if isinstance (value , (str , sympy .Symbol )):
128+ string = value if isinstance (value , str ) else value .name
129+ symbol = value if isinstance (value , sympy .Symbol ) else sympy .Symbol (value )
130+ param_value = self ._param_dict .get (string , _NOT_FOUND )
131+ if param_value is _NOT_FOUND :
132+ param_value = self ._param_dict .get (symbol , _NOT_FOUND )
133+ if param_value is _NOT_FOUND :
134+ # Symbol or string cannot be resolved if not in param dict; return as symbol.
135+ return symbol
131136 v = _resolve_value (param_value )
132137 if v is not NotImplemented :
133138 return v
139+ if isinstance (param_value , str ):
140+ param_value = sympy .Symbol (param_value )
141+ elif not isinstance (param_value , sympy .Basic ):
142+ return value # type: ignore[return-value]
143+ if recursive :
144+ param_value = self ._value_of_recursive (value )
145+ return param_value # type: ignore[return-value]
134146
135- # Input is a string and is not in the dictionary.
136- # Treat it as a symbol instead.
137- if isinstance (value , str ):
138- # If the string is in the param_dict as a value, return it.
139- # Otherwise, try using the symbol instead.
140- return self .value_of (sympy .Symbol (value ), recursive )
141-
142- # Input is a symbol (sympy.Symbol('a')) and its string maps to a number
143- # in the dictionary ({'a': 1.0}). Return it.
144- if isinstance (value , sympy .Symbol ) and value .name in self .param_dict :
145- param_value = self .param_dict [value .name ]
146- v = _resolve_value (param_value )
147- if v is not NotImplemented :
148- return v
147+ if not isinstance (value , sympy .Basic ):
148+ # No known way to resolve this variable, return unchanged.
149+ return value
149150
150151 # The following resolves common sympy expressions
151152 # If sympy did its job and wasn't slower than molasses,
@@ -171,10 +172,6 @@ def value_of(
171172 return np .float_power (cast (complex , base ), cast (complex , exponent ))
172173 return np .power (cast (complex , base ), cast (complex , exponent ))
173174
174- if not isinstance (value , sympy .Basic ):
175- # No known way to resolve this variable, return unchanged.
176- return value
177-
178175 # Input is either a sympy formula or the dictionary maps to a
179176 # formula. Use sympy to resolve the value.
180177 # Note that sympy.subs() is slow, so we want to avoid this and
@@ -186,7 +183,7 @@ def value_of(
186183 # Note that a sympy.SympifyError here likely means
187184 # that one of the expressions was not parsable by sympy
188185 # (such as a function returning NotImplemented)
189- v = value .subs (self .param_dict , simultaneous = True )
186+ v = value .subs (self ._param_dict , simultaneous = True )
190187
191188 if v .free_symbols :
192189 return v
@@ -197,23 +194,26 @@ def value_of(
197194 else :
198195 return float (v )
199196
197+ return self ._value_of_recursive (value )
198+
199+ def _value_of_recursive (self , value : 'cirq.TParamKey' ) -> 'cirq.TParamValComplex' :
200200 # Recursive parameter resolution. We can safely assume that value is a
201201 # single symbol, since combinations are handled earlier in the method.
202202 if value in self ._deep_eval_map :
203203 v = self ._deep_eval_map [value ]
204- if v is not _RecursionFlag :
205- return v
206- raise RecursionError ( 'Evaluation of {value} indirectly contains itself.' )
204+ if v is _RECURSION_FLAG :
205+ raise RecursionError ( 'Evaluation of {value} indirectly contains itself.' )
206+ return v
207207
208208 # There isn't a full evaluation for 'value' yet. Until it's ready,
209209 # map value to None to identify loops in component evaluation.
210- self ._deep_eval_map [value ] = _RecursionFlag # type: ignore
210+ self ._deep_eval_map [value ] = _RECURSION_FLAG # type: ignore
211211
212212 v = self .value_of (value , recursive = False )
213213 if v == value :
214214 self ._deep_eval_map [value ] = v
215215 else :
216- self ._deep_eval_map [value ] = self .value_of (v , recursive )
216+ self ._deep_eval_map [value ] = self .value_of (v , recursive = True )
217217 return self ._deep_eval_map [value ]
218218
219219 def _resolve_parameters_ (self , resolver : 'ParamResolver' , recursive : bool ) -> 'ParamResolver' :
@@ -224,17 +224,17 @@ def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'P
224224 new_dict .update (
225225 {k : resolver .value_of (v , recursive ) for k , v in new_dict .items ()} # type: ignore[misc]
226226 )
227- if recursive and self .param_dict :
227+ if recursive and self ._param_dict :
228228 new_resolver = ParamResolver (cast (ParamDictType , new_dict ))
229229 # Resolve down to single-step mappings.
230230 return ParamResolver ()._resolve_parameters_ (new_resolver , recursive = True )
231231 return ParamResolver (cast (ParamDictType , new_dict ))
232232
233233 def __iter__ (self ) -> Iterator [Union [str , sympy .Expr ]]:
234- return iter (self .param_dict )
234+ return iter (self ._param_dict )
235235
236236 def __bool__ (self ) -> bool :
237- return bool (self .param_dict )
237+ return bool (self ._param_dict )
238238
239239 def __getitem__ (
240240 self , key : Union ['cirq.TParamKey' , 'cirq.TParamValComplex' ]
@@ -243,29 +243,29 @@ def __getitem__(
243243
244244 def __hash__ (self ) -> int :
245245 if self ._param_hash is None :
246- self ._param_hash = hash (frozenset (self .param_dict .items ()))
246+ self ._param_hash = hash (frozenset (self ._param_dict .items ()))
247247 return self ._param_hash
248248
249249 def __eq__ (self , other ):
250250 if not isinstance (other , ParamResolver ):
251251 return NotImplemented
252- return self .param_dict == other .param_dict
252+ return self ._param_dict == other ._param_dict
253253
254254 def __ne__ (self , other ):
255255 return not self == other
256256
257257 def __repr__ (self ) -> str :
258258 param_dict_repr = (
259259 '{'
260- + ', ' .join ([ f'{ proper_repr (k )} : { proper_repr (v )} ' for k , v in self .param_dict .items ()] )
260+ + ', ' .join (f'{ proper_repr (k )} : { proper_repr (v )} ' for k , v in self ._param_dict .items ())
261261 + '}'
262262 )
263263 return f'cirq.ParamResolver({ param_dict_repr } )'
264264
265265 def _json_dict_ (self ) -> Dict [str , Any ]:
266266 return {
267267 # JSON requires mappings to have keys of basic types.
268- 'param_dict' : list (self .param_dict .items ())
268+ 'param_dict' : list (self ._param_dict .items ())
269269 }
270270
271271 @classmethod
0 commit comments