azrai99 commited on
Commit
e5e340b
·
verified ·
1 Parent(s): a4bae1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -269
app.py CHANGED
@@ -1,276 +1,220 @@
1
- from time import time
2
-
3
- import numpy as np
4
- import pandas as pd
5
- import plotly.express as px
6
- import plotly.graph_objects as go
7
  import streamlit as st
8
- from neuralforecast.losses.pytorch import MAE, RMSE, MAPE, SMAPE, MASE
9
- from st_aggrid import AgGrid
10
-
11
- from src.nf import MODELS, forecast_pretrained_model
12
- from src.model_descriptions import model_cards
13
-
14
- DATASETS = {
15
- "Electricity (Ercot COAST)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/ercot_COAST.csv",
16
- "Web Traffic (Peyton Manning)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/peyton_manning.csv",
17
- "Demand (AirPassengers)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv",
18
- "Finance (Exchange USD-EUR)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/usdeur.csv",
 
 
 
19
  }
20
 
 
 
 
 
 
 
 
21
 
22
- @st.cache_data
23
- def convert_df(df):
24
- # IMPORTANT: Cache the conversion to prevent computation on every rerun
25
- return df.to_csv(index=False).encode("utf-8")
26
-
27
-
28
- def plot(df, uid, df_forecast, model):
29
- figs = []
30
- figs += [
31
- go.Scatter(
32
- x=df["ds"],
33
- y=df["y"],
34
- mode="lines",
35
- marker=dict(color="#236796"),
36
- legendrank=1,
37
- name=uid,
38
- ),
39
- ]
40
- if df_forecast is not None:
41
- ds_f = df_forecast["ds"].to_list()
42
- lo = df_forecast["forecast_lo_90"].to_list()
43
- hi = df_forecast["forecast_hi_90"].to_list()
44
- figs += [
45
- go.Scatter(
46
- x=ds_f + ds_f[::-1],
47
- y=hi + lo[::-1],
48
- fill="toself",
49
- fillcolor="#E7C4C0",
50
- mode="lines",
51
- line=dict(color="#E7C4C0"),
52
- name="Prediction Intervals (90%)",
53
- legendrank=5,
54
- opacity=0.5,
55
- hoverinfo="skip",
56
- ),
57
- go.Scatter(
58
- x=ds_f,
59
- y=df_forecast["forecast"],
60
- mode="lines",
61
- legendrank=4,
62
- marker=dict(color="#E7C4C0"),
63
- name=f"Forecast {uid}",
64
- ),
65
- ]
66
- fig = go.Figure(figs)
67
- fig.update_layout(
68
- {"plot_bgcolor": "rgba(0, 0, 0, 0)", "paper_bgcolor": "rgba(0, 0, 0, 0)"}
69
- )
70
- fig.update_layout(
71
- title=f"Forecasts for {uid} using Transfer Learning (from {model})",
72
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
73
- margin=dict(l=20, b=20),
74
- xaxis=dict(rangeslider=dict(visible=True)),
75
- )
76
- initial_range = [df.tail(200)["ds"].iloc[0], ds_f[-1]]
77
- fig["layout"]["xaxis"].update(range=initial_range)
78
- return fig
79
-
80
-
81
- def st_transfer_learning():
82
- st.set_page_config(
83
- page_title="Time Series Forecasting",
84
- page_icon="🔮",
85
- layout="wide",
86
- initial_sidebar_state="expanded",
87
- )
88
-
89
- st.title(
90
- "Transfer Learning: Revolutionizing Time Series"
91
- )
92
- st.write(
93
- "<style>div.block-container{padding-top:2rem;}</style>", unsafe_allow_html=True
94
- )
95
-
96
- intro = """
97
- The success of startups like Open AI and Stability highlights the potential for transfer learning (TL) techniques to have a similar impact on the field of time series forecasting.
98
- TL can achieve lightning-fast predictions with a fraction of the computational cost by pre-training a flexible model on a large dataset and then using it on another dataset with little to no additional training.
99
- In this live demo, you can use pre-trained models by Nixtla (trained on the M4 dataset) to predict your own datasets. You can also see how the models perform on unseen example datasets.
100
- """
101
- st.write(intro)
102
-
103
- required_cols = ["ds", "y"]
104
-
105
- with st.sidebar.expander("Dataset", expanded=False):
106
- data_selection = st.selectbox("Select example dataset", DATASETS.keys())
107
- data_url = DATASETS[data_selection]
108
- url_json = st.text_input("Data (you can pass your own url here)", data_url)
109
- st.write(
110
- "You can also upload a CSV file like [this one](https://github.com/Nixtla/transfer-learning-time-series/blob/main/datasets/air_passengers.csv)."
111
- )
112
-
113
- uploaded_file = st.file_uploader("Upload CSV")
114
- with st.form("Data"):
115
-
116
- if uploaded_file is not None:
117
- df = pd.read_csv(uploaded_file)
118
- cols = df.columns
119
- timestamp_col = st.selectbox("Timestamp column", options=cols)
120
- value_col = st.selectbox("Value column", options=cols)
121
- else:
122
- timestamp_col = st.text_input("Timestamp column", value="timestamp")
123
- value_col = st.text_input("Value column", value="value")
124
- st.write("You must press Submit each time you want to forecast.")
125
- submitted = st.form_submit_button("Submit")
126
- if submitted:
127
- if uploaded_file is None:
128
- st.write("Please provide a dataframe.")
129
- if url_json.endswith("json"):
130
- df = pd.read_json(url_json)
131
- else:
132
- df = pd.read_csv(url_json)
133
- df = df.rename(
134
- columns=dict(zip([timestamp_col, value_col], required_cols))
135
- )
136
- else:
137
- df = df.rename(
138
- columns=dict(zip([timestamp_col, value_col], required_cols))
139
- )
140
- else:
141
- if url_json.endswith("json"):
142
- df = pd.read_json(url_json)
143
- else:
144
- df = pd.read_csv(url_json)
145
- cols = df.columns
146
- if "unique_id" in cols:
147
- cols = cols[-2:]
148
- df = df.rename(columns=dict(zip(cols, required_cols)))
149
-
150
- if "unique_id" not in df:
151
- df.insert(0, "unique_id", "ts_0")
152
-
153
- df["ds"] = pd.to_datetime(df["ds"])
154
- df = df.sort_values(["unique_id", "ds"])
155
-
156
- with st.sidebar:
157
- st.write("Define the pretrained model you want to use to forecast your data")
158
- model_name = st.selectbox("Select your model", tuple(MODELS.keys()))
159
- model_file = MODELS[model_name]["model"]
160
- st.write("Choose how many steps you want to forecast")
161
- fh = st.number_input("Forecast horizon", value=18)
162
- st.write(
163
- "Choose for how many steps the pretrained model will be updated using your data (use 0 for fast computation)"
164
- )
165
- max_steps = st.number_input("N-shot inference", value=0)
166
-
167
- # tabs
168
- tab_fcst, tab_cv, tab_docs = st.tabs(
169
- [
170
- "📈 Forecast",
171
- "🔎 Cross Validation",
172
- "📚 Documentation",
173
- ]
174
- )
175
-
176
- uids = df["unique_id"].unique()
177
- fcst_cols = ["forecast_lo_90", "forecast", "forecast_hi_90"]
178
-
179
- with tab_fcst:
180
- uid = uids[0]
181
- col1, col2 = st.columns([2, 4])
182
- with col1:
183
- tab_insample, tab_forecast = st.tabs(
184
- ["Modify input data", "Modify forecasts"]
185
- )
186
- with tab_insample:
187
- df_grid = df.query("unique_id == @uid").drop(columns="unique_id")
188
- grid_table = AgGrid(
189
- df_grid,
190
- editable=True,
191
- theme="alpine",
192
- fit_columns_on_grid_load=True,
193
- height=360,
194
- )
195
- df.loc[df["unique_id"] == uid, "y"] = (
196
- grid_table["data"].sort_values("ds")["y"].values
197
- )
198
- # forecast code
199
- init = time()
200
- df_forecast = forecast_pretrained_model(df, model_file, fh, max_steps)
201
- end = time()
202
- df_forecast = df_forecast.rename(
203
- columns=dict(zip(["y_5", "y_50", "y_95"], fcst_cols))
204
- )
205
- with tab_forecast:
206
- df_fcst_grid = df_forecast.query("unique_id == @uid").filter(
207
- ["ds", "forecast"]
208
- )
209
- grid_fcst_table = AgGrid(
210
- df_fcst_grid,
211
- editable=True,
212
- theme="alpine",
213
- fit_columns_on_grid_load=True,
214
- height=360,
215
- )
216
- changes = (
217
- df_forecast.query("unique_id == @uid")["forecast"].values
218
- - grid_fcst_table["data"].sort_values("ds")["forecast"].values
219
- )
220
- for col in fcst_cols:
221
- df_forecast.loc[df_forecast["unique_id"] == uid, col] = (
222
- df_forecast.loc[df_forecast["unique_id"] == uid, col] - changes
223
- )
224
- with col2:
225
- if uploaded_file is not None:
226
- fct_name = value_col
227
- else:
228
- fct_name=uid
229
- st.plotly_chart(
230
- plot(
231
- df.query("unique_id == @uid"),
232
- fct_name,
233
- df_forecast.query("unique_id == @uid"),
234
- model_file,
235
- ),
236
- use_container_width=True,
237
- )
238
- st.write(f"Done in: {np.round(end-init, 2)} secs.")
239
- st.write(f"Forecast for {fh} steps ahead.")
240
- st.write("You can download the forecast for the entire dataframe here:")
241
- csv = convert_df(
242
- df_forecast[["unique_id", "ds"] + fcst_cols].sort_values(
243
- ["unique_id", "ds"]
244
- )
245
- )
246
- st.download_button(
247
- label="Download CSV",
248
- data=csv,
249
- file_name="forecast.csv",
250
- mime="text/csv",
251
- )
252
- st.write(df_forecast[["unique_id", "ds"] + fcst_cols].tail(10))
253
-
254
- with tab_cv:
255
- st.write(
256
- "To enable Cross Validation, use the advanced forecasting tool at our [site](https://nixtla.github.io/transfer-learning-time-series/)."
257
- )
258
- df_forecast_cv = None
259
 
260
- with tab_docs:
261
- st.write("Documentation (Work in progress)")
262
- st.write(model_cards[model_name])
 
 
 
 
263
 
264
- with st.sidebar.expander("Data info", expanded=False):
265
- st.write(df.describe())
266
- csv = convert_df(df)
267
- st.download_button(
268
- label="Download data as CSV",
269
- data=csv,
270
- file_name="dataset.csv",
271
- mime="text/csv",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  )
273
-
274
-
275
- if __name__ == "__main__":
276
- st_transfer_learning()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from neuralforecast.core import NeuralForecast
5
+ from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
6
+ from neuralforecast.losses.pytorch import HuberMQLoss
7
+ import time
8
+
9
+ # Paths for saving models
10
+ nhits_paths = {
11
+ 'D': './M4/NHITS/daily',
12
+ 'M': './M4/NHITS/monthly',
13
+ 'H': './M4/NHITS/hourly',
14
+ 'W': './M4/NHITS/weekly',
15
+ 'Y': './M4/NHITS/yearly'
16
  }
17
 
18
+ timesnet_paths = {
19
+ 'D': './M4/TimesNet/daily',
20
+ 'M': './M4/TimesNet/monthly',
21
+ 'H': './M4/TimesNet/hourly',
22
+ 'W': './M4/TimesNet/weekly',
23
+ 'Y': './M4/TimesNet/yearly'
24
+ }
25
 
26
+ lstm_paths = {
27
+ 'D': './M4/LSTM/daily',
28
+ 'M': './M4/LSTM/monthly',
29
+ 'H': './M4/LSTM/hourly',
30
+ 'W': './M4/LSTM/weekly',
31
+ 'Y': './M4/LSTM/yearly'
32
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ tft_paths = {
35
+ 'D': './M4/TFT/daily',
36
+ 'M': './M4/TFT/monthly',
37
+ 'H': './M4/TFT/hourly',
38
+ 'W': './M4/TFT/weekly',
39
+ 'Y': './M4/TFT/yearly'
40
+ }
41
 
42
+ @st.cache_resource
43
+ def load_model(path, freq):
44
+ nf = NeuralForecast.load(path=path)
45
+ return nf
46
+
47
+ nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()}
48
+ timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()}
49
+ lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()}
50
+ tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()}
51
+
52
+ def generate_forecast(model, df):
53
+ forecast_df = model.predict(df=df)
54
+ return forecast_df
55
+
56
+ def determine_frequency(df):
57
+ df['ds'] = pd.to_datetime(df['ds'])
58
+ df = df.set_index('ds')
59
+ freq = pd.infer_freq(df.index)
60
+ return freq
61
+
62
+ def plot_forecasts(forecast_df, train_df, title):
63
+ fig, ax = plt.subplots(1, 1, figsize=(20, 7))
64
+ plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
65
+ historical_col = 'y'
66
+ forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
67
+ lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
68
+ hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
69
+ if forecast_col is None:
70
+ raise KeyError("No forecast column found in the data.")
71
+ plot_df[[historical_col, forecast_col]].plot(ax=ax, linewidth=2, label=['Historical', 'Forecast'])
72
+ if lo_col and hi_col:
73
+ ax.fill_between(
74
+ plot_df.index,
75
+ plot_df[lo_col],
76
+ plot_df[hi_col],
77
+ color='blue',
78
+ alpha=0.3,
79
+ label='90% Confidence Interval'
80
  )
81
+ ax.set_title(title, fontsize=22)
82
+ ax.set_ylabel('Value', fontsize=20)
83
+ ax.set_xlabel('Timestamp [t]', fontsize=20)
84
+ ax.legend(prop={'size': 15})
85
+ ax.grid()
86
+ st.pyplot(fig)
87
+
88
+ def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models):
89
+ if freq == 'D':
90
+ return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D']
91
+ elif freq == 'M':
92
+ return nhits_models['M'], timesnet_models['M'], lstm_models['M'], tft_models['M']
93
+ elif freq == 'H':
94
+ return nhits_models['H'], timesnet_models['H'], lstm_models['H'], tft_models['H']
95
+ elif freq in ['W', 'W-SUN']:
96
+ return nhits_models['W'], timesnet_models['W'], lstm_models['W'], tft_models['W']
97
+ elif freq in ['Y', 'Y-DEC']:
98
+ return nhits_models['Y'], timesnet_models['Y'], lstm_models['Y'], tft_models['Y']
99
+ else:
100
+ raise ValueError(f"Unsupported frequency: {freq}")
101
+
102
+ def select_model(horizon, model_type, max_steps=200):
103
+ if model_type == 'NHITS':
104
+ return NHITS(input_size=5 * horizon,
105
+ h=horizon,
106
+ max_steps=max_steps,
107
+ stack_types=3*['identity'],
108
+ n_blocks=3*[1],
109
+ mlp_units=[[256, 256] for _ in range(3)],
110
+ n_pool_kernel_size=3*[1],
111
+ batch_size=32,
112
+ scaler_type='standard',
113
+ n_freq_downsample=[12, 4, 1],
114
+ loss=HuberMQLoss(level=[90]))
115
+ elif model_type == 'TimesNet':
116
+ return TimesNet(h=horizon,
117
+ input_size=horizon * 5,
118
+ hidden_size=16,
119
+ conv_hidden_size=32,
120
+ loss=HuberMQLoss(level=[90]),
121
+ scaler_type='standard',
122
+ learning_rate=1e-3,
123
+ max_steps=max_steps,
124
+ val_check_steps=200,
125
+ valid_batch_size=64,
126
+ windows_batch_size=128,
127
+ inference_windows_batch_size=512)
128
+ elif model_type == 'LSTM':
129
+ return LSTM(h=horizon,
130
+ input_size=horizon * 5,
131
+ loss=HuberMQLoss(level=[90]),
132
+ scaler_type='standard',
133
+ encoder_n_layers=2,
134
+ encoder_hidden_size=64,
135
+ context_size=10,
136
+ decoder_hidden_size=64,
137
+ decoder_layers=2,
138
+ max_steps=max_steps)
139
+ elif model_type == 'TFT':
140
+ return TFT(h=horizon,
141
+ input_size=horizon,
142
+ hidden_size=16,
143
+ loss=HuberMQLoss(level=[90]),
144
+ learning_rate=0.005,
145
+ scaler_type='standard',
146
+ windows_batch_size=128,
147
+ max_steps=max_steps,
148
+ val_check_steps=200,
149
+ valid_batch_size=64,
150
+ enable_progress_bar=True)
151
+ else:
152
+ raise ValueError(f"Unsupported model type: {model_type}")
153
+
154
+ def forecast_time_series(df, model_type, freq, horizon, max_steps=200):
155
+ start_time = time.time() # Start timing
156
+ if freq:
157
+ df['ds'] = pd.date_range(start='1970-01-01', periods=len(df), freq=freq)
158
+ else:
159
+ freq = determine_frequency(df)
160
+ st.write(f"Determined frequency: {freq}")
161
+ df['ds'] = pd.to_datetime(df['ds'], errors='coerce')
162
+ df = df.dropna(subset=['ds'])
163
+ model = select_model(horizon, model_type, max_steps)
164
+ forecast_results = {}
165
+ st.write(f"Generating forecast using {model_type} model...")
166
+ forecast_results[model_type] = generate_forecast(model, df, freq)
167
+
168
+ for model_name, forecast_df in forecast_results.items():
169
+ plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison')
170
+
171
+ end_time = time.time() # End timing
172
+ time_taken = end_time - start_time
173
+ st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
174
+
175
+ # Streamlit App
176
+ st.title("Dynamic and Automatic Time Series Forecasting")
177
+
178
+ # Upload dataset
179
+ uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
180
+ if uploaded_file:
181
+ df = pd.read_csv(uploaded_file)
182
+ else:
183
+ st.warning("Using default data")
184
+ df = AirPassengersDF.copy()
185
+
186
+ # Model selection and forecasting
187
+ st.subheader("Transfer Learning Forecasting")
188
+ model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
189
+ horizon = st.slider("Forecast horizon", 1, 100, 10)
190
+
191
+ # Determine frequency of data
192
+ frequency = determine_frequency(df)
193
+ st.write(f"Detected frequency: {frequency}")
194
+
195
+ # Load pre-trained models
196
+ nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
197
+ forecast_results = {}
198
+
199
+ start_time = time.time() # Start timing
200
+ if model_choice == "NHITS":
201
+ forecast_results['NHITS'] = generate_forecast(nhits_model, df)
202
+ elif model_choice == "TimesNet":
203
+ forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
204
+ elif model_choice == "LSTM":
205
+ forecast_results['LSTM'] = generate_forecast(lstm_model, df)
206
+ elif model_choice == "TFT":
207
+ forecast_results['TFT'] = generate_forecast(tft_model, df)
208
+
209
+ for model_name, forecast_df in forecast_results.items():
210
+ plot_forecasts(forecast_df, df, f'{model_name} Forecast')
211
+
212
+ end_time = time.time() # End timing
213
+ time_taken = end_time - start_time
214
+ st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
215
+
216
+ # Dynamic forecasting
217
+ st.subheader("Dynamic Forecasting")
218
+ dynamic_model_choice = st.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
219
+ dynamic_horizon = st.slider("Forecast horizon for dynamic forecasting", 1, 100, 10, key="dynamic_horizon")
220
+ forecast_time_series(df, dynamic_model_choice, frequency, dynamic_horizon)