|
import copy |
|
from pathlib import Path |
|
import warnings |
|
import lightning.pytorch as pl |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from prophet.serialize import model_to_json, model_from_json |
|
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet |
|
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters |
|
import pickle |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
class Model_Load: |
|
def __init__(self): |
|
pass |
|
def energy_model_load(self,model_option): |
|
if model_option=='TFT': |
|
best_model_path='models/consumer_final_10/lightning_logs/lightning_logs/version_0/checkpoints/epoch=5-step=49260.ckpt' |
|
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path) |
|
print('Model Load Sucessfully.') |
|
return best_tft |
|
elif model_option=='Prophet': |
|
best_model_path='models/fb_energy_model.json' |
|
with open(best_model_path, 'r') as fin: |
|
model = model_from_json(fin.read()) |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def store_model_load(self,model_option): |
|
if model_option=='TFT': |
|
|
|
|
|
|
|
best_tft=pickle.load(open("models/cpu_finalized_model_v1.sav", 'rb')) |
|
print('Model Load Sucessfully.') |
|
return best_tft |
|
elif model_option=='Prophet': |
|
best_model_path='models/fb_store_model_new.json' |
|
with open(best_model_path, 'r') as fin: |
|
model = model_from_json(fin.read()) |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__=='__main__': |
|
obj=Model_Load() |
|
obj.load() |
|
|
|
|