Spaces:
Sleeping
Sleeping
Update src/nf.py
Browse files
src/nf.py
CHANGED
@@ -12,6 +12,7 @@ from neuralforecast.core import NeuralForecast
|
|
12 |
from neuralforecast.auto import NHITS as autoNHITS
|
13 |
from neuralforecast.tsdataset import TimeSeriesDataset, TimeSeriesLoader
|
14 |
from neuralforecast.models import NHITS
|
|
|
15 |
|
16 |
# GLOBAL PARAMETERS
|
17 |
DEFAULT_HORIZON = 30
|
@@ -133,9 +134,13 @@ def forecast_pretrained_model(Y_df: pd.DataFrame, model: str, fh: int, max_steps
|
|
133 |
scaler.fit(Y_df)
|
134 |
Y_df = scaler.transform(Y_df)
|
135 |
|
136 |
-
#
|
137 |
file_ = f"./models/{model}.ckpt"
|
138 |
-
|
|
|
|
|
|
|
|
|
139 |
|
140 |
# Fit
|
141 |
if max_steps > 0:
|
|
|
12 |
from neuralforecast.auto import NHITS as autoNHITS
|
13 |
from neuralforecast.tsdataset import TimeSeriesDataset, TimeSeriesLoader
|
14 |
from neuralforecast.models import NHITS
|
15 |
+
import torch
|
16 |
|
17 |
# GLOBAL PARAMETERS
|
18 |
DEFAULT_HORIZON = 30
|
|
|
134 |
scaler.fit(Y_df)
|
135 |
Y_df = scaler.transform(Y_df)
|
136 |
|
137 |
+
# Load the checkpoint and initialize NHITS with required parameters
|
138 |
file_ = f"./models/{model}.ckpt"
|
139 |
+
checkpoint = torch.load(file_)
|
140 |
+
h = checkpoint['hyper_parameters']['h']
|
141 |
+
input_size = checkpoint['hyper_parameters']['input_size']
|
142 |
+
nhits = NHITS(h=h, input_size=input_size)
|
143 |
+
nhits.load_state_dict(checkpoint['state_dict'])
|
144 |
|
145 |
# Fit
|
146 |
if max_steps > 0:
|