10
10
import pandas as pd
11
11
12
12
if TYPE_CHECKING :
13
- from neuralprophet .configure import ConfigEvents , ConfigLaggedRegressors
13
+ from neuralprophet .configure import ConfigEvents , ConfigLaggedRegressors , ConfigSeasonality
14
14
15
15
16
16
log = logging .getLogger ("NP.df_utils" )
@@ -141,6 +141,7 @@ def data_params_definition(
141
141
config_lagged_regressors : Optional [ConfigLaggedRegressors ] = None ,
142
142
config_regressors = None ,
143
143
config_events : Optional [ConfigEvents ] = None ,
144
+ config_seasonality : Optional [ConfigSeasonality ] = None ,
144
145
):
145
146
"""
146
147
Initialize data scaling values.
@@ -178,6 +179,8 @@ def data_params_definition(
178
179
extra regressors (with known future values) with sub_parameters normalize (bool)
179
180
config_events : configure.ConfigEvents
180
181
user specified events configs
182
+ config_seasonality : configure.ConfigSeasonality
183
+ user specified seasonality configs
181
184
182
185
Returns
183
186
-------
@@ -221,6 +224,13 @@ def data_params_definition(
221
224
if event not in df .columns :
222
225
raise ValueError (f"Event { event } not found in DataFrame." )
223
226
data_params [event ] = ShiftScale ()
227
+ if config_seasonality is not None :
228
+ for season in config_seasonality .periods :
229
+ condition_name = config_seasonality .periods [season ].condition_name
230
+ if condition_name is not None :
231
+ if condition_name not in df .columns :
232
+ raise ValueError (f"Seasonality condition { condition_name } not found in DataFrame." )
233
+ data_params [condition_name ] = ShiftScale ()
224
234
return data_params
225
235
226
236
@@ -230,6 +240,7 @@ def init_data_params(
230
240
config_lagged_regressors : Optional [ConfigLaggedRegressors ] = None ,
231
241
config_regressors = None ,
232
242
config_events : Optional [ConfigEvents ] = None ,
243
+ config_seasonality : Optional [ConfigSeasonality ] = None ,
233
244
global_normalization = False ,
234
245
global_time_normalization = False ,
235
246
):
@@ -265,6 +276,8 @@ def init_data_params(
265
276
extra regressors (with known future values)
266
277
config_events : configure.ConfigEvents
267
278
user specified events configs
279
+ config_seasonality : configure.ConfigSeasonality
280
+ user specified seasonality configs
268
281
global_normalization : bool
269
282
270
283
``True``: sets global modeling training with global normalization
@@ -289,7 +302,7 @@ def init_data_params(
289
302
df , _ , _ , _ = prep_or_copy_df (df )
290
303
df_merged = df .copy (deep = True ).drop ("ID" , axis = 1 )
291
304
global_data_params = data_params_definition (
292
- df_merged , normalize , config_lagged_regressors , config_regressors , config_events
305
+ df_merged , normalize , config_lagged_regressors , config_regressors , config_events , config_seasonality
293
306
)
294
307
if global_normalization :
295
308
log .debug (
@@ -300,7 +313,7 @@ def init_data_params(
300
313
for df_name , df_i in df .groupby ("ID" ):
301
314
df_i .drop ("ID" , axis = 1 , inplace = True )
302
315
local_data_params [df_name ] = data_params_definition (
303
- df_i , normalize , config_lagged_regressors , config_regressors , config_events
316
+ df_i , normalize , config_lagged_regressors , config_regressors , config_events , config_seasonality
304
317
)
305
318
if global_time_normalization :
306
319
# Overwrite local time normalization data_params with global values (pointer)
@@ -387,7 +400,7 @@ def normalize(df, data_params):
387
400
return df
388
401
389
402
390
- def check_single_dataframe (df , check_y , covariates , regressors , events ):
403
+ def check_single_dataframe (df , check_y , covariates , regressors , events , seasonalities ):
391
404
"""Performs basic data sanity checks and ordering
392
405
as well as prepare dataframe for fitting or predicting.
393
406
@@ -403,6 +416,8 @@ def check_single_dataframe(df, check_y, covariates, regressors, events):
403
416
regressor column names
404
417
events : list or dict
405
418
event column names
419
+ seasonalities : list or dict
420
+ seasonalities column names
406
421
407
422
Returns
408
423
-------
@@ -451,6 +466,13 @@ def check_single_dataframe(df, check_y, covariates, regressors, events):
451
466
columns .extend (events )
452
467
else : # treat as dict
453
468
columns .extend (events .keys ())
469
+ if seasonalities is not None :
470
+ for season in seasonalities .periods :
471
+ condition_name = seasonalities .periods [season ].condition_name
472
+ if condition_name is not None :
473
+ if not df [condition_name ].isin ([True , False ]).all () and not df [condition_name ].between (0 , 1 ).all ():
474
+ raise ValueError (f"Condition column { condition_name } must be boolean or numeric between 0 and 1." )
475
+ columns .append (condition_name )
454
476
for name in columns :
455
477
if name not in df :
456
478
raise ValueError (f"Column { name !r} missing from dataframe" )
@@ -470,7 +492,7 @@ def check_single_dataframe(df, check_y, covariates, regressors, events):
470
492
return df , regressors_to_remove
471
493
472
494
473
- def check_dataframe (df , check_y = True , covariates = None , regressors = None , events = None ):
495
+ def check_dataframe (df , check_y = True , covariates = None , regressors = None , events = None , seasonalities = None ):
474
496
"""Performs basic data sanity checks and ordering,
475
497
as well as prepare dataframe for fitting or predicting.
476
498
@@ -487,6 +509,8 @@ def check_dataframe(df, check_y=True, covariates=None, regressors=None, events=N
487
509
regressor column names
488
510
events : list or dict
489
511
event column names
512
+ seasonalities : list or dict
513
+ seasonalities column names
490
514
491
515
Returns
492
516
-------
@@ -497,7 +521,7 @@ def check_dataframe(df, check_y=True, covariates=None, regressors=None, events=N
497
521
checked_df = pd .DataFrame ()
498
522
regressors_to_remove = []
499
523
for df_name , df_i in df .groupby ("ID" ):
500
- df_aux , reg = check_single_dataframe (df_i , check_y , covariates , regressors , events )
524
+ df_aux , reg = check_single_dataframe (df_i , check_y , covariates , regressors , events , seasonalities )
501
525
df_aux = df_aux .copy (deep = True )
502
526
if len (reg ) > 0 :
503
527
regressors_to_remove .append (* reg )
0 commit comments