Sanket45's picture
model path changed update
dbfefc3
raw
history blame
2.62 kB
import copy
from pathlib import Path
import warnings
import lightning.pytorch as pl
import numpy as np
import pandas as pd
import torch
from prophet.serialize import model_to_json, model_from_json
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
import pickle
# at beginning of the script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Model_Load:
def __init__(self):
pass
def energy_model_load(self,model_option):
if model_option=='TFT':
best_model_path='models/consumer_final_10/lightning_logs/lightning_logs/version_0/checkpoints/epoch=5-step=49260.ckpt'
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
print('Model Load Sucessfully.')
return best_tft
elif model_option=='Prophet':
best_model_path='models/fb_energy_model.json'
with open(best_model_path, 'r') as fin:
model = model_from_json(fin.read())
return model
# elif model_option=='ten consumer':
# best_model_path='consumer_10/lightning_logs/lightning_logs/version_0/checkpoints/epoch=11-step=98544.ckpt'
# best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
# print('Model Load Sucessfully.')
# elif model_option=='fifty consumer':
# raise Exception('Model not present')
def store_model_load(self,model_option):
if model_option=='TFT':
# best_model_path="models/store_item_10_lead_1_v2/lightning_logs/lightning_logs/version_2/checkpoints/epoch=7-step=4472.ckpt"
# best_model_path="models/store_item_10_lead_1_v3/lightning_logs/lightning_logs/version_0/checkpoints/epoch=7-step=4472.ckpt"
# best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
best_tft=pickle.load(open("models/cpu_finalized_model_v1.sav", 'rb'))
print('Model Load Sucessfully.')
return best_tft
elif model_option=='Prophet':
best_model_path='models/fb_store_model_new.json'
with open(best_model_path, 'r') as fin:
model = model_from_json(fin.read())
return model
# elif model_option=='Item 50 TFT':
# raise Exception('Model not present')
# elif model_option=='FB Prophet':
# raise Exception('Model not present')
if __name__=='__main__':
obj=Model_Load()
obj.load()