Sanket45 commited on
Commit
dbfefc3
·
1 Parent(s): 5658340

model path changed update

Browse files
Files changed (1) hide show
  1. src/model.py +4 -6
src/model.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  from prophet.serialize import model_to_json, model_from_json
9
  from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
10
  from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
11
-
12
  # at beginning of the script
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
 
@@ -38,11 +38,9 @@ class Model_Load:
38
  def store_model_load(self,model_option):
39
  if model_option=='TFT':
40
  # best_model_path="models/store_item_10_lead_1_v2/lightning_logs/lightning_logs/version_2/checkpoints/epoch=7-step=4472.ckpt"
41
- best_model_path="models/store_item_10_lead_1_v3/lightning_logs/lightning_logs/version_0/checkpoints/epoch=7-step=4472.ckpt"
42
- best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
43
- # best_tft = TemporalFusionTransformer()
44
- # best_tft.load_state_dict(torch.load(best_model_path,map_location=torch.device('cpu')))
45
- # best_tft.to('cpu')
46
  print('Model Load Sucessfully.')
47
  return best_tft
48
  elif model_option=='Prophet':
 
8
  from prophet.serialize import model_to_json, model_from_json
9
  from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
10
  from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
11
+ import pickle
12
  # at beginning of the script
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
 
 
38
  def store_model_load(self,model_option):
39
  if model_option=='TFT':
40
  # best_model_path="models/store_item_10_lead_1_v2/lightning_logs/lightning_logs/version_2/checkpoints/epoch=7-step=4472.ckpt"
41
+ # best_model_path="models/store_item_10_lead_1_v3/lightning_logs/lightning_logs/version_0/checkpoints/epoch=7-step=4472.ckpt"
42
+ # best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
43
+ best_tft=pickle.load(open("models/cpu_finalized_model_v1.sav", 'rb'))
 
 
44
  print('Model Load Sucessfully.')
45
  return best_tft
46
  elif model_option=='Prophet':