Spaces:
Runtime error
Runtime error
import os | |
import math | |
import tempfile | |
import warnings | |
import streamlit as st | |
import pandas as pd | |
import torch | |
import plotly.express as px | |
from torch.optim import AdamW | |
from torch.optim.lr_scheduler import OneCycleLR | |
from transformers import ( | |
EarlyStoppingCallback, | |
Trainer, | |
TrainingArguments, | |
set_seed, | |
) | |
from transformers.integrations import INTEGRATION_TO_CALLBACK | |
from tsfm_public import ( | |
TimeSeriesPreprocessor, | |
TrackingCallback, | |
count_parameters, | |
get_datasets, | |
) | |
from tsfm_public.toolkit.get_model import get_model | |
from tsfm_public.toolkit.lr_finder import optimal_lr_finder | |
from tsfm_public.toolkit.visualization import plot_predictions | |
# For M4 Hourly Example | |
from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction | |
# Suppress warnings and set a reproducible seed | |
warnings.filterwarnings("ignore") | |
SEED = 42 | |
set_seed(SEED) | |
# Default model parameters and output directory | |
TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2" | |
DEFAULT_CONTEXT_LENGTH = 512 | |
DEFAULT_PREDICTION_LENGTH = 96 | |
OUT_DIR = "dashboard_outputs" | |
os.makedirs(OUT_DIR, exist_ok=True) | |
# -------------------------- | |
# Helper: Interactive Plot | |
def interactive_plot(actual, forecast, title="Forecast vs Actual"): | |
df = pd.DataFrame( | |
{"Time": range(len(actual)), "Actual": actual, "Forecast": forecast} | |
) | |
fig = px.line(df, x="Time", y=["Actual", "Forecast"], title=title) | |
return fig | |
# -------------------------- | |
# Mode 1: Zero-shot Evaluation | |
def run_zero_shot_forecasting( | |
data, | |
context_length, | |
prediction_length, | |
batch_size, | |
selected_target_columns, | |
selected_conditional_columns, | |
rolling_forecast_extension, | |
selected_forecast_index, | |
): | |
st.write("### Preparing Data for Forecasting") | |
timestamp_column = "date" | |
id_columns = [] # Modify if needed. | |
# Use selected target columns; default to all columns (except "date") if not provided. | |
if not selected_target_columns: | |
target_columns = [col for col in data.columns if col != timestamp_column] | |
else: | |
target_columns = selected_target_columns | |
# Incorporate exogenous/control columns. | |
conditional_columns = selected_conditional_columns | |
# Define column specifiers (if your preprocessor supports static columns, add here) | |
column_specifiers = { | |
"timestamp_column": timestamp_column, | |
"id_columns": id_columns, | |
"target_columns": target_columns, | |
"control_columns": conditional_columns, | |
} | |
n = len(data) | |
split_config = { | |
"train": [0, int(n * 0.7)], | |
"valid": [int(n * 0.7), int(n * 0.8)], | |
"test": [int(n * 0.8), n], | |
} | |
tsp = TimeSeriesPreprocessor( | |
**column_specifiers, | |
context_length=context_length, | |
prediction_length=prediction_length, | |
scaling=True, | |
encode_categorical=False, | |
scaler_type="standard", | |
) | |
dset_train, dset_valid, dset_test = get_datasets(tsp, data, split_config) | |
st.write("Data split into train, validation, and test sets.") | |
st.write("### Loading the Pre-trained TTM Model") | |
model = get_model( | |
TTM_MODEL_PATH, | |
context_length=context_length, | |
prediction_length=prediction_length, | |
) | |
temp_dir = tempfile.mkdtemp() | |
training_args = TrainingArguments( | |
output_dir=temp_dir, | |
per_device_eval_batch_size=batch_size, | |
seed=SEED, | |
report_to="none", | |
) | |
trainer = Trainer(model=model, args=training_args) | |
st.write("### Running Zero-shot Evaluation") | |
st.info("Evaluating on the test set...") | |
eval_output = trainer.evaluate(dset_test) | |
st.write("**Zero-shot Evaluation Metrics:**") | |
st.json(eval_output) | |
st.write("### Generating Forecast Predictions") | |
predictions_dict = trainer.predict(dset_test) | |
try: | |
predictions_np = predictions_dict.predictions[0] | |
except Exception as e: | |
st.error("Error extracting predictions: " + str(e)) | |
return | |
st.write("Predictions shape:", predictions_np.shape) | |
if rolling_forecast_extension > 0: | |
st.write( | |
f"### Rolling Forecast Extension: {rolling_forecast_extension} extra steps" | |
) | |
st.info("Rolling forecast logic can be implemented here.") | |
# Interactive plot for a selected forecast index. | |
idx = selected_forecast_index | |
try: | |
# This example assumes dset_test[idx] is a dict with a "target" key; adjust as needed. | |
actual = ( | |
dset_test[idx]["target"] | |
if isinstance(dset_test[idx], dict) | |
else dset_test[idx][0] | |
) | |
except Exception: | |
actual = predictions_np[idx] # Fallback if actual is not available. | |
fig = interactive_plot( | |
actual, predictions_np[idx], title=f"Forecast vs Actual for index {idx}" | |
) | |
st.plotly_chart(fig) | |
# Static plots (generated via plot_predictions) | |
plot_dir = os.path.join(OUT_DIR, "zero_shot_plots") | |
os.makedirs(plot_dir, exist_ok=True) | |
try: | |
plot_predictions( | |
model=trainer.model, | |
dset=dset_test, | |
plot_dir=plot_dir, | |
plot_prefix="test_zeroshot", | |
indices=[idx], | |
channel=0, | |
) | |
except Exception as e: | |
st.error("Error during static plotting: " + str(e)) | |
return | |
for file in os.listdir(plot_dir): | |
if file.endswith(".png"): | |
st.image(os.path.join(plot_dir, file), caption=file) | |
# -------------------------- | |
# Mode 2: Channel-Mix Finetuning Example | |
def run_channel_mix_finetuning(): | |
st.write("## Channel-Mix Finetuning Example (Bike Sharing Data)") | |
# Load bike sharing dataset | |
target_dataset = "bike_sharing" | |
DATA_ROOT_PATH = ( | |
"https://raw.githubusercontent.com/blobibob/bike-sharing-dataset/main/hour.csv" | |
) | |
timestamp_column = "dteday" | |
id_columns = [] | |
try: | |
data = pd.read_csv(DATA_ROOT_PATH, parse_dates=[timestamp_column]) | |
except Exception as e: | |
st.error("Error loading bike sharing dataset: " + str(e)) | |
return | |
data[timestamp_column] = pd.to_datetime(data[timestamp_column]) | |
# Adjust timestamps (to add hourly information) | |
data[timestamp_column] = data[timestamp_column] + pd.to_timedelta( | |
data.groupby(data[timestamp_column].dt.date).cumcount(), unit="h" | |
) | |
st.write("### Bike Sharing Data Preview") | |
st.dataframe(data.head()) | |
# Define columns: targets and conditional (exogenous) channels | |
column_specifiers = { | |
"timestamp_column": timestamp_column, | |
"id_columns": id_columns, | |
"target_columns": ["casual", "registered", "cnt"], | |
"conditional_columns": [ | |
"season", | |
"yr", | |
"mnth", | |
"holiday", | |
"weekday", | |
"workingday", | |
"weathersit", | |
"temp", | |
"atemp", | |
"hum", | |
"windspeed", | |
], | |
} | |
n = len(data) | |
split_config = { | |
"train": [0, int(n * 0.5)], | |
"valid": [int(n * 0.5), int(n * 0.75)], | |
"test": [int(n * 0.75), n], | |
} | |
context_length = 512 | |
forecast_length = 96 | |
tsp = TimeSeriesPreprocessor( | |
**column_specifiers, | |
context_length=context_length, | |
prediction_length=forecast_length, | |
scaling=True, | |
encode_categorical=False, | |
scaler_type="standard", | |
) | |
train_dataset, valid_dataset, test_dataset = get_datasets(tsp, data, split_config) | |
st.write("Data split completed.") | |
# For channel-mix finetuning, we use TTM-R1 (as per provided script) | |
TTM_MODEL_PATH_CM = "ibm-granite/granite-timeseries-ttm-r1" | |
finetune_forecast_model = get_model( | |
TTM_MODEL_PATH_CM, | |
context_length=context_length, | |
prediction_length=forecast_length, | |
num_input_channels=tsp.num_input_channels, | |
decoder_mode="mix_channel", | |
prediction_channel_indices=tsp.prediction_channel_indices, | |
) | |
st.write( | |
"Number of params before freezing backbone:", | |
count_parameters(finetune_forecast_model), | |
) | |
for param in finetune_forecast_model.backbone.parameters(): | |
param.requires_grad = False | |
st.write( | |
"Number of params after freezing backbone:", | |
count_parameters(finetune_forecast_model), | |
) | |
num_epochs = 50 | |
batch_size = 64 | |
learning_rate = 0.001 | |
optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate) | |
scheduler = OneCycleLR( | |
optimizer, | |
learning_rate, | |
epochs=num_epochs, | |
steps_per_epoch=math.ceil(len(train_dataset) / batch_size), | |
) | |
out_dir = os.path.join(OUT_DIR, target_dataset) | |
os.makedirs(out_dir, exist_ok=True) | |
finetune_args = TrainingArguments( | |
output_dir=os.path.join(out_dir, "output"), | |
overwrite_output_dir=True, | |
learning_rate=learning_rate, | |
num_train_epochs=num_epochs, | |
do_eval=True, | |
evaluation_strategy="epoch", | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
dataloader_num_workers=8, | |
report_to="none", | |
save_strategy="epoch", | |
logging_strategy="epoch", | |
save_total_limit=1, | |
logging_dir=os.path.join(out_dir, "logs"), | |
load_best_model_at_end=True, | |
metric_for_best_model="eval_loss", | |
greater_is_better=False, | |
seed=SEED, | |
) | |
early_stopping_callback = EarlyStoppingCallback( | |
early_stopping_patience=10, | |
early_stopping_threshold=1e-5, | |
) | |
tracking_callback = TrackingCallback() | |
finetune_trainer = Trainer( | |
model=finetune_forecast_model, | |
args=finetune_args, | |
train_dataset=train_dataset, | |
eval_dataset=valid_dataset, | |
callbacks=[early_stopping_callback, tracking_callback], | |
optimizers=(optimizer, scheduler), | |
) | |
finetune_trainer.remove_callback(INTEGRATION_TO_CALLBACK["codecarbon"]) | |
st.write("Starting channel-mix finetuning...") | |
finetune_trainer.train() | |
st.write("Evaluating finetuned model on test set...") | |
eval_output = finetune_trainer.evaluate(test_dataset) | |
st.write("Few-shot (channel-mix) evaluation metrics:") | |
st.json(eval_output) | |
# Plot predictions | |
plot_dir = os.path.join(out_dir, "channel_mix_plots") | |
os.makedirs(plot_dir, exist_ok=True) | |
try: | |
plot_predictions( | |
model=finetune_trainer.model, | |
dset=test_dataset, | |
plot_dir=plot_dir, | |
plot_prefix="test_channel_mix", | |
indices=[0], | |
channel=0, | |
) | |
except Exception as e: | |
st.error("Error plotting channel mix predictions: " + str(e)) | |
return | |
for file in os.listdir(plot_dir): | |
if file.endswith(".png"): | |
st.image(os.path.join(plot_dir, file), caption=file) | |
# -------------------------- | |
# Mode 3: M4 Hourly Example | |
def run_m4_hourly_example(): | |
st.write("## M4 Hourly Example") | |
st.info("This example reproduces a simplified version of the M4 hourly evaluation.") | |
# For demonstration, we attempt to load an M4 hourly dataset from a URL. | |
# (In practice, you would need to download and prepare the dataset.) | |
M4_DATASET_URL = "https://raw.githubusercontent.com/IBM/TSFM-public/main/tsfm_public/notebooks/ETTh1.csv" # Placeholder URL | |
try: | |
m4_data = pd.read_csv(M4_DATASET_URL, parse_dates=["date"]) | |
except Exception as e: | |
st.error("Could not load M4 hourly dataset: " + str(e)) | |
return | |
st.write("### M4 Hourly Data Preview") | |
st.dataframe(m4_data.head()) | |
context_length = 512 | |
forecast_length = 48 # M4 hourly forecast horizon | |
timestamp_column = "date" | |
id_columns = [] | |
target_columns = [col for col in m4_data.columns if col != timestamp_column] | |
n = len(m4_data) | |
split_config = { | |
"train": [0, int(n * 0.7)], | |
"valid": [int(n * 0.7), int(n * 0.85)], | |
"test": [int(n * 0.85), n], | |
} | |
column_specifiers = { | |
"timestamp_column": timestamp_column, | |
"id_columns": id_columns, | |
"target_columns": target_columns, | |
"control_columns": [], | |
} | |
tsp = TimeSeriesPreprocessor( | |
**column_specifiers, | |
context_length=context_length, | |
prediction_length=forecast_length, | |
scaling=True, | |
encode_categorical=False, | |
scaler_type="standard", | |
) | |
dset_train, dset_valid, dset_test = get_datasets(tsp, m4_data, split_config) | |
st.write("Data split completed.") | |
# Load model from Hugging Face TTM Model Repository (TTM-V1 for M4) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = TinyTimeMixerForPrediction.from_pretrained( | |
"ibm-granite/granite-timeseries-ttm-v1", | |
revision="main", | |
prediction_filter_length=forecast_length, | |
).to(device) | |
st.write("Running zero-shot evaluation on M4 hourly data...") | |
temp_dir = tempfile.mkdtemp() | |
trainer = Trainer( | |
model=model, | |
args=TrainingArguments( | |
output_dir=temp_dir, | |
per_device_eval_batch_size=64, | |
report_to="none", | |
), | |
) | |
eval_output = trainer.evaluate(dset_test) | |
st.write("Zero-shot evaluation metrics on M4 hourly:") | |
st.json(eval_output) | |
plot_dir = os.path.join(OUT_DIR, "m4_hourly", "zero_shot") | |
os.makedirs(plot_dir, exist_ok=True) | |
try: | |
plot_predictions( | |
model=trainer.model, | |
dset=dset_test, | |
plot_dir=plot_dir, | |
plot_prefix="m4_zero_shot", | |
indices=[0], | |
channel=0, | |
) | |
except Exception as e: | |
st.error("Error plotting M4 zero-shot predictions: " + str(e)) | |
return | |
for file in os.listdir(plot_dir): | |
if file.endswith(".png"): | |
st.image(os.path.join(plot_dir, file), caption=file) | |
st.info("Fine-tuning on M4 hourly data can be added similarly.") | |
# -------------------------- | |
# Main UI | |
def main(): | |
st.title("Interactive Time-Series Forecasting Dashboard") | |
st.markdown( | |
""" | |
This dashboard lets you run advanced forecasting experiments using the Granite-TimeSeries-TTM model. | |
Select one of the modes below: | |
- **Zero-shot Evaluation** | |
- **Channel-Mix Finetuning Example** | |
- **M4 Hourly Example** | |
""" | |
) | |
mode = st.selectbox( | |
"Select Evaluation Mode", | |
options=[ | |
"Zero-shot Evaluation", | |
"Channel-Mix Finetuning Example", | |
"M4 Hourly Example", | |
], | |
) | |
if mode == "Zero-shot Evaluation": | |
# Allow user to choose dataset source | |
dataset_source = st.radio( | |
"Dataset Source", options=["Default (ETTh1)", "Upload CSV"] | |
) | |
if dataset_source == "Default (ETTh1)": | |
DATASET_PATH = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv" | |
try: | |
data = pd.read_csv(DATASET_PATH, parse_dates=["date"]) | |
except Exception as e: | |
st.error("Error loading default dataset.") | |
return | |
st.write("### Default Dataset Preview") | |
st.dataframe(data.head()) | |
selected_target_columns = [ | |
"HUFL", | |
"HULL", | |
"MUFL", | |
"MULL", | |
"LUFL", | |
"LULL", | |
"OT", | |
] | |
else: | |
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"]) | |
if not uploaded_file: | |
st.info("Awaiting CSV file upload.") | |
return | |
data = pd.read_csv(uploaded_file, parse_dates=["date"]) | |
st.write("### Uploaded Data Preview") | |
st.dataframe(data.head()) | |
available_columns = [col for col in data.columns if col != "date"] | |
selected_target_columns = st.multiselect( | |
"Select Target Column(s)", | |
options=available_columns, | |
default=available_columns, | |
) | |
# Advanced options | |
available_exog = [ | |
col | |
for col in data.columns | |
if col not in (["date"] + selected_target_columns) | |
] | |
selected_conditional_columns = st.multiselect( | |
"Select Exogenous/Control Columns", options=available_exog, default=[] | |
) | |
rolling_extension = st.number_input( | |
"Rolling Forecast Extension (Extra Steps)", value=0, min_value=0, step=1 | |
) | |
forecast_index = st.slider( | |
"Select Forecast Index for Plotting", | |
min_value=0, | |
max_value=len(data) - 1, | |
value=0, | |
) | |
context_length = st.number_input( | |
"Context Length", value=DEFAULT_CONTEXT_LENGTH, step=64 | |
) | |
prediction_length = st.number_input( | |
"Prediction Length", value=DEFAULT_PREDICTION_LENGTH, step=1 | |
) | |
batch_size = st.number_input("Batch Size", value=64, step=1) | |
if st.button("Run Zero-shot Evaluation"): | |
with st.spinner("Running zero-shot evaluation..."): | |
run_zero_shot_forecasting( | |
data, | |
context_length, | |
prediction_length, | |
batch_size, | |
selected_target_columns, | |
selected_conditional_columns, | |
rolling_extension, | |
forecast_index, | |
) | |
elif mode == "Channel-Mix Finetuning Example": | |
if st.button("Run Channel-Mix Finetuning Example"): | |
with st.spinner("Running channel-mix finetuning..."): | |
run_channel_mix_finetuning() | |
elif mode == "M4 Hourly Example": | |
if st.button("Run M4 Hourly Example"): | |
with st.spinner("Running M4 hourly example..."): | |
run_m4_hourly_example() | |
if __name__ == "__main__": | |
main() | |