fraware commited on
Commit
3389ae7
·
1 Parent(s): 4a982ac

Added Docker support for Hugging Face Spaces

Browse files
Files changed (4) hide show
  1. Dockerfile +22 -0
  2. EETh1.csv +0 -0
  3. app.py +528 -0
  4. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9-slim
3
+
4
+ # Set environment variables
5
+ ENV PYTHONDONTWRITEBYTECODE=1
6
+ ENV PYTHONUNBUFFERED=1
7
+
8
+ # Set the working directory
9
+ WORKDIR /app
10
+
11
+ # Copy the requirements file and install dependencies
12
+ COPY requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy the rest of the application code
16
+ COPY . /app
17
+
18
+ # Expose the port that Streamlit uses
19
+ EXPOSE 8501
20
+
21
+ # Command to run the Streamlit app
22
+ CMD ["streamlit", "run", "app.py", "--server.enableCORS", "false"]
EETh1.csv ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import tempfile
4
+ import warnings
5
+ import streamlit as st
6
+ import pandas as pd
7
+ import torch
8
+ import plotly.express as px
9
+
10
+ from torch.optim import AdamW
11
+ from torch.optim.lr_scheduler import OneCycleLR
12
+ from transformers import (
13
+ EarlyStoppingCallback,
14
+ Trainer,
15
+ TrainingArguments,
16
+ set_seed,
17
+ )
18
+ from transformers.integrations import INTEGRATION_TO_CALLBACK
19
+
20
+ from tsfm_public import (
21
+ TimeSeriesPreprocessor,
22
+ TrackingCallback,
23
+ count_parameters,
24
+ get_datasets,
25
+ )
26
+ from tsfm_public.toolkit.get_model import get_model
27
+ from tsfm_public.toolkit.lr_finder import optimal_lr_finder
28
+ from tsfm_public.toolkit.visualization import plot_predictions
29
+
30
+ # For M4 Hourly Example
31
+ from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction
32
+
33
+ # Suppress warnings and set a reproducible seed
34
+ warnings.filterwarnings("ignore")
35
+ SEED = 42
36
+ set_seed(SEED)
37
+
38
+ # Default model parameters and output directory
39
+ TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"
40
+ DEFAULT_CONTEXT_LENGTH = 512
41
+ DEFAULT_PREDICTION_LENGTH = 96
42
+ OUT_DIR = "dashboard_outputs"
43
+ os.makedirs(OUT_DIR, exist_ok=True)
44
+
45
+
46
+ # --------------------------
47
+ # Helper: Interactive Plot
48
+ def interactive_plot(actual, forecast, title="Forecast vs Actual"):
49
+ df = pd.DataFrame(
50
+ {"Time": range(len(actual)), "Actual": actual, "Forecast": forecast}
51
+ )
52
+ fig = px.line(df, x="Time", y=["Actual", "Forecast"], title=title)
53
+ return fig
54
+
55
+
56
+ # --------------------------
57
+ # Mode 1: Zero-shot Evaluation
58
+ def run_zero_shot_forecasting(
59
+ data,
60
+ context_length,
61
+ prediction_length,
62
+ batch_size,
63
+ selected_target_columns,
64
+ selected_conditional_columns,
65
+ rolling_forecast_extension,
66
+ selected_forecast_index,
67
+ ):
68
+ st.write("### Preparing Data for Forecasting")
69
+ timestamp_column = "date"
70
+ id_columns = [] # Modify if needed.
71
+ # Use selected target columns; default to all columns (except "date") if not provided.
72
+ if not selected_target_columns:
73
+ target_columns = [col for col in data.columns if col != timestamp_column]
74
+ else:
75
+ target_columns = selected_target_columns
76
+
77
+ # Incorporate exogenous/control columns.
78
+ conditional_columns = selected_conditional_columns
79
+
80
+ # Define column specifiers (if your preprocessor supports static columns, add here)
81
+ column_specifiers = {
82
+ "timestamp_column": timestamp_column,
83
+ "id_columns": id_columns,
84
+ "target_columns": target_columns,
85
+ "control_columns": conditional_columns,
86
+ }
87
+
88
+ n = len(data)
89
+ split_config = {
90
+ "train": [0, int(n * 0.7)],
91
+ "valid": [int(n * 0.7), int(n * 0.8)],
92
+ "test": [int(n * 0.8), n],
93
+ }
94
+
95
+ tsp = TimeSeriesPreprocessor(
96
+ **column_specifiers,
97
+ context_length=context_length,
98
+ prediction_length=prediction_length,
99
+ scaling=True,
100
+ encode_categorical=False,
101
+ scaler_type="standard",
102
+ )
103
+ dset_train, dset_valid, dset_test = get_datasets(tsp, data, split_config)
104
+ st.write("Data split into train, validation, and test sets.")
105
+
106
+ st.write("### Loading the Pre-trained TTM Model")
107
+ model = get_model(
108
+ TTM_MODEL_PATH,
109
+ context_length=context_length,
110
+ prediction_length=prediction_length,
111
+ )
112
+ temp_dir = tempfile.mkdtemp()
113
+ training_args = TrainingArguments(
114
+ output_dir=temp_dir,
115
+ per_device_eval_batch_size=batch_size,
116
+ seed=SEED,
117
+ report_to="none",
118
+ )
119
+ trainer = Trainer(model=model, args=training_args)
120
+
121
+ st.write("### Running Zero-shot Evaluation")
122
+ st.info("Evaluating on the test set...")
123
+ eval_output = trainer.evaluate(dset_test)
124
+ st.write("**Zero-shot Evaluation Metrics:**")
125
+ st.json(eval_output)
126
+
127
+ st.write("### Generating Forecast Predictions")
128
+ predictions_dict = trainer.predict(dset_test)
129
+ try:
130
+ predictions_np = predictions_dict.predictions[0]
131
+ except Exception as e:
132
+ st.error("Error extracting predictions: " + str(e))
133
+ return
134
+ st.write("Predictions shape:", predictions_np.shape)
135
+
136
+ if rolling_forecast_extension > 0:
137
+ st.write(
138
+ f"### Rolling Forecast Extension: {rolling_forecast_extension} extra steps"
139
+ )
140
+ st.info("Rolling forecast logic can be implemented here.")
141
+
142
+ # Interactive plot for a selected forecast index.
143
+ idx = selected_forecast_index
144
+ try:
145
+ # This example assumes dset_test[idx] is a dict with a "target" key; adjust as needed.
146
+ actual = (
147
+ dset_test[idx]["target"]
148
+ if isinstance(dset_test[idx], dict)
149
+ else dset_test[idx][0]
150
+ )
151
+ except Exception:
152
+ actual = predictions_np[idx] # Fallback if actual is not available.
153
+ fig = interactive_plot(
154
+ actual, predictions_np[idx], title=f"Forecast vs Actual for index {idx}"
155
+ )
156
+ st.plotly_chart(fig)
157
+
158
+ # Static plots (generated via plot_predictions)
159
+ plot_dir = os.path.join(OUT_DIR, "zero_shot_plots")
160
+ os.makedirs(plot_dir, exist_ok=True)
161
+ try:
162
+ plot_predictions(
163
+ model=trainer.model,
164
+ dset=dset_test,
165
+ plot_dir=plot_dir,
166
+ plot_prefix="test_zeroshot",
167
+ indices=[idx],
168
+ channel=0,
169
+ )
170
+ except Exception as e:
171
+ st.error("Error during static plotting: " + str(e))
172
+ return
173
+ for file in os.listdir(plot_dir):
174
+ if file.endswith(".png"):
175
+ st.image(os.path.join(plot_dir, file), caption=file)
176
+
177
+
178
+ # --------------------------
179
+ # Mode 2: Channel-Mix Finetuning Example
180
+ def run_channel_mix_finetuning():
181
+ st.write("## Channel-Mix Finetuning Example (Bike Sharing Data)")
182
+ # Load bike sharing dataset
183
+ target_dataset = "bike_sharing"
184
+ DATA_ROOT_PATH = (
185
+ "https://raw.githubusercontent.com/blobibob/bike-sharing-dataset/main/hour.csv"
186
+ )
187
+ timestamp_column = "dteday"
188
+ id_columns = []
189
+ try:
190
+ data = pd.read_csv(DATA_ROOT_PATH, parse_dates=[timestamp_column])
191
+ except Exception as e:
192
+ st.error("Error loading bike sharing dataset: " + str(e))
193
+ return
194
+ data[timestamp_column] = pd.to_datetime(data[timestamp_column])
195
+ # Adjust timestamps (to add hourly information)
196
+ data[timestamp_column] = data[timestamp_column] + pd.to_timedelta(
197
+ data.groupby(data[timestamp_column].dt.date).cumcount(), unit="h"
198
+ )
199
+ st.write("### Bike Sharing Data Preview")
200
+ st.dataframe(data.head())
201
+
202
+ # Define columns: targets and conditional (exogenous) channels
203
+ column_specifiers = {
204
+ "timestamp_column": timestamp_column,
205
+ "id_columns": id_columns,
206
+ "target_columns": ["casual", "registered", "cnt"],
207
+ "conditional_columns": [
208
+ "season",
209
+ "yr",
210
+ "mnth",
211
+ "holiday",
212
+ "weekday",
213
+ "workingday",
214
+ "weathersit",
215
+ "temp",
216
+ "atemp",
217
+ "hum",
218
+ "windspeed",
219
+ ],
220
+ }
221
+ n = len(data)
222
+ split_config = {
223
+ "train": [0, int(n * 0.5)],
224
+ "valid": [int(n * 0.5), int(n * 0.75)],
225
+ "test": [int(n * 0.75), n],
226
+ }
227
+ context_length = 512
228
+ forecast_length = 96
229
+
230
+ tsp = TimeSeriesPreprocessor(
231
+ **column_specifiers,
232
+ context_length=context_length,
233
+ prediction_length=forecast_length,
234
+ scaling=True,
235
+ encode_categorical=False,
236
+ scaler_type="standard",
237
+ )
238
+ train_dataset, valid_dataset, test_dataset = get_datasets(tsp, data, split_config)
239
+ st.write("Data split completed.")
240
+
241
+ # For channel-mix finetuning, we use TTM-R1 (as per provided script)
242
+ TTM_MODEL_PATH_CM = "ibm-granite/granite-timeseries-ttm-r1"
243
+ finetune_forecast_model = get_model(
244
+ TTM_MODEL_PATH_CM,
245
+ context_length=context_length,
246
+ prediction_length=forecast_length,
247
+ num_input_channels=tsp.num_input_channels,
248
+ decoder_mode="mix_channel",
249
+ prediction_channel_indices=tsp.prediction_channel_indices,
250
+ )
251
+ st.write(
252
+ "Number of params before freezing backbone:",
253
+ count_parameters(finetune_forecast_model),
254
+ )
255
+ for param in finetune_forecast_model.backbone.parameters():
256
+ param.requires_grad = False
257
+ st.write(
258
+ "Number of params after freezing backbone:",
259
+ count_parameters(finetune_forecast_model),
260
+ )
261
+
262
+ num_epochs = 50
263
+ batch_size = 64
264
+ learning_rate = 0.001
265
+ optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)
266
+ scheduler = OneCycleLR(
267
+ optimizer,
268
+ learning_rate,
269
+ epochs=num_epochs,
270
+ steps_per_epoch=math.ceil(len(train_dataset) / batch_size),
271
+ )
272
+ out_dir = os.path.join(OUT_DIR, target_dataset)
273
+ os.makedirs(out_dir, exist_ok=True)
274
+ finetune_args = TrainingArguments(
275
+ output_dir=os.path.join(out_dir, "output"),
276
+ overwrite_output_dir=True,
277
+ learning_rate=learning_rate,
278
+ num_train_epochs=num_epochs,
279
+ do_eval=True,
280
+ evaluation_strategy="epoch",
281
+ per_device_train_batch_size=batch_size,
282
+ per_device_eval_batch_size=batch_size,
283
+ dataloader_num_workers=8,
284
+ report_to="none",
285
+ save_strategy="epoch",
286
+ logging_strategy="epoch",
287
+ save_total_limit=1,
288
+ logging_dir=os.path.join(out_dir, "logs"),
289
+ load_best_model_at_end=True,
290
+ metric_for_best_model="eval_loss",
291
+ greater_is_better=False,
292
+ seed=SEED,
293
+ )
294
+ early_stopping_callback = EarlyStoppingCallback(
295
+ early_stopping_patience=10,
296
+ early_stopping_threshold=1e-5,
297
+ )
298
+ tracking_callback = TrackingCallback()
299
+ finetune_trainer = Trainer(
300
+ model=finetune_forecast_model,
301
+ args=finetune_args,
302
+ train_dataset=train_dataset,
303
+ eval_dataset=valid_dataset,
304
+ callbacks=[early_stopping_callback, tracking_callback],
305
+ optimizers=(optimizer, scheduler),
306
+ )
307
+ finetune_trainer.remove_callback(INTEGRATION_TO_CALLBACK["codecarbon"])
308
+ st.write("Starting channel-mix finetuning...")
309
+ finetune_trainer.train()
310
+ st.write("Evaluating finetuned model on test set...")
311
+ eval_output = finetune_trainer.evaluate(test_dataset)
312
+ st.write("Few-shot (channel-mix) evaluation metrics:")
313
+ st.json(eval_output)
314
+ # Plot predictions
315
+ plot_dir = os.path.join(out_dir, "channel_mix_plots")
316
+ os.makedirs(plot_dir, exist_ok=True)
317
+ try:
318
+ plot_predictions(
319
+ model=finetune_trainer.model,
320
+ dset=test_dataset,
321
+ plot_dir=plot_dir,
322
+ plot_prefix="test_channel_mix",
323
+ indices=[0],
324
+ channel=0,
325
+ )
326
+ except Exception as e:
327
+ st.error("Error plotting channel mix predictions: " + str(e))
328
+ return
329
+ for file in os.listdir(plot_dir):
330
+ if file.endswith(".png"):
331
+ st.image(os.path.join(plot_dir, file), caption=file)
332
+
333
+
334
+ # --------------------------
335
+ # Mode 3: M4 Hourly Example
336
+ def run_m4_hourly_example():
337
+ st.write("## M4 Hourly Example")
338
+ st.info("This example reproduces a simplified version of the M4 hourly evaluation.")
339
+ # For demonstration, we attempt to load an M4 hourly dataset from a URL.
340
+ # (In practice, you would need to download and prepare the dataset.)
341
+ M4_DATASET_URL = "https://raw.githubusercontent.com/IBM/TSFM-public/main/tsfm_public/notebooks/ETTh1.csv" # Placeholder URL
342
+ try:
343
+ m4_data = pd.read_csv(M4_DATASET_URL, parse_dates=["date"])
344
+ except Exception as e:
345
+ st.error("Could not load M4 hourly dataset: " + str(e))
346
+ return
347
+ st.write("### M4 Hourly Data Preview")
348
+ st.dataframe(m4_data.head())
349
+ context_length = 512
350
+ forecast_length = 48 # M4 hourly forecast horizon
351
+ timestamp_column = "date"
352
+ id_columns = []
353
+ target_columns = [col for col in m4_data.columns if col != timestamp_column]
354
+ n = len(m4_data)
355
+ split_config = {
356
+ "train": [0, int(n * 0.7)],
357
+ "valid": [int(n * 0.7), int(n * 0.85)],
358
+ "test": [int(n * 0.85), n],
359
+ }
360
+ column_specifiers = {
361
+ "timestamp_column": timestamp_column,
362
+ "id_columns": id_columns,
363
+ "target_columns": target_columns,
364
+ "control_columns": [],
365
+ }
366
+ tsp = TimeSeriesPreprocessor(
367
+ **column_specifiers,
368
+ context_length=context_length,
369
+ prediction_length=forecast_length,
370
+ scaling=True,
371
+ encode_categorical=False,
372
+ scaler_type="standard",
373
+ )
374
+ dset_train, dset_valid, dset_test = get_datasets(tsp, m4_data, split_config)
375
+ st.write("Data split completed.")
376
+
377
+ # Load model from Hugging Face TTM Model Repository (TTM-V1 for M4)
378
+ device = "cuda" if torch.cuda.is_available() else "cpu"
379
+ model = TinyTimeMixerForPrediction.from_pretrained(
380
+ "ibm-granite/granite-timeseries-ttm-v1",
381
+ revision="main",
382
+ prediction_filter_length=forecast_length,
383
+ ).to(device)
384
+ st.write("Running zero-shot evaluation on M4 hourly data...")
385
+ temp_dir = tempfile.mkdtemp()
386
+ trainer = Trainer(
387
+ model=model,
388
+ args=TrainingArguments(
389
+ output_dir=temp_dir,
390
+ per_device_eval_batch_size=64,
391
+ report_to="none",
392
+ ),
393
+ )
394
+ eval_output = trainer.evaluate(dset_test)
395
+ st.write("Zero-shot evaluation metrics on M4 hourly:")
396
+ st.json(eval_output)
397
+ plot_dir = os.path.join(OUT_DIR, "m4_hourly", "zero_shot")
398
+ os.makedirs(plot_dir, exist_ok=True)
399
+ try:
400
+ plot_predictions(
401
+ model=trainer.model,
402
+ dset=dset_test,
403
+ plot_dir=plot_dir,
404
+ plot_prefix="m4_zero_shot",
405
+ indices=[0],
406
+ channel=0,
407
+ )
408
+ except Exception as e:
409
+ st.error("Error plotting M4 zero-shot predictions: " + str(e))
410
+ return
411
+ for file in os.listdir(plot_dir):
412
+ if file.endswith(".png"):
413
+ st.image(os.path.join(plot_dir, file), caption=file)
414
+ st.info("Fine-tuning on M4 hourly data can be added similarly.")
415
+
416
+
417
+ # --------------------------
418
+ # Main UI
419
+ def main():
420
+ st.title("Interactive Time-Series Forecasting Dashboard")
421
+ st.markdown(
422
+ """
423
+ This dashboard lets you run advanced forecasting experiments using the Granite-TimeSeries-TTM model.
424
+ Select one of the modes below:
425
+ - **Zero-shot Evaluation**
426
+ - **Channel-Mix Finetuning Example**
427
+ - **M4 Hourly Example**
428
+ """
429
+ )
430
+
431
+ mode = st.selectbox(
432
+ "Select Evaluation Mode",
433
+ options=[
434
+ "Zero-shot Evaluation",
435
+ "Channel-Mix Finetuning Example",
436
+ "M4 Hourly Example",
437
+ ],
438
+ )
439
+
440
+ if mode == "Zero-shot Evaluation":
441
+ # Allow user to choose dataset source
442
+ dataset_source = st.radio(
443
+ "Dataset Source", options=["Default (ETTh1)", "Upload CSV"]
444
+ )
445
+ if dataset_source == "Default (ETTh1)":
446
+ DATASET_PATH = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv"
447
+ try:
448
+ data = pd.read_csv(DATASET_PATH, parse_dates=["date"])
449
+ except Exception as e:
450
+ st.error("Error loading default dataset.")
451
+ return
452
+ st.write("### Default Dataset Preview")
453
+ st.dataframe(data.head())
454
+ selected_target_columns = [
455
+ "HUFL",
456
+ "HULL",
457
+ "MUFL",
458
+ "MULL",
459
+ "LUFL",
460
+ "LULL",
461
+ "OT",
462
+ ]
463
+ else:
464
+ uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
465
+ if not uploaded_file:
466
+ st.info("Awaiting CSV file upload.")
467
+ return
468
+ data = pd.read_csv(uploaded_file, parse_dates=["date"])
469
+ st.write("### Uploaded Data Preview")
470
+ st.dataframe(data.head())
471
+ available_columns = [col for col in data.columns if col != "date"]
472
+ selected_target_columns = st.multiselect(
473
+ "Select Target Column(s)",
474
+ options=available_columns,
475
+ default=available_columns,
476
+ )
477
+
478
+ # Advanced options
479
+ available_exog = [
480
+ col
481
+ for col in data.columns
482
+ if col not in (["date"] + selected_target_columns)
483
+ ]
484
+ selected_conditional_columns = st.multiselect(
485
+ "Select Exogenous/Control Columns", options=available_exog, default=[]
486
+ )
487
+ rolling_extension = st.number_input(
488
+ "Rolling Forecast Extension (Extra Steps)", value=0, min_value=0, step=1
489
+ )
490
+ forecast_index = st.slider(
491
+ "Select Forecast Index for Plotting",
492
+ min_value=0,
493
+ max_value=len(data) - 1,
494
+ value=0,
495
+ )
496
+ context_length = st.number_input(
497
+ "Context Length", value=DEFAULT_CONTEXT_LENGTH, step=64
498
+ )
499
+ prediction_length = st.number_input(
500
+ "Prediction Length", value=DEFAULT_PREDICTION_LENGTH, step=1
501
+ )
502
+ batch_size = st.number_input("Batch Size", value=64, step=1)
503
+ if st.button("Run Zero-shot Evaluation"):
504
+ with st.spinner("Running zero-shot evaluation..."):
505
+ run_zero_shot_forecasting(
506
+ data,
507
+ context_length,
508
+ prediction_length,
509
+ batch_size,
510
+ selected_target_columns,
511
+ selected_conditional_columns,
512
+ rolling_extension,
513
+ forecast_index,
514
+ )
515
+
516
+ elif mode == "Channel-Mix Finetuning Example":
517
+ if st.button("Run Channel-Mix Finetuning Example"):
518
+ with st.spinner("Running channel-mix finetuning..."):
519
+ run_channel_mix_finetuning()
520
+
521
+ elif mode == "M4 Hourly Example":
522
+ if st.button("Run M4 Hourly Example"):
523
+ with st.spinner("Running M4 hourly example..."):
524
+ run_m4_hourly_example()
525
+
526
+
527
+ if __name__ == "__main__":
528
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ torch
4
+ transformers
5
+ plotly
6
+ tsfm_public @ git+https://github.com/ibm-granite/granite-tsfm.git
7
+ fastapi
8
+ uvicorn[standard]