Skip to content

Commit 1b44634

Browse files
authored
[enhancement] Add conditional seasonality as a new feature (#1067)
* Extended config_seasonality with condition_name *Indicates whether conditional seasonality is applied or not - Type: boolean * Integrated condition_name into the workflow *Mask out all the seasonality features that do not belong to the conditioned seasonality *Ensured conditional seasonality is part of the model's data_params *Integrated check in check_dataframe to rule out false column names for conditional seasonality * Added conditional seasonality to test_custom_seasons() in test_integration.py * Fixed Linters issues * Fixed Linters issues * Allow binary input and floats between 0 and 1 in conditional columns * Allow binary input and floats between 0 and 1 in conditional columns
1 parent 3d46b1e commit 1b44634

File tree

5 files changed

+70
-19
lines changed

5 files changed

+70
-19
lines changed

neuralprophet/configure.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def init_data_params(
4040
config_lagged_regressors: Optional[ConfigLaggedRegressors] = None,
4141
config_regressors=None,
4242
config_events: Optional[ConfigEvents] = None,
43+
config_seasonality: Optional[ConfigSeasonality] = None,
4344
):
4445
if len(df["ID"].unique()) == 1:
4546
if not self.global_normalization:
@@ -51,6 +52,7 @@ def init_data_params(
5152
config_lagged_regressors=config_lagged_regressors,
5253
config_regressors=config_regressors,
5354
config_events=config_events,
55+
config_seasonality=config_seasonality,
5456
global_normalization=self.global_normalization,
5557
global_time_normalization=self.global_normalization,
5658
)
@@ -303,6 +305,7 @@ class Season:
303305
resolution: int
304306
period: float
305307
arg: np_types.SeasonalityArgument
308+
condition_name: Optional[str]
306309

307310

308311
@dataclass
@@ -315,16 +318,17 @@ class ConfigSeasonality:
315318
daily_arg: np_types.SeasonalityArgument = "auto"
316319
periods: OrderedDict = field(init=False) # contains SeasonConfig objects
317320
global_local: np_types.SeasonGlobalLocalMode = "local"
321+
condition_name: Optional[str] = None
318322

319323
def __post_init__(self):
320324
if self.reg_lambda > 0 and self.computation == "fourier":
321325
log.info("Note: Fourier-based seasonality regularization is experimental.")
322326
self.reg_lambda = 0.001 * self.reg_lambda
323327
self.periods = OrderedDict(
324328
{
325-
"yearly": Season(resolution=6, period=365.25, arg=self.yearly_arg),
326-
"weekly": Season(resolution=3, period=7, arg=self.weekly_arg),
327-
"daily": Season(resolution=6, period=1, arg=self.daily_arg),
329+
"yearly": Season(resolution=6, period=365.25, arg=self.yearly_arg, condition_name=None),
330+
"weekly": Season(resolution=3, period=7, arg=self.weekly_arg, condition_name=None),
331+
"daily": Season(resolution=6, period=1, arg=self.daily_arg, condition_name=None),
328332
}
329333
)
330334

@@ -333,8 +337,8 @@ def __post_init__(self):
333337
log.error("Invalid global_local mode '{}'. Set to 'global'".format(self.global_local))
334338
self.global_local = "global"
335339

336-
def append(self, name, period, resolution, arg):
337-
self.periods[name] = Season(resolution=resolution, period=period, arg=arg)
340+
def append(self, name, period, resolution, arg, condition_name):
341+
self.periods[name] = Season(resolution=resolution, period=period, arg=arg, condition_name=condition_name)
338342

339343

340344
@dataclass

neuralprophet/df_utils.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pandas as pd
1111

1212
if TYPE_CHECKING:
13-
from neuralprophet.configure import ConfigEvents, ConfigLaggedRegressors
13+
from neuralprophet.configure import ConfigEvents, ConfigLaggedRegressors, ConfigSeasonality
1414

1515

1616
log = logging.getLogger("NP.df_utils")
@@ -141,6 +141,7 @@ def data_params_definition(
141141
config_lagged_regressors: Optional[ConfigLaggedRegressors] = None,
142142
config_regressors=None,
143143
config_events: Optional[ConfigEvents] = None,
144+
config_seasonality: Optional[ConfigSeasonality] = None,
144145
):
145146
"""
146147
Initialize data scaling values.
@@ -178,6 +179,8 @@ def data_params_definition(
178179
extra regressors (with known future values) with sub_parameters normalize (bool)
179180
config_events : configure.ConfigEvents
180181
user specified events configs
182+
config_seasonality : configure.ConfigSeasonality
183+
user specified seasonality configs
181184
182185
Returns
183186
-------
@@ -221,6 +224,13 @@ def data_params_definition(
221224
if event not in df.columns:
222225
raise ValueError(f"Event {event} not found in DataFrame.")
223226
data_params[event] = ShiftScale()
227+
if config_seasonality is not None:
228+
for season in config_seasonality.periods:
229+
condition_name = config_seasonality.periods[season].condition_name
230+
if condition_name is not None:
231+
if condition_name not in df.columns:
232+
raise ValueError(f"Seasonality condition {condition_name} not found in DataFrame.")
233+
data_params[condition_name] = ShiftScale()
224234
return data_params
225235

226236

@@ -230,6 +240,7 @@ def init_data_params(
230240
config_lagged_regressors: Optional[ConfigLaggedRegressors] = None,
231241
config_regressors=None,
232242
config_events: Optional[ConfigEvents] = None,
243+
config_seasonality: Optional[ConfigSeasonality] = None,
233244
global_normalization=False,
234245
global_time_normalization=False,
235246
):
@@ -265,6 +276,8 @@ def init_data_params(
265276
extra regressors (with known future values)
266277
config_events : configure.ConfigEvents
267278
user specified events configs
279+
config_seasonality : configure.ConfigSeasonality
280+
user specified seasonality configs
268281
global_normalization : bool
269282
270283
``True``: sets global modeling training with global normalization
@@ -289,7 +302,7 @@ def init_data_params(
289302
df, _, _, _ = prep_or_copy_df(df)
290303
df_merged = df.copy(deep=True).drop("ID", axis=1)
291304
global_data_params = data_params_definition(
292-
df_merged, normalize, config_lagged_regressors, config_regressors, config_events
305+
df_merged, normalize, config_lagged_regressors, config_regressors, config_events, config_seasonality
293306
)
294307
if global_normalization:
295308
log.debug(
@@ -300,7 +313,7 @@ def init_data_params(
300313
for df_name, df_i in df.groupby("ID"):
301314
df_i.drop("ID", axis=1, inplace=True)
302315
local_data_params[df_name] = data_params_definition(
303-
df_i, normalize, config_lagged_regressors, config_regressors, config_events
316+
df_i, normalize, config_lagged_regressors, config_regressors, config_events, config_seasonality
304317
)
305318
if global_time_normalization:
306319
# Overwrite local time normalization data_params with global values (pointer)
@@ -387,7 +400,7 @@ def normalize(df, data_params):
387400
return df
388401

389402

390-
def check_single_dataframe(df, check_y, covariates, regressors, events):
403+
def check_single_dataframe(df, check_y, covariates, regressors, events, seasonalities):
391404
"""Performs basic data sanity checks and ordering
392405
as well as prepare dataframe for fitting or predicting.
393406
@@ -403,6 +416,8 @@ def check_single_dataframe(df, check_y, covariates, regressors, events):
403416
regressor column names
404417
events : list or dict
405418
event column names
419+
seasonalities : list or dict
420+
seasonalities column names
406421
407422
Returns
408423
-------
@@ -451,6 +466,13 @@ def check_single_dataframe(df, check_y, covariates, regressors, events):
451466
columns.extend(events)
452467
else: # treat as dict
453468
columns.extend(events.keys())
469+
if seasonalities is not None:
470+
for season in seasonalities.periods:
471+
condition_name = seasonalities.periods[season].condition_name
472+
if condition_name is not None:
473+
if not df[condition_name].isin([True, False]).all() and not df[condition_name].between(0, 1).all():
474+
raise ValueError(f"Condition column {condition_name} must be boolean or numeric between 0 and 1.")
475+
columns.append(condition_name)
454476
for name in columns:
455477
if name not in df:
456478
raise ValueError(f"Column {name!r} missing from dataframe")
@@ -470,7 +492,7 @@ def check_single_dataframe(df, check_y, covariates, regressors, events):
470492
return df, regressors_to_remove
471493

472494

473-
def check_dataframe(df, check_y=True, covariates=None, regressors=None, events=None):
495+
def check_dataframe(df, check_y=True, covariates=None, regressors=None, events=None, seasonalities=None):
474496
"""Performs basic data sanity checks and ordering,
475497
as well as prepare dataframe for fitting or predicting.
476498
@@ -487,6 +509,8 @@ def check_dataframe(df, check_y=True, covariates=None, regressors=None, events=N
487509
regressor column names
488510
events : list or dict
489511
event column names
512+
seasonalities : list or dict
513+
seasonalities column names
490514
491515
Returns
492516
-------
@@ -497,7 +521,7 @@ def check_dataframe(df, check_y=True, covariates=None, regressors=None, events=N
497521
checked_df = pd.DataFrame()
498522
regressors_to_remove = []
499523
for df_name, df_i in df.groupby("ID"):
500-
df_aux, reg = check_single_dataframe(df_i, check_y, covariates, regressors, events)
524+
df_aux, reg = check_single_dataframe(df_i, check_y, covariates, regressors, events, seasonalities)
501525
df_aux = df_aux.copy(deep=True)
502526
if len(reg) > 0:
503527
regressors_to_remove.append(*reg)

neuralprophet/forecaster.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def __init__(
409409
weekly_arg=weekly_seasonality,
410410
daily_arg=daily_seasonality,
411411
global_local=season_global_local,
412+
condition_name=None,
412413
)
413414

414415
# Events
@@ -625,13 +626,18 @@ def add_country_holidays(self, country_name, lower_window=0, upper_window=0, reg
625626
self.config_country_holidays.init_holidays()
626627
return self
627628

628-
def add_seasonality(self, name, period, fourier_order):
629+
def add_seasonality(self, name, period, fourier_order, condition_name=None):
629630
"""Add a seasonal component with specified period, number of Fourier components, and regularization.
630631
631632
Increasing the number of Fourier components allows the seasonality to change more quickly
632633
(at risk of overfitting).
633634
Note: regularization and mode (additive/multiplicative) are set in the main init.
634635
636+
If condition_name is provided, the dataframe passed to `fit` and
637+
`predict` should have a column with the specified condition_name
638+
containing only zeros and ones, deciding when to apply seasonality.
639+
Floats between 0 and 1 can be used to apply seasonality partially.
640+
635641
Parameters
636642
----------
637643
name : string
@@ -640,17 +646,22 @@ def add_seasonality(self, name, period, fourier_order):
640646
number of days in one period.
641647
fourier_order : int
642648
number of Fourier components to use.
643-
649+
condition_name : string
650+
string name of the seasonality condition.
644651
"""
645652
if self.fitted:
646653
raise Exception("Seasonality must be added prior to model fitting.")
647654
if name in ["daily", "weekly", "yearly"]:
648655
log.error("Please use inbuilt daily, weekly, or yearly seasonality or set another name.")
649656
# Do not Allow overwriting built-in seasonalities
650657
self._validate_column_name(name, seasons=True)
658+
if condition_name is not None:
659+
self._validate_column_name(condition_name)
651660
if fourier_order <= 0:
652661
raise ValueError("Fourier Order must be > 0")
653-
self.config_seasonality.append(name=name, period=period, resolution=fourier_order, arg="custom")
662+
self.config_seasonality.append(
663+
name=name, period=period, resolution=fourier_order, condition_name=condition_name, arg="custom"
664+
)
654665
return self
655666

656667
def fit(
@@ -2404,6 +2415,7 @@ def _check_dataframe(self, df, check_y=True, exogenous=True):
24042415
covariates=self.config_lagged_regressors if exogenous else None,
24052416
regressors=self.config_regressors if exogenous else None,
24062417
events=self.config_events if exogenous else None,
2418+
seasonalities=self.config_seasonality if exogenous else None,
24072419
)
24082420
for reg in regressors_to_remove:
24092421
log.warning(f"Removing regressor {reg} because it is not present in the data.")
@@ -2507,6 +2519,7 @@ def _init_train_loader(self, df, num_workers=0):
25072519
config_lagged_regressors=self.config_lagged_regressors,
25082520
config_regressors=self.config_regressors,
25092521
config_events=self.config_events,
2522+
config_seasonality=self.config_seasonality,
25102523
)
25112524

25122525
df = self._normalize(df)

neuralprophet/time_dataset.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def _stride_time_features_for_forecasts(x):
288288
inputs["time"] = time
289289

290290
if config_seasonality is not None:
291-
seasonalities = seasonal_features_from_dates(df["ds"], config_seasonality)
291+
seasonalities = seasonal_features_from_dates(df, config_seasonality)
292292
for name, features in seasonalities.items():
293293
if max_lags == 0:
294294
seasonalities[name] = np.expand_dims(features, axis=1)
@@ -608,15 +608,15 @@ def make_regressors_features(df, config_regressors):
608608
return additive_regressors, multiplicative_regressors
609609

610610

611-
def seasonal_features_from_dates(dates, config_seasonality: configure.ConfigSeasonality):
611+
def seasonal_features_from_dates(df, config_seasonality: configure.ConfigSeasonality):
612612
"""Dataframe with seasonality features.
613613
614614
Includes seasonality features, holiday features, and added regressors.
615615
616616
Parameters
617617
----------
618-
dates : pd.Series
619-
With dates for computing seasonality features
618+
df : pd.DataFrame
619+
Dataframe with all values
620620
config_seasonality : configure.ConfigSeasonality
621621
Configuration for seasonalities
622622
@@ -626,6 +626,7 @@ def seasonal_features_from_dates(dates, config_seasonality: configure.ConfigSeas
626626
Dictionary with keys for each period name containing an np.array
627627
with the respective regression features. each with dims: (len(dates), 2*fourier_order)
628628
"""
629+
dates = df["ds"]
629630
assert len(dates.shape) == 1
630631
seasonalities = OrderedDict({})
631632
# Seasonality features
@@ -639,5 +640,7 @@ def seasonal_features_from_dates(dates, config_seasonality: configure.ConfigSeas
639640
)
640641
else:
641642
raise NotImplementedError
643+
if period.condition_name is not None:
644+
features = features * df[period.condition_name].values[:, np.newaxis]
642645
seasonalities[name] = features
643646
return seasonalities

tests/test_integration.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,17 @@ def test_custom_seasons():
236236
batch_size=BATCH_SIZE,
237237
learning_rate=LR,
238238
)
239-
m = m.add_seasonality(name="quarterly", period=90, fourier_order=5)
239+
# conditional seasonality
240+
df["ds"] = pd.to_datetime(df["ds"])
241+
df["on_season"] = df["ds"].apply(lambda x: x.month in [9, 10, 11, 12, 1])
242+
df["off_season"] = df["ds"].apply(lambda x: x.month not in [9, 10, 11, 12, 1])
243+
m.add_seasonality(name="on_season", period=7, fourier_order=3, condition_name="on_season")
244+
m.add_seasonality(name="off_season", period=7, fourier_order=3, condition_name="off_season")
240245
log.debug(f"seasonalities: {m.config_seasonality.periods}")
241246
metrics_df = m.fit(df, freq="D")
242247
future = m.make_future_dataframe(df, n_historic_predictions=365, periods=365)
248+
future["on_season"] = future["ds"].apply(lambda x: x.month in [9, 10, 11, 12, 1])
249+
future["off_season"] = future["ds"].apply(lambda x: x.month not in [9, 10, 11, 12, 1])
243250
forecast = m.predict(df=future)
244251
log.debug(f"season params: {m.model.season_params.items()}")
245252
if PLOT:

0 commit comments

Comments
 (0)