import copy from pathlib import Path import warnings import lightning.pytorch as pl import numpy as np import pandas as pd import torch from prophet.serialize import model_to_json, model_from_json from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters # at beginning of the script device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class Model_Load: def __init__(self): pass def energy_model_load(self,model_option): if model_option=='TFT': best_model_path='models/consumer_final_10/lightning_logs/lightning_logs/version_0/checkpoints/epoch=5-step=49260.ckpt' best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path) print('Model Load Sucessfully.') return best_tft elif model_option=='Prophet': best_model_path='models/fb_energy_model.json' with open(best_model_path, 'r') as fin: model = model_from_json(fin.read()) return model # elif model_option=='ten consumer': # best_model_path='consumer_10/lightning_logs/lightning_logs/version_0/checkpoints/epoch=11-step=98544.ckpt' # best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path) # print('Model Load Sucessfully.') # elif model_option=='fifty consumer': # raise Exception('Model not present') def store_model_load(self,model_option): if model_option=='TFT': # best_model_path="models/store_item_10_lead_1_v2/lightning_logs/lightning_logs/version_2/checkpoints/epoch=7-step=4472.ckpt" best_model_path="models/store_item_10_lead_1_v3/lightning_logs/lightning_logs/version_0/checkpoints/epoch=7-step=4472.ckpt" best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path) # best_tft = TemporalFusionTransformer() # best_tft.load_state_dict(torch.load(best_model_path,map_location=torch.device('cpu'))) # best_tft.to('cpu') print('Model Load Sucessfully.') return best_tft elif model_option=='Prophet': best_model_path='models/fb_store_model_new.json' with open(best_model_path, 'r') as fin: model = model_from_json(fin.read()) return model # elif model_option=='Item 50 TFT': # raise Exception('Model not present') # elif model_option=='FB Prophet': # raise Exception('Model not present') if __name__=='__main__': obj=Model_Load() obj.load()