azrai99 commited on
Commit
6c651f4
·
verified ·
1 Parent(s): 676dd2d

Update src/nf.py

Browse files
Files changed (1) hide show
  1. src/nf.py +12 -36
src/nf.py CHANGED
@@ -8,11 +8,10 @@ import pandas as pd
8
  import pytorch_lightning as pl
9
  from datasetsforecast.utils import download_file
10
  from hyperopt import hp
 
11
  from neuralforecast.auto import NHITS as autoNHITS
12
- from neuralforecast.data.tsdataset import WindowsDataset
13
- from neuralforecast.data.tsloader import TimeSeriesLoader
14
- from neuralforecast.models.mqnhits.mqnhits import MQNHITS
15
- from neuralforecast.models.nhits.nhits import NHITS
16
 
17
  # GLOBAL PARAMETERS
18
  DEFAULT_HORIZON = 30
@@ -73,7 +72,6 @@ MODELS = {
73
  },
74
  }
75
 
76
-
77
  def download_models():
78
  for _, meta in MODELS.items():
79
  if not Path(f'./models/{meta["model"]}.ckpt').is_file():
@@ -82,19 +80,17 @@ def download_models():
82
  f'https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/{meta["model"]}.ckpt',
83
  )
84
 
85
-
86
  download_models()
87
 
88
-
89
  class StandardScaler:
90
  """This class helps to standardize a dataframe with multiple time series."""
91
-
92
  def __init__(self):
93
- self.norm: pd.DataFrame
94
 
95
  def fit(self, X: pd.DataFrame) -> "StandardScaler":
96
  self.norm = X.groupby("unique_id").agg({"y": [np.mean, np.std]})
97
  self.norm = self.norm.droplevel(0, 1).reset_index()
 
98
 
99
  def transform(self, X: pd.DataFrame) -> pd.DataFrame:
100
  transformed = X.merge(self.norm, how="left", on=["unique_id"])
@@ -109,7 +105,6 @@ class StandardScaler:
109
  )
110
  return transformed[["unique_id", "ds"] + cols]
111
 
112
-
113
  def compute_ds_future(Y_df, fh):
114
  if Y_df["unique_id"].nunique() == 1:
115
  ds_ = pd.to_datetime(Y_df["ds"].values)
@@ -130,10 +125,7 @@ def compute_ds_future(Y_df, fh):
130
  )
131
  return list(ds_future)
132
 
133
-
134
- def forecast_pretrained_model(
135
- Y_df: pd.DataFrame, model: str, fh: int, max_steps: int = 0
136
- ):
137
  if "unique_id" not in Y_df:
138
  Y_df.insert(0, "unique_id", "ts_1")
139
 
@@ -143,26 +135,12 @@ def forecast_pretrained_model(
143
 
144
  # Model
145
  file_ = f"./models/{model}.ckpt"
146
- mqnhits = MQNHITS.load_from_checkpoint(file_)
147
 
148
  # Fit
149
  if max_steps > 0:
150
- train_dataset = WindowsDataset(
151
- Y_df=Y_df,
152
- X_df=None,
153
- S_df=None,
154
- mask_df=None,
155
- f_cols=[],
156
- input_size=mqnhits.n_time_in,
157
- output_size=mqnhits.n_time_out,
158
- sample_freq=1,
159
- complete_windows=True,
160
- verbose=False,
161
- )
162
-
163
- train_loader = TimeSeriesLoader(
164
- dataset=train_dataset, batch_size=1, n_windows=32, shuffle=True
165
- )
166
 
167
  trainer = pl.Trainer(
168
  max_epochs=None,
@@ -174,13 +152,12 @@ def forecast_pretrained_model(
174
  log_every_n_steps=1,
175
  )
176
 
177
- trainer.fit(mqnhits, train_loader)
178
 
179
  # Forecast
180
- forecast_df = mqnhits.forecast(Y_df=Y_df)
181
  forecast_df = scaler.inverse_transform(forecast_df, cols=["y_5", "y_50", "y_95"])
182
 
183
- # Foreoast
184
  n_ts = forecast_df["unique_id"].nunique()
185
  if fh * n_ts > len(forecast_df):
186
  forecast_df = (
@@ -194,7 +171,6 @@ def forecast_pretrained_model(
194
 
195
  return forecast_df
196
 
197
-
198
  if __name__ == "__main__":
199
  df = pd.read_csv(
200
  "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/ercot_COAST.csv"
@@ -208,4 +184,4 @@ if __name__ == "__main__":
208
  assert forecast.shape == (80, 5)
209
  # test multiple time series
210
  multi_forecast = forecast_pretrained_model(multi_df, model=meta["model"], fh=80)
211
- assert multi_forecast.shape == (80 * 2, 5)
 
8
  import pytorch_lightning as pl
9
  from datasetsforecast.utils import download_file
10
  from hyperopt import hp
11
+ from neuralforecast.core import NeuralForecast
12
  from neuralforecast.auto import NHITS as autoNHITS
13
+ from neuralforecast.tsdataset import TimeSeriesDataset, TimeSeriesLoader
14
+ from neuralforecast.models import NHITS
 
 
15
 
16
  # GLOBAL PARAMETERS
17
  DEFAULT_HORIZON = 30
 
72
  },
73
  }
74
 
 
75
  def download_models():
76
  for _, meta in MODELS.items():
77
  if not Path(f'./models/{meta["model"]}.ckpt').is_file():
 
80
  f'https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/{meta["model"]}.ckpt',
81
  )
82
 
 
83
  download_models()
84
 
 
85
  class StandardScaler:
86
  """This class helps to standardize a dataframe with multiple time series."""
 
87
  def __init__(self):
88
+ self.norm: pd.DataFrame = None
89
 
90
  def fit(self, X: pd.DataFrame) -> "StandardScaler":
91
  self.norm = X.groupby("unique_id").agg({"y": [np.mean, np.std]})
92
  self.norm = self.norm.droplevel(0, 1).reset_index()
93
+ return self
94
 
95
  def transform(self, X: pd.DataFrame) -> pd.DataFrame:
96
  transformed = X.merge(self.norm, how="left", on=["unique_id"])
 
105
  )
106
  return transformed[["unique_id", "ds"] + cols]
107
 
 
108
  def compute_ds_future(Y_df, fh):
109
  if Y_df["unique_id"].nunique() == 1:
110
  ds_ = pd.to_datetime(Y_df["ds"].values)
 
125
  )
126
  return list(ds_future)
127
 
128
+ def forecast_pretrained_model(Y_df: pd.DataFrame, model: str, fh: int, max_steps: int = 0):
 
 
 
129
  if "unique_id" not in Y_df:
130
  Y_df.insert(0, "unique_id", "ts_1")
131
 
 
135
 
136
  # Model
137
  file_ = f"./models/{model}.ckpt"
138
+ nhits = NHITS.load_from_checkpoint(file_)
139
 
140
  # Fit
141
  if max_steps > 0:
142
+ train_dataset = TimeSeriesDataset.from_dataframe(Y_df, input_size=nhits.hparams.n_time_in, output_size=nhits.hparams.n_time_out)
143
+ train_loader = TimeSeriesLoader(dataset=train_dataset, batch_size=1, n_windows=32, shuffle=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  trainer = pl.Trainer(
146
  max_epochs=None,
 
152
  log_every_n_steps=1,
153
  )
154
 
155
+ trainer.fit(nhits, train_loader)
156
 
157
  # Forecast
158
+ forecast_df = nhits.forecast(Y_df=Y_df)
159
  forecast_df = scaler.inverse_transform(forecast_df, cols=["y_5", "y_50", "y_95"])
160
 
 
161
  n_ts = forecast_df["unique_id"].nunique()
162
  if fh * n_ts > len(forecast_df):
163
  forecast_df = (
 
171
 
172
  return forecast_df
173
 
 
174
  if __name__ == "__main__":
175
  df = pd.read_csv(
176
  "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/ercot_COAST.csv"
 
184
  assert forecast.shape == (80, 5)
185
  # test multiple time series
186
  multi_forecast = forecast_pretrained_model(multi_df, model=meta["model"], fh=80)
187
+ assert multi_forecast.shape == (80 * 2, 5)