azrai99 commited on
Commit
31040fb
·
verified ·
1 Parent(s): e77081c

Update src/nf.py

Browse files
Files changed (1) hide show
  1. src/nf.py +1 -5
src/nf.py CHANGED
@@ -136,11 +136,7 @@ def forecast_pretrained_model(Y_df: pd.DataFrame, model: str, fh: int, max_steps
136
 
137
  # Load the checkpoint and initialize NHITS with required parameters
138
  file_ = f"./models/{model}.ckpt"
139
- checkpoint = torch.load(file_)
140
- h = checkpoint['hyper_parameters']['h']
141
- input_size = checkpoint['hyper_parameters']['input_size']
142
- nhits = NHITS(h=h, input_size=input_size)
143
- nhits.load_state_dict(checkpoint['state_dict'])
144
 
145
  # Fit
146
  if max_steps > 0:
 
136
 
137
  # Load the checkpoint and initialize NHITS with required parameters
138
  file_ = f"./models/{model}.ckpt"
139
+ nhits = NeuralForecast.load_from_checkpoint(file_)
 
 
 
 
140
 
141
  # Fit
142
  if max_steps > 0: