azrai99 commited on
Commit
e77081c
·
verified ·
1 Parent(s): ea639d2

Update src/nf.py

Browse files
Files changed (1) hide show
  1. src/nf.py +7 -2
src/nf.py CHANGED
@@ -12,6 +12,7 @@ from neuralforecast.core import NeuralForecast
12
  from neuralforecast.auto import NHITS as autoNHITS
13
  from neuralforecast.tsdataset import TimeSeriesDataset, TimeSeriesLoader
14
  from neuralforecast.models import NHITS
 
15
 
16
  # GLOBAL PARAMETERS
17
  DEFAULT_HORIZON = 30
@@ -133,9 +134,13 @@ def forecast_pretrained_model(Y_df: pd.DataFrame, model: str, fh: int, max_steps
133
  scaler.fit(Y_df)
134
  Y_df = scaler.transform(Y_df)
135
 
136
- # Model
137
  file_ = f"./models/{model}.ckpt"
138
- nhits = NHITS.load_from_checkpoint(file_)
 
 
 
 
139
 
140
  # Fit
141
  if max_steps > 0:
 
12
  from neuralforecast.auto import NHITS as autoNHITS
13
  from neuralforecast.tsdataset import TimeSeriesDataset, TimeSeriesLoader
14
  from neuralforecast.models import NHITS
15
+ import torch
16
 
17
  # GLOBAL PARAMETERS
18
  DEFAULT_HORIZON = 30
 
134
  scaler.fit(Y_df)
135
  Y_df = scaler.transform(Y_df)
136
 
137
+ # Load the checkpoint and initialize NHITS with required parameters
138
  file_ = f"./models/{model}.ckpt"
139
+ checkpoint = torch.load(file_)
140
+ h = checkpoint['hyper_parameters']['h']
141
+ input_size = checkpoint['hyper_parameters']['input_size']
142
+ nhits = NHITS(h=h, input_size=input_size)
143
+ nhits.load_state_dict(checkpoint['state_dict'])
144
 
145
  # Fit
146
  if max_steps > 0: