99from pytensor .tensor .slinalg import solve_discrete_lyapunov
1010
1111from pymc_extras .statespace .core .statespace import PyMCStateSpace
12- from pymc_extras .statespace .models .utilities import make_default_coords
12+ from pymc_extras .statespace .models .utilities import make_default_coords , validate_names
1313from pymc_extras .statespace .utils .constants import (
1414 ALL_STATE_AUX_DIM ,
1515 ALL_STATE_DIM ,
@@ -99,9 +99,7 @@ def __init__(
9999 self ,
100100 order : tuple [int , int ],
101101 endog_names : list [str ] | None = None ,
102- k_endog : int | None = None ,
103102 exog_state_names : list [str ] | dict [str , list [str ]] | None = None ,
104- k_exog : int | dict [str , int ] | None = None ,
105103 stationary_initialization : bool = False ,
106104 filter_type : str = "standard" ,
107105 measurement_error : bool = False ,
@@ -118,23 +116,14 @@ def __init__(
118116 specified order are included. For restricted models, set zeros directly on the priors.
119117
120118 endog_names: list of str, optional
121- Names of the endogenous variables being modeled. Used to generate names for the state and shock coords. If
122- None, the state names will simply be numbered.
123-
124- Exactly one of either ``endog_names`` or ``k_endog`` must be specified.
119+ Names of the endogenous variables being modeled. Used to generate names for the state and shock coords.
125120
126121 exog_state_names : list[str] or dict[str, list[str]], optional
127122 Names of the exogenous state variables. If a list, all endogenous variables will share the same exogenous
128123 variables. If a dict, keys should be the names of the endogenous variables, and values should be lists of the
129124 exogenous variable names for that endogenous variable. Endogenous variables not included in the dict will
130125 be assumed to have no exogenous variables. If None, no exogenous variables will be included.
131126
132- k_exog : int or dict[str, int], optional
133- Number of exogenous variables. If an int, all endogenous variables will share the same number of exogenous
134- variables. If a dict, keys should be the names of the endogenous variables, and values should be the number of
135- exogenous variables for that endogenous variable. Endogenous variables not included in the dict will be
136- assumed to have no exogenous variables. If None, no exogenous variables will be included.
137-
138127 stationary_initialization: bool, default False
139128 If true, the initial state and initial state covariance will not be assigned priors. Instead, their steady
140129 state values will be used. If False, the user is responsible for setting priors on the initial state and
@@ -162,62 +151,23 @@ def __init__(
162151 to all sampling methods.
163152
164153 """
165- if (endog_names is None ) and (k_endog is None ):
166- raise ValueError ("Must specify either endog_names or k_endog" )
167- if (endog_names is not None ) and (k_endog is None ):
168- k_endog = len (endog_names )
169- if (endog_names is None ) and (k_endog is not None ):
170- endog_names = [f"observed_{ i } " for i in range (k_endog )]
171- if (endog_names is not None ) and (k_endog is not None ):
172- if len (endog_names ) != k_endog :
173- raise ValueError ("Length of provided endog_names does not match provided k_endog" )
154+
155+ validate_names (endog_names , var_name = "endog_names" , optional = False )
156+ k_endog = len (endog_names )
174157
175158 needs_exog_data = False
176159
177- if k_exog is not None and not isinstance (k_exog , int | dict ):
178- raise ValueError ("If not None, k_exog must be either an int or a dict" )
179160 if exog_state_names is not None and not isinstance (exog_state_names , list | dict ):
180161 raise ValueError ("If not None, exog_state_names must be either a list or a dict" )
181162
182- if k_exog is not None and exog_state_names is not None :
183- if isinstance (k_exog , int ) and isinstance (exog_state_names , list ):
184- if len (exog_state_names ) != k_exog :
185- raise ValueError ("Length of exog_state_names does not match provided k_exog" )
186- elif isinstance (k_exog , int ) and isinstance (exog_state_names , dict ):
187- raise ValueError (
188- "If k_exog is an int, exog_state_names must be a list of the same length (or None)"
189- )
190- elif isinstance (k_exog , dict ) and isinstance (exog_state_names , list ):
191- raise ValueError (
192- "If k_exog is a dict, exog_state_names must be a dict as well (or None)"
193- )
194- elif isinstance (k_exog , dict ) and isinstance (exog_state_names , dict ):
195- if set (k_exog .keys ()) != set (exog_state_names .keys ()):
196- raise ValueError ("Keys of k_exog and exog_state_names dicts must match" )
197- if not all (
198- len (names ) == k for names , k in zip (exog_state_names .values (), k_exog .values ())
199- ):
200- raise ValueError (
201- "If both k_endog and exog_state_names are provided, lengths of exog_state_names "
202- "lists must match corresponding values in k_exog"
203- )
204- needs_exog_data = True
205-
206- if k_exog is not None and exog_state_names is None :
207- if isinstance (k_exog , int ):
208- exog_state_names = [f"exogenous_{ i } " for i in range (k_exog )]
209- elif isinstance (k_exog , dict ):
210- exog_state_names = {
211- name : [f"{ name } _exogenous_{ i } " for i in range (k )] for name , k in k_exog .items ()
212- }
213- needs_exog_data = True
214-
215- if k_exog is None and exog_state_names is not None :
163+ if exog_state_names is not None :
216164 if isinstance (exog_state_names , list ):
217165 k_exog = len (exog_state_names )
218166 elif isinstance (exog_state_names , dict ):
219167 k_exog = {name : len (names ) for name , names in exog_state_names .items ()}
220168 needs_exog_data = True
169+ else :
170+ k_exog = None
221171
222172 # If exog_state_names is a dict but 1) all endog variables are among the keys, and 2) all values are the same
223173 # then we can drop back to the list case.
0 commit comments