Spaces:
Runtime error
Runtime error
import copy | |
from pathlib import Path | |
import warnings | |
import lightning.pytorch as pl | |
import numpy as np | |
import pandas as pd | |
import torch | |
from prophet.serialize import model_to_json, model_from_json | |
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet | |
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters | |
# at beginning of the script | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
class Model_Load: | |
def __init__(self): | |
pass | |
def energy_model_load(self,model_option): | |
if model_option=='TFT': | |
best_model_path='models/consumer_final_10/lightning_logs/lightning_logs/version_0/checkpoints/epoch=5-step=49260.ckpt' | |
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path) | |
print('Model Load Sucessfully.') | |
return best_tft | |
elif model_option=='Prophet': | |
best_model_path='models/fb_energy_model.json' | |
with open(best_model_path, 'r') as fin: | |
model = model_from_json(fin.read()) | |
return model | |
# elif model_option=='ten consumer': | |
# best_model_path='consumer_10/lightning_logs/lightning_logs/version_0/checkpoints/epoch=11-step=98544.ckpt' | |
# best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path) | |
# print('Model Load Sucessfully.') | |
# elif model_option=='fifty consumer': | |
# raise Exception('Model not present') | |
def store_model_load(self,model_option): | |
if model_option=='TFT': | |
# best_model_path="models/store_item_10_lead_1_v2/lightning_logs/lightning_logs/version_2/checkpoints/epoch=7-step=4472.ckpt" | |
best_model_path="models/store_item_10_lead_1_v3/lightning_logs/lightning_logs/version_0/checkpoints/epoch=7-step=4472.ckpt" | |
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path) | |
# best_tft = TemporalFusionTransformer() | |
# best_tft.load_state_dict(torch.load(best_model_path,map_location=torch.device('cpu'))) | |
# best_tft.to('cpu') | |
print('Model Load Sucessfully.') | |
return best_tft | |
elif model_option=='Prophet': | |
best_model_path='models/fb_store_model_new.json' | |
with open(best_model_path, 'r') as fin: | |
model = model_from_json(fin.read()) | |
return model | |
# elif model_option=='Item 50 TFT': | |
# raise Exception('Model not present') | |
# elif model_option=='FB Prophet': | |
# raise Exception('Model not present') | |
if __name__=='__main__': | |
obj=Model_Load() | |
obj.load() | |