Skip to content

Commit 3531d29

Browse files
authored
Remove k_endog & k_exog parameters in SSM (#599)
* removed k_endog argument from BayesianETS making endog_names required and updated tests accordingly * removed k_exog from SARIMAX preferring exog_state_names for defining exogenous variables and updated tests accordingly * Updated VARMAX model by removing k_endog & k_exog arguments making endog_names required and exog_state_names required for exogenous variables and updated tests accordingly * updated DFM by removing k_endog & k_exog args and made endog_names required and exog_names required if exogenous variables are requested and updated tests accordingly * removed k_exog from STS regression component and updated tests accordingly * updated docstrings of VARMAX and DFM to relect changes in removal of k_endog and k_exog parameters * moved endog_names validation into stand alone utility * removed commented code and tests, updated validate names test to be reused in both endog and some exog cases * updated docstring in regression component * updated validate_names to always return None, removed _handle_input_data, reverted test_SARIMA_with_exogenous to use stationary initialization
1 parent 04a6259 commit 3531d29

File tree

12 files changed

+86
-341
lines changed

12 files changed

+86
-341
lines changed

pymc_extras/statespace/models/DFM.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytensor.tensor as pt
66

77
from pymc_extras.statespace.core.statespace import PyMCStateSpace
8-
from pymc_extras.statespace.models.utilities import make_default_coords
8+
from pymc_extras.statespace.models.utilities import make_default_coords, validate_names
99
from pymc_extras.statespace.utils.constants import (
1010
ALL_STATE_AUX_DIM,
1111
ALL_STATE_DIM,
@@ -224,9 +224,7 @@ def __init__(
224224
self,
225225
k_factors: int,
226226
factor_order: int,
227-
k_endog: int | None = None,
228227
endog_names: Sequence[str] | None = None,
229-
k_exog: int | None = None,
230228
exog_names: Sequence[str] | None = None,
231229
shared_exog_states: bool = False,
232230
exog_innovations: bool = False,
@@ -249,19 +247,11 @@ def __init__(
249247
and are modeled as a white noise process, i.e., :math:`f_t = \varepsilon_{f,t}`.
250248
Therefore, the state vector will include one state per factor and "factor_ar" will not exist.
251249
252-
k_endog : int, optional
253-
Number of observed time series. If not provided, the number of observed series will be inferred from `endog_names`.
254-
At least one of `k_endog` or `endog_names` must be provided.
255-
256250
endog_names : list of str, optional
257-
Names of the observed time series. If not provided, default names will be generated as `endog_1`, `endog_2`, ..., `endog_k` based on `k_endog`.
258-
At least one of `k_endog` or `endog_names` must be provided.
259-
260-
k_exog : int, optional
261-
Number of exogenous variables. If not provided, the model will not have exogenous variables.
251+
Names of the observed time series.
262252
263253
exog_names : Sequence[str], optional
264-
Names of the exogenous variables. If not provided, but `k_exog` is specified, default names will be generated as `exog_1`, `exog_2`, ..., `exog_k`.
254+
Names of the exogenous variables.
265255
266256
shared_exog_states: bool, optional
267257
Whether exogenous latent states are shared across the observed states. If True, there will be only one set of exogenous latent
@@ -289,13 +279,8 @@ def __init__(
289279
290280
"""
291281

292-
if k_endog is None and endog_names is None:
293-
raise ValueError("Either k_endog or endog_names must be provided.")
294-
if k_endog is None:
295-
k_endog = len(endog_names)
296-
if endog_names is None:
297-
endog_names = [f"endog_{i}" for i in range(k_endog)]
298-
282+
validate_names(endog_names, var_name="endog_names", optional=False)
283+
k_endog = len(endog_names)
299284
self.endog_names = endog_names
300285
self.k_endog = k_endog
301286
self.k_factors = k_factors
@@ -304,17 +289,17 @@ def __init__(
304289
self.error_var = error_var
305290
self.error_cov_type = error_cov_type
306291

307-
if k_exog is None and exog_names is None:
308-
self.k_exog = 0
309-
else:
292+
if exog_names is not None:
310293
self.shared_exog_states = shared_exog_states
311294
self.exog_innovations = exog_innovations
312-
if k_exog is None:
313-
k_exog = len(exog_names) if exog_names is not None else 0
314-
elif exog_names is None:
315-
exog_names = [f"exog_{i}" for i in range(k_exog)] if k_exog > 0 else None
295+
validate_names(
296+
exog_names, var_name="exog_names", optional=True
297+
) # Not sure if this adds anything
298+
k_exog = len(exog_names)
316299
self.k_exog = k_exog
317300
self.exog_names = exog_names
301+
else:
302+
self.k_exog = 0
318303

319304
self.k_exog_states = self.k_exog * self.k_endog if not shared_exog_states else self.k_exog
320305
self.exog_flag = self.k_exog > 0

pymc_extras/statespace/models/ETS.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.tensor.slinalg import solve_discrete_lyapunov
1010

1111
from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
12-
from pymc_extras.statespace.models.utilities import make_default_coords
12+
from pymc_extras.statespace.models.utilities import make_default_coords, validate_names
1313
from pymc_extras.statespace.utils.constants import (
1414
ALL_STATE_AUX_DIM,
1515
ALL_STATE_DIM,
@@ -138,12 +138,9 @@ class BayesianETS(PyMCStateSpace):
138138
or 'N'.
139139
If provided, the model will be initialized from the given order, and the `trend`, `damped_trend`, and `seasonal`
140140
arguments will be ignored.
141-
endog_names: str or list of str, Optional
141+
endog_names: str or list of str
142142
Names associated with observed states. If a list, the length should be equal to the number of time series
143143
to be estimated.
144-
k_endog: int, Optional
145-
Number of time series to estimate. If endog_names are provided, this is ignored and len(endog_names) is
146-
used instead.
147144
trend: bool
148145
Whether to include a trend component. Setting ``trend=True`` is equivalent to ``order[1] == 'A'``.
149146
damped_trend: bool
@@ -213,7 +210,6 @@ def __init__(
213210
self,
214211
order: tuple[str, str, str] | None = None,
215212
endog_names: str | list[str] | None = None,
216-
k_endog: int = 1,
217213
trend: bool = True,
218214
damped_trend: bool = False,
219215
seasonal: bool = False,
@@ -265,13 +261,9 @@ def __init__(
265261
if self.seasonal and self.seasonal_periods is None:
266262
raise ValueError("If seasonal is True, seasonal_periods must be provided.")
267263

268-
if endog_names is not None:
269-
endog_names = list(endog_names)
270-
k_endog = len(endog_names)
271-
else:
272-
endog_names = [f"data_{i}" for i in range(k_endog)] if k_endog > 1 else ["data"]
273-
274-
self.endog_names = endog_names
264+
validate_names(endog_names, var_name="endog_names", optional=False)
265+
k_endog = len(endog_names)
266+
self.endog_names = list(endog_names)
275267

276268
if dense_innovation_covariance and k_endog == 1:
277269
dense_innovation_covariance = False

pymc_extras/statespace/models/SARIMAX.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
make_default_coords,
1313
make_harvey_state_names,
1414
make_SARIMA_transition_matrix,
15+
validate_names,
1516
)
1617
from pymc_extras.statespace.utils.constants import (
1718
ALL_STATE_AUX_DIM,
@@ -132,7 +133,6 @@ def __init__(
132133
order: tuple[int, int, int],
133134
seasonal_order: tuple[int, int, int, int] | None = None,
134135
exog_state_names: list[str] | None = None,
135-
k_exog: int | None = None,
136136
stationary_initialization: bool = True,
137137
filter_type: str = "standard",
138138
state_structure: str = "fast",
@@ -166,10 +166,6 @@ def __init__(
166166
exog_state_names : list[str], optional
167167
Names of the exogenous state variables.
168168
169-
k_exog : int, optional
170-
Number of exogenous variables. If provided, must match the length of
171-
`exog_state_names`.
172-
173169
stationary_initialization : bool, default True
174170
If true, the initial state and initial state covariance will not be assigned priors. Instead, their steady
175171
state values will be used.
@@ -212,18 +208,10 @@ def __init__(
212208
if seasonal_order is None:
213209
seasonal_order = (0, 0, 0, 0)
214210

215-
if exog_state_names is None and k_exog is not None:
216-
exog_state_names = [f"exogenous_{i}" for i in range(k_exog)]
217-
elif exog_state_names is not None and k_exog is None:
218-
k_exog = len(exog_state_names)
219-
elif exog_state_names is not None and k_exog is not None:
220-
if len(exog_state_names) != k_exog:
221-
raise ValueError(
222-
f"Based on provided inputs, expected exog_state_names to have {k_exog} elements, but "
223-
f"found {len(exog_state_names)}"
224-
)
225-
else:
226-
k_exog = 0
211+
validate_names(
212+
exog_state_names, var_name="exog_state_names", optional=True
213+
) # Not sure if this adds anything
214+
k_exog = len(exog_state_names) if exog_state_names is not None else 0
227215

228216
self.exog_state_names = exog_state_names
229217
self.k_exog = k_exog

pymc_extras/statespace/models/VARMAX.py

Lines changed: 8 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.tensor.slinalg import solve_discrete_lyapunov
1010

1111
from pymc_extras.statespace.core.statespace import PyMCStateSpace
12-
from pymc_extras.statespace.models.utilities import make_default_coords
12+
from pymc_extras.statespace.models.utilities import make_default_coords, validate_names
1313
from pymc_extras.statespace.utils.constants import (
1414
ALL_STATE_AUX_DIM,
1515
ALL_STATE_DIM,
@@ -99,9 +99,7 @@ def __init__(
9999
self,
100100
order: tuple[int, int],
101101
endog_names: list[str] | None = None,
102-
k_endog: int | None = None,
103102
exog_state_names: list[str] | dict[str, list[str]] | None = None,
104-
k_exog: int | dict[str, int] | None = None,
105103
stationary_initialization: bool = False,
106104
filter_type: str = "standard",
107105
measurement_error: bool = False,
@@ -118,23 +116,14 @@ def __init__(
118116
specified order are included. For restricted models, set zeros directly on the priors.
119117
120118
endog_names: list of str, optional
121-
Names of the endogenous variables being modeled. Used to generate names for the state and shock coords. If
122-
None, the state names will simply be numbered.
123-
124-
Exactly one of either ``endog_names`` or ``k_endog`` must be specified.
119+
Names of the endogenous variables being modeled. Used to generate names for the state and shock coords.
125120
126121
exog_state_names : list[str] or dict[str, list[str]], optional
127122
Names of the exogenous state variables. If a list, all endogenous variables will share the same exogenous
128123
variables. If a dict, keys should be the names of the endogenous variables, and values should be lists of the
129124
exogenous variable names for that endogenous variable. Endogenous variables not included in the dict will
130125
be assumed to have no exogenous variables. If None, no exogenous variables will be included.
131126
132-
k_exog : int or dict[str, int], optional
133-
Number of exogenous variables. If an int, all endogenous variables will share the same number of exogenous
134-
variables. If a dict, keys should be the names of the endogenous variables, and values should be the number of
135-
exogenous variables for that endogenous variable. Endogenous variables not included in the dict will be
136-
assumed to have no exogenous variables. If None, no exogenous variables will be included.
137-
138127
stationary_initialization: bool, default False
139128
If true, the initial state and initial state covariance will not be assigned priors. Instead, their steady
140129
state values will be used. If False, the user is responsible for setting priors on the initial state and
@@ -162,62 +151,23 @@ def __init__(
162151
to all sampling methods.
163152
164153
"""
165-
if (endog_names is None) and (k_endog is None):
166-
raise ValueError("Must specify either endog_names or k_endog")
167-
if (endog_names is not None) and (k_endog is None):
168-
k_endog = len(endog_names)
169-
if (endog_names is None) and (k_endog is not None):
170-
endog_names = [f"observed_{i}" for i in range(k_endog)]
171-
if (endog_names is not None) and (k_endog is not None):
172-
if len(endog_names) != k_endog:
173-
raise ValueError("Length of provided endog_names does not match provided k_endog")
154+
155+
validate_names(endog_names, var_name="endog_names", optional=False)
156+
k_endog = len(endog_names)
174157

175158
needs_exog_data = False
176159

177-
if k_exog is not None and not isinstance(k_exog, int | dict):
178-
raise ValueError("If not None, k_exog must be either an int or a dict")
179160
if exog_state_names is not None and not isinstance(exog_state_names, list | dict):
180161
raise ValueError("If not None, exog_state_names must be either a list or a dict")
181162

182-
if k_exog is not None and exog_state_names is not None:
183-
if isinstance(k_exog, int) and isinstance(exog_state_names, list):
184-
if len(exog_state_names) != k_exog:
185-
raise ValueError("Length of exog_state_names does not match provided k_exog")
186-
elif isinstance(k_exog, int) and isinstance(exog_state_names, dict):
187-
raise ValueError(
188-
"If k_exog is an int, exog_state_names must be a list of the same length (or None)"
189-
)
190-
elif isinstance(k_exog, dict) and isinstance(exog_state_names, list):
191-
raise ValueError(
192-
"If k_exog is a dict, exog_state_names must be a dict as well (or None)"
193-
)
194-
elif isinstance(k_exog, dict) and isinstance(exog_state_names, dict):
195-
if set(k_exog.keys()) != set(exog_state_names.keys()):
196-
raise ValueError("Keys of k_exog and exog_state_names dicts must match")
197-
if not all(
198-
len(names) == k for names, k in zip(exog_state_names.values(), k_exog.values())
199-
):
200-
raise ValueError(
201-
"If both k_endog and exog_state_names are provided, lengths of exog_state_names "
202-
"lists must match corresponding values in k_exog"
203-
)
204-
needs_exog_data = True
205-
206-
if k_exog is not None and exog_state_names is None:
207-
if isinstance(k_exog, int):
208-
exog_state_names = [f"exogenous_{i}" for i in range(k_exog)]
209-
elif isinstance(k_exog, dict):
210-
exog_state_names = {
211-
name: [f"{name}_exogenous_{i}" for i in range(k)] for name, k in k_exog.items()
212-
}
213-
needs_exog_data = True
214-
215-
if k_exog is None and exog_state_names is not None:
163+
if exog_state_names is not None:
216164
if isinstance(exog_state_names, list):
217165
k_exog = len(exog_state_names)
218166
elif isinstance(exog_state_names, dict):
219167
k_exog = {name: len(names) for name, names in exog_state_names.items()}
220168
needs_exog_data = True
169+
else:
170+
k_exog = None
221171

222172
# If exog_state_names is a dict but 1) all endog variables are among the keys, and 2) all values are the same
223173
# then we can drop back to the list case.

pymc_extras/statespace/models/structural/components/regression.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pytensor import tensor as pt
44

55
from pymc_extras.statespace.models.structural.core import Component
6+
from pymc_extras.statespace.models.utilities import validate_names
67
from pymc_extras.statespace.utils.constants import TIME_DIM
78

89

@@ -12,10 +13,6 @@ class RegressionComponent(Component):
1213
1314
Parameters
1415
----------
15-
k_exog : int | None, default None
16-
Number of exogenous variables to include in the regression. Must be specified if
17-
state_names is not provided.
18-
1916
name : str | None, default "regression"
2017
A name for this regression component. Used to label dimensions and coordinates.
2118
@@ -107,7 +104,6 @@ class RegressionComponent(Component):
107104

108105
def __init__(
109106
self,
110-
k_exog: int | None = None,
111107
name: str | None = "regression",
112108
state_names: list[str] | None = None,
113109
observed_state_names: list[str] | None = None,
@@ -120,7 +116,9 @@ def __init__(
120116
observed_state_names = ["data"]
121117

122118
self.innovations = innovations
123-
k_exog = self._handle_input_data(k_exog, state_names, name)
119+
validate_names(state_names, var_name="state_names", optional=False)
120+
k_exog = len(state_names)
121+
self.state_names = state_names
124122

125123
k_states = k_exog
126124
k_endog = len(observed_state_names)
@@ -140,26 +138,6 @@ def __init__(
140138
obs_state_idxs=np.ones(k_states),
141139
)
142140

143-
@staticmethod
144-
def _get_state_names(k_exog: int | None, state_names: list[str] | None, name: str):
145-
if k_exog is None and state_names is None:
146-
raise ValueError("Must specify at least one of k_exog or state_names")
147-
if state_names is not None and k_exog is not None:
148-
if len(state_names) != k_exog:
149-
raise ValueError(f"Expected {k_exog} state names, found {len(state_names)}")
150-
elif k_exog is None:
151-
k_exog = len(state_names)
152-
else:
153-
state_names = [f"{name}_{i + 1}" for i in range(k_exog)]
154-
155-
return k_exog, state_names
156-
157-
def _handle_input_data(self, k_exog: int, state_names: list[str] | None, name) -> int:
158-
k_exog, state_names = self._get_state_names(k_exog, state_names, name)
159-
self.state_names = state_names
160-
161-
return k_exog
162-
163141
def make_symbolic_graph(self) -> None:
164142
k_endog = self.k_endog
165143
k_endog_effective = 1 if self.share_states else k_endog

pymc_extras/statespace/models/utilities.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,3 +670,10 @@ def get_exog_dims_from_idata(exog_name, idata):
670670
exog_dims = None
671671

672672
return exog_dims
673+
674+
675+
def validate_names(names: list[str], var_name: str, optional: bool = True) -> None:
676+
if names is None:
677+
if optional:
678+
return None
679+
raise ValueError(f"Must specify {var_name}")

tests/statespace/core/test_statespace.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,8 @@ def exog_ss_mod(exog_data):
182182
level_trend = st.LevelTrendComponent(name="trend", order=1, innovations_order=[0])
183183
exog = st.RegressionComponent(
184184
name="exog", # Name of this exogenous variable component
185-
k_exog=1, # Only one exogenous variable now
186185
innovations=False, # Typically fixed effect (no stochastic evolution)
187-
state_names=exog_data[["x1"]].columns.tolist(),
186+
state_names=exog_data[["x1"]].columns.tolist(), # Only one exogenous variable now
188187
)
189188

190189
combined_model = level_trend + exog
@@ -198,9 +197,8 @@ def exog_ss_mod_mv(exog_data_mv):
198197
)
199198
exog = st.RegressionComponent(
200199
name="exog", # Name of this exogenous variable component
201-
k_exog=1, # Only one exogenous variable now
202200
innovations=False, # Typically fixed effect (no stochastic evolution)
203-
state_names=exog_data_mv[["x1"]].columns.tolist(),
201+
state_names=exog_data_mv[["x1"]].columns.tolist(), # Only one exogenous variable now
204202
observed_state_names=["y1", "y2"],
205203
)
206204

tests/statespace/filters/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_lgss_distribution_with_dims(output_name, ss_mod_me, pymc_model_2):
193193
def test_lgss_with_time_varying_inputs(output_name, rng):
194194
X = rng.random(size=(10, 3), dtype=floatX)
195195
ss_mod = structural.LevelTrendComponent() + structural.RegressionComponent(
196-
name="exog", k_exog=3
196+
name="exog", state_names=["exog_0", "exog_1", "exog_2"]
197197
)
198198
mod = ss_mod.build("data", verbose=False)
199199

0 commit comments

Comments
 (0)