azrai99's picture
Update app.py
971fffe verified
raw
history blame
15.6 kB
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from neuralforecast.core import NeuralForecast
from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
from neuralforecast.losses.pytorch import HuberMQLoss
from neuralforecast.utils import AirPassengersDF
import time
from st_aggrid import AgGrid
from nixtla import NixtlaClient
import os
@st.cache_resource
def load_model(path, freq):
nf = NeuralForecast.load(path=path)
return nf
@st.cache_resource
def load_all_models():
nhits_paths = {
'D': './M4/NHITS/daily',
'M': './M4/NHITS/monthly',
'H': './M4/NHITS/hourly',
'W': './M4/NHITS/weekly',
'Y': './M4/NHITS/yearly'
}
timesnet_paths = {
'D': './M4/TimesNet/daily',
'M': './M4/TimesNet/monthly',
'H': './M4/TimesNet/hourly',
'W': './M4/TimesNet/weekly',
'Y': './M4/TimesNet/yearly'
}
lstm_paths = {
'D': './M4/LSTM/daily',
'M': './M4/LSTM/monthly',
'H': './M4/LSTM/hourly',
'W': './M4/LSTM/weekly',
'Y': './M4/LSTM/yearly'
}
tft_paths = {
'D': './M4/TFT/daily',
'M': './M4/TFT/monthly',
'H': './M4/TFT/hourly',
'W': './M4/TFT/weekly',
'Y': './M4/TFT/yearly'
}
nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()}
timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()}
lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()}
tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()}
return nhits_models, timesnet_models, lstm_models, tft_models
def generate_forecast(model, df,tag=False):
if tag == 'retrain':
forecast_df = model.predict()
else:
forecast_df = model.predict(df=df)
return forecast_df
def determine_frequency(df):
df['ds'] = pd.to_datetime(df['ds'])
df = df.set_index('ds')
freq = pd.infer_freq(df.index)
return freq
def plot_forecasts(forecast_df, train_df, title):
fig, ax = plt.subplots(1, 1, figsize=(20, 7))
plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
historical_col = 'y'
forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
if forecast_col is None:
raise KeyError("No forecast column found in the data.")
plot_df[[historical_col, forecast_col]].plot(ax=ax, linewidth=2, label=['Historical', 'Forecast'])
if lo_col and hi_col:
ax.fill_between(
plot_df.index,
plot_df[lo_col],
plot_df[hi_col],
color='blue',
alpha=0.3,
label='90% Confidence Interval'
)
ax.set_title(title, fontsize=22)
ax.set_ylabel('Value', fontsize=20)
ax.set_xlabel('Timestamp [t]', fontsize=20)
ax.legend(prop={'size': 15})
ax.grid()
st.pyplot(fig)
def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models):
if freq == 'D':
return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D']
elif freq == 'ME':
return nhits_models['M'], timesnet_models['M'], lstm_models['M'], tft_models['M']
elif freq == 'H':
return nhits_models['H'], timesnet_models['H'], lstm_models['H'], tft_models['H']
elif freq in ['W', 'W-SUN']:
return nhits_models['W'], timesnet_models['W'], lstm_models['W'], tft_models['W']
elif freq in ['Y', 'Y-DEC']:
return nhits_models['Y'], timesnet_models['Y'], lstm_models['Y'], tft_models['Y']
else:
raise ValueError(f"Unsupported frequency: {freq}")
def select_model(horizon, model_type, max_steps=50):
if model_type == 'NHITS':
return NHITS(input_size=5 * horizon,
h=horizon,
max_steps=max_steps,
stack_types=3*['identity'],
n_blocks=3*[1],
mlp_units=[[256, 256] for _ in range(3)],
n_pool_kernel_size=3*[1],
batch_size=32,
scaler_type='standard',
n_freq_downsample=[12, 4, 1],
loss=HuberMQLoss(level=[90]))
elif model_type == 'TimesNet':
return TimesNet(h=horizon,
input_size=horizon * 5,
hidden_size=16,
conv_hidden_size=32,
loss=HuberMQLoss(level=[90]),
scaler_type='standard',
learning_rate=1e-3,
max_steps=max_steps,
val_check_steps=200,
valid_batch_size=64,
windows_batch_size=128,
inference_windows_batch_size=512)
elif model_type == 'LSTM':
return LSTM(h=horizon,
input_size=horizon * 5,
loss=HuberMQLoss(level=[90]),
scaler_type='standard',
encoder_n_layers=2,
encoder_hidden_size=64,
context_size=10,
decoder_hidden_size=64,
decoder_layers=2,
max_steps=max_steps)
elif model_type == 'TFT':
return TFT(h=horizon,
input_size=horizon,
hidden_size=16,
loss=HuberMQLoss(level=[90]),
learning_rate=0.005,
scaler_type='standard',
windows_batch_size=128,
max_steps=max_steps,
val_check_steps=200,
valid_batch_size=64,
enable_progress_bar=True)
else:
raise ValueError(f"Unsupported model type: {model_type}")
def model_train(df,model, freq):
nf = NeuralForecast(models=[model], freq=freq)
df['ds'] = pd.to_datetime(df['ds'])
nf.fit(df)
return nf
def forecast_time_series(df, model_type, horizon, max_steps,y_col):
start_time = time.time() # Start timing
freq = determine_frequency(df)
st.sidebar.write(f"Data frequency: {freq}")
selected_model = select_model(horizon, model_type, max_steps)
model = model_train(df, selected_model,freq)
forecast_results = {}
st.sidebar.write(f"Generating forecast using {model_type} model...")
forecast_results[model_type] = generate_forecast(model, df, tag='retrain')
for model_name, forecast_df in forecast_results.items():
plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
end_time = time.time() # End timing
time_taken = end_time - start_time
st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
@st.cache_data
def load_default():
df = AirPassengersDF.copy()
return df
def transfer_learning_forecasting():
st.title("Transfer Learning Forecasting")
nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
if uploaded_file:
df = pd.read_csv(uploaded_file)
st.session_state.df = df
else:
df = load_default()
st.session_state.df = df
# Column selection
columns = df.columns.tolist() # Convert Index to list
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
y_col = st.selectbox("Select Target column", options=columns,, index=columns.index('ds') if 'ds' in columns else 0)
# unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")
st.session_state.ds_col = ds_col
st.session_state.y_col = y_col
# Model selection and forecasting
st.sidebar.subheader("Model Selection and Forecasting")
model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
horizon = st.sidebar.number_input("Forecast horizon", value=18)
df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
df['unique_id']=1
df = df[['unique_id','ds','y']]
st.session_state.df = df
# Determine frequency of data
frequency = determine_frequency(df)
st.sidebar.write(f"Detected frequency: {frequency}")
col1, col2 = st.columns([2,4])
with col1:
tab_insample, tab_forecast = st.tabs(
["Input data", "Forecast"]
)
with tab_insample:
df_grid = df.drop(columns="unique_id")
st.write(df_grid)
# grid_table = AgGrid(
# df_grid,
# theme="alpine",
# )
with tab_forecast:
df_grid = df.drop(columns="unique_id")
# grid_table = AgGrid(
# df_grid,
# theme="alpine",
# )
with col2:
# Load pre-trained models
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
forecast_results = {}
if st.sidebar.button("Submit"):
start_time = time.time() # Start timing
if model_choice == "NHITS":
forecast_results['NHITS'] = generate_forecast(nhits_model, df)
elif model_choice == "TimesNet":
forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
elif model_choice == "LSTM":
forecast_results['LSTM'] = generate_forecast(lstm_model, df)
elif model_choice == "TFT":
forecast_results['TFT'] = generate_forecast(tft_model, df)
for model_name, forecast_df in forecast_results.items():
plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
end_time = time.time() # End timing
time_taken = end_time - start_time
st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
def dynamic_forecasting():
st.title("Dynamic Forecasting")
st.subheader("Speed depends on CPU/GPU availability", divider="gray")
with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
if uploaded_file:
df = pd.read_csv(uploaded_file)
st.session_state.df = df
else:
df = load_default()
st.session_state.df = df
# Column selection
columns = df.columns.tolist() # Convert Index to list
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
y_col = st.selectbox("Select Target column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
# unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")
df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
df['unique_id']=1
df = df[['unique_id','ds','y']]
st.session_state.df = df
st.session_state.ds_col = ds_col
st.session_state.y_col = y_col
# Dynamic forecasting
st.sidebar.subheader("Dynamic Model Selection and Forecasting")
dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
dynamic_max_steps = st.sidebar.number_input('Max steps', value=10)
if st.sidebar.button("Submit"):
forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
def timegpt_fcst():
nixtla_token = os.environ.get("NIXTLA_API_KEY")
nixtla_client = NixtlaClient(
api_key = api_key
)
st.title("TimeGPT Forecasting")
with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
if uploaded_file:
df = pd.read_csv(uploaded_file)
st.session_state.df = df
else:
df = load_default()
st.session_state.df = df
# Column selection
columns = df.columns.tolist() # Convert Index to list
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
y_col = st.selectbox("Select Target column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
id_col = 'ts_test'
df['unique_id']=id_col
freq = determine_frequency(df)
forecast_df = nixtla_client.forecast(
df=df,
h=7,
freq=freq,
level=[90]
)
nixtla_client.plot(
forecast_df,
level=[90],
max_insample_length=365
)
def timegpt_anom():
nixtla_token = os.environ.get("NIXTLA_API_KEY")
nixtla_client = NixtlaClient(
api_key = api_key
)
st.title("TimeGPT Forecasting")
with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
if uploaded_file:
df = pd.read_csv(uploaded_file)
st.session_state.df = df
else:
df = load_default()
st.session_state.df = df
# Column selection
columns = df.columns.tolist() # Convert Index to list
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
y_col = st.selectbox("Select Target column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
id_col = 'ts_test'
df['unique_id']=id_col
freq = determine_frequency(df)
forecast_df = nixtla_client.forecast(
df=df,
h=7,
freq=freq,
level=[90]
)
nixtla_client.plot(
forecast_df,
level=[90],
max_insample_length=365
)
pg = st.navigation({
"NeuralForecast": [
# Load pages from functions
st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/query_stats:"),
st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/monitoring:"),
],
"TimeGPT": [
# Load pages from functions
st.Page(timegpt_fcst, title="TimeGPT Forecast", icon=":material/smart_toy:"),
st.Page(timegpt_anom, title="TimeGPT Anomalies Detection", icon=":material/detector_offline:")
]
})
try:
pg.run()
except Exception as e:
st.sidebar.error(f"Something went wrong: {e}", icon=":material/error:")