Spaces:
Runtime error
Runtime error
model path changed update
Browse files- src/model.py +4 -6
src/model.py
CHANGED
@@ -8,7 +8,7 @@ import torch
|
|
8 |
from prophet.serialize import model_to_json, model_from_json
|
9 |
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
|
10 |
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
|
11 |
-
|
12 |
# at beginning of the script
|
13 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
|
@@ -38,11 +38,9 @@ class Model_Load:
|
|
38 |
def store_model_load(self,model_option):
|
39 |
if model_option=='TFT':
|
40 |
# best_model_path="models/store_item_10_lead_1_v2/lightning_logs/lightning_logs/version_2/checkpoints/epoch=7-step=4472.ckpt"
|
41 |
-
best_model_path="models/store_item_10_lead_1_v3/lightning_logs/lightning_logs/version_0/checkpoints/epoch=7-step=4472.ckpt"
|
42 |
-
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
|
43 |
-
|
44 |
-
# best_tft.load_state_dict(torch.load(best_model_path,map_location=torch.device('cpu')))
|
45 |
-
# best_tft.to('cpu')
|
46 |
print('Model Load Sucessfully.')
|
47 |
return best_tft
|
48 |
elif model_option=='Prophet':
|
|
|
8 |
from prophet.serialize import model_to_json, model_from_json
|
9 |
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
|
10 |
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
|
11 |
+
import pickle
|
12 |
# at beginning of the script
|
13 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
|
|
|
38 |
def store_model_load(self,model_option):
|
39 |
if model_option=='TFT':
|
40 |
# best_model_path="models/store_item_10_lead_1_v2/lightning_logs/lightning_logs/version_2/checkpoints/epoch=7-step=4472.ckpt"
|
41 |
+
# best_model_path="models/store_item_10_lead_1_v3/lightning_logs/lightning_logs/version_0/checkpoints/epoch=7-step=4472.ckpt"
|
42 |
+
# best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
|
43 |
+
best_tft=pickle.load(open("models/cpu_finalized_model_v1.sav", 'rb'))
|
|
|
|
|
44 |
print('Model Load Sucessfully.')
|
45 |
return best_tft
|
46 |
elif model_option=='Prophet':
|