azrai99 commited on
Commit
b2108ae
ยท
verified ยท
1 Parent(s): 634cf9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -374
app.py CHANGED
@@ -1,374 +1,372 @@
1
- from time import time
2
-
3
- import numpy as np
4
- import pandas as pd
5
- import plotly.express as px
6
- import plotly.graph_objects as go
7
- import streamlit as st
8
- from datasetsforecast.losses import rmse, mae, smape, mse, mape
9
- from st_aggrid import AgGrid
10
-
11
- from src.nf import MODELS, forecast_pretrained_model
12
- from src.model_descriptions import model_cards
13
-
14
- DATASETS = {
15
- "Electricity (Ercot COAST)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/ercot_COAST.csv",
16
- #"Electriciy (ERCOT, multiple markets)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/ercot_multiple_ts.csv",
17
- "Web Traffic (Peyton Manning)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/peyton_manning.csv",
18
- "Demand (AirPassengers)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv",
19
- "Finance (Exchange USD-EUR)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/usdeur.csv",
20
- }
21
-
22
-
23
- @st.cache_data
24
- def convert_df(df):
25
- # IMPORTANT: Cache the conversion to prevent computation on every rerun
26
- return df.to_csv(index=False).encode("utf-8")
27
-
28
-
29
- def plot(df, uid, df_forecast, model):
30
- figs = []
31
- figs += [
32
- go.Scatter(
33
- x=df["ds"],
34
- y=df["y"],
35
- mode="lines",
36
- marker=dict(color="#236796"),
37
- legendrank=1,
38
- name=uid,
39
- ),
40
- ]
41
- if df_forecast is not None:
42
- ds_f = df_forecast["ds"].to_list()
43
- lo = df_forecast["forecast_lo_90"].to_list()
44
- hi = df_forecast["forecast_hi_90"].to_list()
45
- figs += [
46
- go.Scatter(
47
- x=ds_f + ds_f[::-1],
48
- y=hi + lo[::-1],
49
- fill="toself",
50
- fillcolor="#E7C4C0",
51
- mode="lines",
52
- line=dict(color="#E7C4C0"),
53
- name="Prediction Intervals (90%)",
54
- legendrank=5,
55
- opacity=0.5,
56
- hoverinfo="skip",
57
- ),
58
- go.Scatter(
59
- x=ds_f,
60
- y=df_forecast["forecast"],
61
- mode="lines",
62
- legendrank=4,
63
- marker=dict(color="#E7C4C0"),
64
- name=f"Forecast {uid}",
65
- ),
66
- ]
67
- fig = go.Figure(figs)
68
- fig.update_layout(
69
- {"plot_bgcolor": "rgba(0, 0, 0, 0)", "paper_bgcolor": "rgba(0, 0, 0, 0)"}
70
- )
71
- fig.update_layout(
72
- title=f"Forecasts for {uid} using Transfer Learning (from {model})",
73
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
74
- margin=dict(l=20, b=20),
75
- xaxis=dict(rangeslider=dict(visible=True)),
76
- )
77
- initial_range = [df.tail(200)["ds"].iloc[0], ds_f[-1]]
78
- fig["layout"]["xaxis"].update(range=initial_range)
79
- return fig
80
-
81
-
82
- def st_transfer_learning():
83
- st.set_page_config(
84
- page_title="Time Series Visualization",
85
- page_icon="๐Ÿ”ฎ",
86
- layout="wide",
87
- initial_sidebar_state="expanded",
88
- )
89
-
90
- st.title(
91
- "Transfer Learning: Revolutionizing Time Series by Nixtla"
92
- )
93
- st.write(
94
- "<style>div.block-container{padding-top:2rem;}</style>", unsafe_allow_html=True
95
- )
96
-
97
- intro = """
98
- The success of startups like Open AI and Stability highlights the potential for transfer learning (TL) techniques to have a similar impact on the field of time series forecasting.
99
-
100
- TL can achieve lightning-fast predictions with a fraction of the computational cost by pre-training a flexible model on a large dataset and then using it on another dataset with little to no additional training.
101
-
102
- In this live demo, you can use pre-trained models by Nixtla (trained on the M4 dataset) to predict your own datasets. You can also see how the models perform on unseen example datasets.
103
- """
104
- st.write(intro)
105
-
106
- required_cols = ["ds", "y"]
107
-
108
- with st.sidebar.expander("Dataset", expanded=False):
109
- data_selection = st.selectbox("Select example dataset", DATASETS.keys())
110
- data_url = DATASETS[data_selection]
111
- url_json = st.text_input("Data (you can pass your own url here)", data_url)
112
- st.write(
113
- "You can also upload a CSV file like [this one](https://github.com/Nixtla/transfer-learning-time-series/blob/main/datasets/air_passengers.csv)."
114
- )
115
-
116
- uploaded_file = st.file_uploader("Upload CSV")
117
- with st.form("Data"):
118
-
119
- if uploaded_file is not None:
120
- df = pd.read_csv(uploaded_file)
121
- cols = df.columns
122
- timestamp_col = st.selectbox("Timestamp column", options=cols)
123
- value_col = st.selectbox("Value column", options=cols)
124
- else:
125
- timestamp_col = st.text_input("Timestamp column", value="timestamp")
126
- value_col = st.text_input("Value column", value="value")
127
- st.write("You must press Submit each time you want to forecast.")
128
- submitted = st.form_submit_button("Submit")
129
- if submitted:
130
- if uploaded_file is None:
131
- st.write("Please provide a dataframe.")
132
- if url_json.endswith("json"):
133
- df = pd.read_json(url_json)
134
- else:
135
- df = pd.read_csv(url_json)
136
- df = df.rename(
137
- columns=dict(zip([timestamp_col, value_col], required_cols))
138
- )
139
- else:
140
- # df = pd.read_csv(uploaded_file)
141
- df = df.rename(
142
- columns=dict(zip([timestamp_col, value_col], required_cols))
143
- )
144
- else:
145
- if url_json.endswith("json"):
146
- df = pd.read_json(url_json)
147
- else:
148
- df = pd.read_csv(url_json)
149
- cols = df.columns
150
- if "unique_id" in cols:
151
- cols = cols[-2:]
152
- df = df.rename(columns=dict(zip(cols, required_cols)))
153
-
154
- if "unique_id" not in df:
155
- df.insert(0, "unique_id", "ts_0")
156
-
157
- df["ds"] = pd.to_datetime(df["ds"])
158
- df = df.sort_values(["unique_id", "ds"])
159
-
160
- with st.sidebar:
161
- st.write("Define the pretrained model you want to use to forecast your data")
162
- model_name = st.selectbox("Select your model", tuple(MODELS.keys()))
163
- model_file = MODELS[model_name]["model"]
164
- st.write("Choose how many steps you want to forecast")
165
- fh = st.number_input("Forecast horizon", value=18)
166
- st.write(
167
- "Choose for how many steps the pretrained model will be updated using your data (use 0 for fast computation)"
168
- )
169
- max_steps = st.number_input("N-shot inference", value=0)
170
-
171
- # tabs
172
- tab_fcst, tab_cv, tab_docs, tab_nixtla = st.tabs(
173
- [
174
- "๐Ÿ“ˆ Forecast",
175
- "๐Ÿ”Ž Cross Validation",
176
- "๐Ÿ“š Documentation",
177
- "๐Ÿ”ฎ Nixtlaverse",
178
- ]
179
- )
180
-
181
- uids = df["unique_id"].unique()
182
- fcst_cols = ["forecast_lo_90", "forecast", "forecast_hi_90"]
183
-
184
- with tab_fcst:
185
- uid = uids[0]#st.selectbox("Dataset", options=uids)
186
- col1, col2 = st.columns([2, 4])
187
- with col1:
188
- tab_insample, tab_forecast = st.tabs(
189
- ["Modify input data", "Modify forecasts"]
190
- )
191
- with tab_insample:
192
- df_grid = df.query("unique_id == @uid").drop(columns="unique_id")
193
- grid_table = AgGrid(
194
- df_grid,
195
- editable=True,
196
- theme="streamlit",
197
- fit_columns_on_grid_load=True,
198
- height=360,
199
- )
200
- df.loc[df["unique_id"] == uid, "y"] = (
201
- grid_table["data"].sort_values("ds")["y"].values
202
- )
203
- # forecast code
204
- init = time()
205
- df_forecast = forecast_pretrained_model(df, model_file, fh, max_steps)
206
- end = time()
207
- df_forecast = df_forecast.rename(
208
- columns=dict(zip(["y_5", "y_50", "y_95"], fcst_cols))
209
- )
210
- with tab_forecast:
211
- df_fcst_grid = df_forecast.query("unique_id == @uid").filter(
212
- ["ds", "forecast"]
213
- )
214
- grid_fcst_table = AgGrid(
215
- df_fcst_grid,
216
- editable=True,
217
- theme="streamlit",
218
- fit_columns_on_grid_load=True,
219
- height=360,
220
- )
221
- changes = (
222
- df_forecast.query("unique_id == @uid")["forecast"].values
223
- - grid_fcst_table["data"].sort_values("ds")["forecast"].values
224
- )
225
- for col in fcst_cols:
226
- df_forecast.loc[df_forecast["unique_id"] == uid, col] = (
227
- df_forecast.loc[df_forecast["unique_id"] == uid, col] - changes
228
- )
229
- with col2:
230
- st.plotly_chart(
231
- plot(
232
- df.query("unique_id == @uid"),
233
- uid,
234
- df_forecast.query("unique_id == @uid"),
235
- model_name,
236
- ),
237
- use_container_width=True,
238
- )
239
- st.success(f'Done! Approximate inference time CPU: {0.7*(end-init):.2f} seconds.')
240
-
241
- with tab_cv:
242
- col_uid, col_n_windows = st.columns(2)
243
- uid = uids[0]
244
- #with col_uid:
245
- # uid = st.selectbox("Time series to analyse", options=uids, key="uid_cv")
246
- with col_n_windows:
247
- n_windows = st.number_input("Cross validation windows", value=1)
248
- df_forecast = []
249
- for i_window in range(n_windows, 0, -1):
250
- test = df.groupby("unique_id").tail(i_window * fh)
251
- df_forecast_w = forecast_pretrained_model(
252
- df.drop(test.index), model_file, fh, max_steps
253
- )
254
- df_forecast_w = df_forecast_w.rename(
255
- columns=dict(zip(["y_5", "y_50", "y_95"], fcst_cols))
256
- )
257
- df_forecast_w.insert(2, "window", i_window)
258
- df_forecast.append(df_forecast_w)
259
- df_forecast = pd.concat(df_forecast)
260
- df_forecast["ds"] = pd.to_datetime(df_forecast["ds"])
261
- df_forecast = df_forecast.merge(df, how="left", on=["unique_id", "ds"])
262
- metrics = [mae, mape, rmse, smape]
263
- evaluation = df_forecast.groupby(["unique_id", "window"]).apply(
264
- lambda df: [f'{fn(df["y"].values, df["forecast"]):.2f}' for fn in metrics]
265
- )
266
- evaluation = evaluation.rename("eval").reset_index()
267
- evaluation["eval"] = evaluation["eval"].str.join(",")
268
- evaluation[["MAE", "MAPE", "RMSE", "sMAPE"]] = evaluation["eval"].str.split(
269
- ",", expand=True
270
- )
271
- col_eval, col_plot = st.columns([2, 4])
272
- with col_eval:
273
- st.write("Evaluation metrics for each cross validation window")
274
- st.dataframe(
275
- evaluation.query("unique_id == @uid")
276
- .drop(columns=["unique_id", "eval"])
277
- .set_index("window")
278
- )
279
- with col_plot:
280
- st.plotly_chart(
281
- plot(
282
- df.query("unique_id == @uid"),
283
- uid,
284
- df_forecast.query("unique_id == @uid").drop(columns="y"),
285
- model_name,
286
- ),
287
- use_container_width=True,
288
- )
289
- with tab_docs:
290
- tab_transfer, tab_desc, tab_ref = st.tabs(
291
- [
292
- "๐Ÿš€ Transfer Learning",
293
- "๐Ÿ”Ž Description of the model",
294
- "๐Ÿ“š References",
295
- ]
296
- )
297
-
298
- with tab_desc:
299
- model_card_name = MODELS[model_name]["card"]
300
- st.subheader("Abstract")
301
- st.write(f"""{model_cards[model_card_name]['Abstract']}""")
302
- st.subheader("Intended use")
303
- st.write(f"""{model_cards[model_card_name]['Intended use']}""")
304
- st.subheader("Secondary use")
305
- st.write(f"""{model_cards[model_card_name]['Secondary use']}""")
306
- st.subheader("Limitations")
307
- st.write(f"""{model_cards[model_card_name]['Limitations']}""")
308
- st.subheader("Training data")
309
- st.write(f"""{model_cards[model_card_name]['Training data']}""")
310
- st.subheader("BibTex/Citation Info")
311
- st.code(f"""{model_cards[model_card_name]['Citation Info']}""")
312
-
313
- with tab_transfer:
314
- transfer_text = """
315
- Transfer learning refers to the process of pre-training a flexible model on a large dataset and using it later on other data with little to no training. It is one of the most outstanding ๐Ÿš€ achievements in Machine Learning ๐Ÿง  and has many practical applications.
316
-
317
- For time series forecasting, the technique allows you to get lightning-fast predictions โšก bypassing the tradeoff between accuracy and speed.
318
-
319
- [This notebook](https://colab.research.google.com/drive/1uFCO2UBpH-5l2fk3KmxfU0oupsOC6v2n?authuser=0&pli=1#cell-5=) shows how to generate a pre-trained model and store it in a checkpoint to make it available for public use to forecast new time series never seen by the model.
320
- **You can contribute with your pre-trained models by following [this Notebook](https://github.com/Nixtla/transfer-learning-time-series/blob/main/nbs/Transfer_Learning.ipynb) and sending us an email at federico[at]nixtla.io**
321
-
322
- You can also take a look at list of pretrained models here. Currently we have this ones avaiable in our [API](https://docs.nixtla.io/reference/neural_transfer_neural_transfer_post) or [Demo](http://nixtla.io/transfer-learning/). You can also download the `.ckpt`:
323
- - [Pretrained N-HiTS M4 Hourly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_hourly.ckpt)
324
- - [Pretrained N-HiTS M4 Hourly (Tiny)](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_hourly_tiny.ckpt)
325
- - [Pretrained N-HiTS M4 Daily](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_daily.ckpt)
326
- - [Pretrained N-HiTS M4 Monthly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_monthly.ckpt)
327
- - [Pretrained N-HiTS M4 Yearly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_yearly.ckpt)
328
- - [Pretrained N-BEATS M4 Hourly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_hourly.ckpt)
329
- - [Pretrained N-BEATS M4 Daily](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_daily.ckpt)
330
- - [Pretrained N-BEATS M4 Weekly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_weekly.ckpt)
331
- - [Pretrained N-BEATS M4 Monthly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_monthly.ckpt)
332
- - [Pretrained N-BEATS M4 Yearly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_yearly.ckpt)
333
- """
334
- st.write(transfer_text)
335
-
336
- with tab_ref:
337
- ref_text = """
338
- If you are interested in the transfer learning literature applied to time series forecasting, take a look at these papers:
339
- - [Meta-learning framework with applications to zero-shot time-series forecasting](https://arxiv.org/abs/2002.02887)
340
- - [N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting](https://arxiv.org/abs/2201.12886)
341
- """
342
- st.write(ref_text)
343
-
344
- with tab_nixtla:
345
- nixtla_text = """
346
- Nixtla is a startup that is building forecasting software for Data Scientists and Devs.
347
-
348
- We have been developing different open source libraries for machine learning, statistical and deep learning forecasting.
349
-
350
- In our [GitHub repo](https://github.com/Nixtla), you can find the projects that support this APP.
351
- """
352
- st.write(nixtla_text)
353
- st.image(
354
- "https://files.readme.io/168cdb2-Screen_Shot_2022-09-30_at_10.40.09.png",
355
- width=800,
356
- )
357
-
358
- with st.sidebar:
359
- st.download_button(
360
- label="Download historical data as CSV",
361
- data=convert_df(df),
362
- file_name="history.csv",
363
- mime="text/csv",
364
- )
365
- st.download_button(
366
- label="Download forecasts as CSV",
367
- data=convert_df(df_forecast),
368
- file_name="forecasts.csv",
369
- mime="text/csv",
370
- )
371
-
372
-
373
- if __name__ == "__main__":
374
- st_transfer_learning()
 
1
+ from time import time
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ import streamlit as st
8
+ from datasetsforecast.losses import rmse, mae, smape, mse, mape
9
+ from st_aggrid import AgGrid
10
+
11
+ from src.nf import MODELS, forecast_pretrained_model
12
+ from src.model_descriptions import model_cards
13
+
14
+ DATASETS = {
15
+ "Electricity (Ercot COAST)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/ercot_COAST.csv",
16
+ #"Electriciy (ERCOT, multiple markets)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/ercot_multiple_ts.csv",
17
+ "Web Traffic (Peyton Manning)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/peyton_manning.csv",
18
+ "Demand (AirPassengers)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv",
19
+ "Finance (Exchange USD-EUR)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/usdeur.csv",
20
+ }
21
+
22
+
23
+ @st.cache_data
24
+ def convert_df(df):
25
+ # IMPORTANT: Cache the conversion to prevent computation on every rerun
26
+ return df.to_csv(index=False).encode("utf-8")
27
+
28
+
29
+ def plot(df, uid, df_forecast, model):
30
+ figs = []
31
+ figs += [
32
+ go.Scatter(
33
+ x=df["ds"],
34
+ y=df["y"],
35
+ mode="lines",
36
+ marker=dict(color="#236796"),
37
+ legendrank=1,
38
+ name=uid,
39
+ ),
40
+ ]
41
+ if df_forecast is not None:
42
+ ds_f = df_forecast["ds"].to_list()
43
+ lo = df_forecast["forecast_lo_90"].to_list()
44
+ hi = df_forecast["forecast_hi_90"].to_list()
45
+ figs += [
46
+ go.Scatter(
47
+ x=ds_f + ds_f[::-1],
48
+ y=hi + lo[::-1],
49
+ fill="toself",
50
+ fillcolor="#E7C4C0",
51
+ mode="lines",
52
+ line=dict(color="#E7C4C0"),
53
+ name="Prediction Intervals (90%)",
54
+ legendrank=5,
55
+ opacity=0.5,
56
+ hoverinfo="skip",
57
+ ),
58
+ go.Scatter(
59
+ x=ds_f,
60
+ y=df_forecast["forecast"],
61
+ mode="lines",
62
+ legendrank=4,
63
+ marker=dict(color="#E7C4C0"),
64
+ name=f"Forecast {uid}",
65
+ ),
66
+ ]
67
+ fig = go.Figure(figs)
68
+ fig.update_layout(
69
+ {"plot_bgcolor": "rgba(0, 0, 0, 0)", "paper_bgcolor": "rgba(0, 0, 0, 0)"}
70
+ )
71
+ fig.update_layout(
72
+ title=f"Forecasts for {uid} using Transfer Learning (from {model})",
73
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
74
+ margin=dict(l=20, b=20),
75
+ xaxis=dict(rangeslider=dict(visible=True)),
76
+ )
77
+ initial_range = [df.tail(200)["ds"].iloc[0], ds_f[-1]]
78
+ fig["layout"]["xaxis"].update(range=initial_range)
79
+ return fig
80
+
81
+
82
+ def st_transfer_learning():
83
+ st.set_page_config(
84
+ page_title="Time Series Forecasting",
85
+ page_icon="๐Ÿ”ฎ",
86
+ layout="wide",
87
+ initial_sidebar_state="expanded",
88
+ )
89
+
90
+ st.title(
91
+ "Transfer Learning: Revolutionizing Time Series"
92
+ )
93
+ st.write(
94
+ "<style>div.block-container{padding-top:2rem;}</style>", unsafe_allow_html=True
95
+ )
96
+
97
+ intro = """
98
+ The success of startups like Open AI and Stability highlights the potential for transfer learning (TL) techniques to have a similar impact on the field of time series forecasting.
99
+
100
+ TL can achieve lightning-fast predictions with a fraction of the computational cost by pre-training a flexible model on a large dataset and then using it on another dataset with little to no additional training.
101
+
102
+ In this live demo, you can use pre-trained models by Nixtla (trained on the M4 dataset) to predict your own datasets. You can also see how the models perform on unseen example datasets.
103
+ """
104
+ st.write(intro)
105
+
106
+ required_cols = ["ds", "y"]
107
+
108
+ with st.sidebar.expander("Dataset", expanded=False):
109
+ data_selection = st.selectbox("Select example dataset", DATASETS.keys())
110
+ data_url = DATASETS[data_selection]
111
+ url_json = st.text_input("Data (you can pass your own url here)", data_url)
112
+ st.write(
113
+ "You can also upload a CSV file like [this one](https://github.com/Nixtla/transfer-learning-time-series/blob/main/datasets/air_passengers.csv)."
114
+ )
115
+
116
+ uploaded_file = st.file_uploader("Upload CSV")
117
+ with st.form("Data"):
118
+
119
+ if uploaded_file is not None:
120
+ df = pd.read_csv(uploaded_file)
121
+ cols = df.columns
122
+ timestamp_col = st.selectbox("Timestamp column", options=cols)
123
+ value_col = st.selectbox("Value column", options=cols)
124
+ else:
125
+ timestamp_col = st.text_input("Timestamp column", value="timestamp")
126
+ value_col = st.text_input("Value column", value="value")
127
+ st.write("You must press Submit each time you want to forecast.")
128
+ submitted = st.form_submit_button("Submit")
129
+ if submitted:
130
+ if uploaded_file is None:
131
+ st.write("Please provide a dataframe.")
132
+ if url_json.endswith("json"):
133
+ df = pd.read_json(url_json)
134
+ else:
135
+ df = pd.read_csv(url_json)
136
+ df = df.rename(
137
+ columns=dict(zip([timestamp_col, value_col], required_cols))
138
+ )
139
+ else:
140
+ # df = pd.read_csv(uploaded_file)
141
+ df = df.rename(
142
+ columns=dict(zip([timestamp_col, value_col], required_cols))
143
+ )
144
+ else:
145
+ if url_json.endswith("json"):
146
+ df = pd.read_json(url_json)
147
+ else:
148
+ df = pd.read_csv(url_json)
149
+ cols = df.columns
150
+ if "unique_id" in cols:
151
+ cols = cols[-2:]
152
+ df = df.rename(columns=dict(zip(cols, required_cols)))
153
+
154
+ if "unique_id" not in df:
155
+ df.insert(0, "unique_id", "ts_0")
156
+
157
+ df["ds"] = pd.to_datetime(df["ds"])
158
+ df = df.sort_values(["unique_id", "ds"])
159
+
160
+ with st.sidebar:
161
+ st.write("Define the pretrained model you want to use to forecast your data")
162
+ model_name = st.selectbox("Select your model", tuple(MODELS.keys()))
163
+ model_file = MODELS[model_name]["model"]
164
+ st.write("Choose how many steps you want to forecast")
165
+ fh = st.number_input("Forecast horizon", value=18)
166
+ st.write(
167
+ "Choose for how many steps the pretrained model will be updated using your data (use 0 for fast computation)"
168
+ )
169
+ max_steps = st.number_input("N-shot inference", value=0)
170
+
171
+ # tabs
172
+ tab_fcst, tab_cv, tab_docs = st.tabs(
173
+ [
174
+ "๐Ÿ“ˆ Forecast",
175
+ "๐Ÿ”Ž Cross Validation",
176
+ "๐Ÿ“š Documentation",
177
+ ]
178
+ )
179
+
180
+ uids = df["unique_id"].unique()
181
+ fcst_cols = ["forecast_lo_90", "forecast", "forecast_hi_90"]
182
+
183
+ with tab_fcst:
184
+ uid = uids[0]#st.selectbox("Dataset", options=uids)
185
+ col1, col2 = st.columns([2, 4])
186
+ with col1:
187
+ tab_insample, tab_forecast = st.tabs(
188
+ ["Modify input data", "Modify forecasts"]
189
+ )
190
+ with tab_insample:
191
+ df_grid = df.query("unique_id == @uid").drop(columns="unique_id")
192
+ grid_table = AgGrid(
193
+ df_grid,
194
+ editable=True,
195
+ theme="streamlit",
196
+ fit_columns_on_grid_load=True,
197
+ height=360,
198
+ )
199
+ df.loc[df["unique_id"] == uid, "y"] = (
200
+ grid_table["data"].sort_values("ds")["y"].values
201
+ )
202
+ # forecast code
203
+ init = time()
204
+ df_forecast = forecast_pretrained_model(df, model_file, fh, max_steps)
205
+ end = time()
206
+ df_forecast = df_forecast.rename(
207
+ columns=dict(zip(["y_5", "y_50", "y_95"], fcst_cols))
208
+ )
209
+ with tab_forecast:
210
+ df_fcst_grid = df_forecast.query("unique_id == @uid").filter(
211
+ ["ds", "forecast"]
212
+ )
213
+ grid_fcst_table = AgGrid(
214
+ df_fcst_grid,
215
+ editable=True,
216
+ theme="streamlit",
217
+ fit_columns_on_grid_load=True,
218
+ height=360,
219
+ )
220
+ changes = (
221
+ df_forecast.query("unique_id == @uid")["forecast"].values
222
+ - grid_fcst_table["data"].sort_values("ds")["forecast"].values
223
+ )
224
+ for col in fcst_cols:
225
+ df_forecast.loc[df_forecast["unique_id"] == uid, col] = (
226
+ df_forecast.loc[df_forecast["unique_id"] == uid, col] - changes
227
+ )
228
+ with col2:
229
+ st.plotly_chart(
230
+ plot(
231
+ df.query("unique_id == @uid"),
232
+ uid,
233
+ df_forecast.query("unique_id == @uid"),
234
+ model_name,
235
+ ),
236
+ use_container_width=True,
237
+ )
238
+ st.success(f'Done! Approximate inference time CPU: {0.7*(end-init):.2f} seconds.')
239
+
240
+ with tab_cv:
241
+ col_uid, col_n_windows = st.columns(2)
242
+ uid = uids[0]
243
+ #with col_uid:
244
+ # uid = st.selectbox("Time series to analyse", options=uids, key="uid_cv")
245
+ with col_n_windows:
246
+ n_windows = st.number_input("Cross validation windows", value=1)
247
+ df_forecast = []
248
+ for i_window in range(n_windows, 0, -1):
249
+ test = df.groupby("unique_id").tail(i_window * fh)
250
+ df_forecast_w = forecast_pretrained_model(
251
+ df.drop(test.index), model_file, fh, max_steps
252
+ )
253
+ df_forecast_w = df_forecast_w.rename(
254
+ columns=dict(zip(["y_5", "y_50", "y_95"], fcst_cols))
255
+ )
256
+ df_forecast_w.insert(2, "window", i_window)
257
+ df_forecast.append(df_forecast_w)
258
+ df_forecast = pd.concat(df_forecast)
259
+ df_forecast["ds"] = pd.to_datetime(df_forecast["ds"])
260
+ df_forecast = df_forecast.merge(df, how="left", on=["unique_id", "ds"])
261
+ metrics = [mae, mape, rmse, smape]
262
+ evaluation = df_forecast.groupby(["unique_id", "window"]).apply(
263
+ lambda df: [f'{fn(df["y"].values, df["forecast"]):.2f}' for fn in metrics]
264
+ )
265
+ evaluation = evaluation.rename("eval").reset_index()
266
+ evaluation["eval"] = evaluation["eval"].str.join(",")
267
+ evaluation[["MAE", "MAPE", "RMSE", "sMAPE"]] = evaluation["eval"].str.split(
268
+ ",", expand=True
269
+ )
270
+ col_eval, col_plot = st.columns([2, 4])
271
+ with col_eval:
272
+ st.write("Evaluation metrics for each cross validation window")
273
+ st.dataframe(
274
+ evaluation.query("unique_id == @uid")
275
+ .drop(columns=["unique_id", "eval"])
276
+ .set_index("window")
277
+ )
278
+ with col_plot:
279
+ st.plotly_chart(
280
+ plot(
281
+ df.query("unique_id == @uid"),
282
+ uid,
283
+ df_forecast.query("unique_id == @uid").drop(columns="y"),
284
+ model_name,
285
+ ),
286
+ use_container_width=True,
287
+ )
288
+ with tab_docs:
289
+ tab_transfer, tab_desc, tab_ref = st.tabs(
290
+ [
291
+ "๐Ÿš€ Transfer Learning",
292
+ "๐Ÿ”Ž Description of the model",
293
+ "๐Ÿ“š References",
294
+ ]
295
+ )
296
+
297
+ with tab_desc:
298
+ model_card_name = MODELS[model_name]["card"]
299
+ st.subheader("Abstract")
300
+ st.write(f"""{model_cards[model_card_name]['Abstract']}""")
301
+ st.subheader("Intended use")
302
+ st.write(f"""{model_cards[model_card_name]['Intended use']}""")
303
+ st.subheader("Secondary use")
304
+ st.write(f"""{model_cards[model_card_name]['Secondary use']}""")
305
+ st.subheader("Limitations")
306
+ st.write(f"""{model_cards[model_card_name]['Limitations']}""")
307
+ st.subheader("Training data")
308
+ st.write(f"""{model_cards[model_card_name]['Training data']}""")
309
+ st.subheader("BibTex/Citation Info")
310
+ st.code(f"""{model_cards[model_card_name]['Citation Info']}""")
311
+
312
+ with tab_transfer:
313
+ transfer_text = """
314
+ Transfer learning refers to the process of pre-training a flexible model on a large dataset and using it later on other data with little to no training. It is one of the most outstanding ๐Ÿš€ achievements in Machine Learning ๐Ÿง  and has many practical applications.
315
+
316
+ For time series forecasting, the technique allows you to get lightning-fast predictions โšก bypassing the tradeoff between accuracy and speed.
317
+
318
+ [This notebook](https://colab.research.google.com/drive/1uFCO2UBpH-5l2fk3KmxfU0oupsOC6v2n?authuser=0&pli=1#cell-5=) shows how to generate a pre-trained model and store it in a checkpoint to make it available for public use to forecast new time series never seen by the model.
319
+
320
+ You can also take a look at list of pretrained models here. Currently we have this ones avaiable in our [API](https://docs.nixtla.io/reference/neural_transfer_neural_transfer_post) or [Demo](http://nixtla.io/transfer-learning/). You can also download the `.ckpt`:
321
+ - [Pretrained N-HiTS M4 Hourly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_hourly.ckpt)
322
+ - [Pretrained N-HiTS M4 Hourly (Tiny)](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_hourly_tiny.ckpt)
323
+ - [Pretrained N-HiTS M4 Daily](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_daily.ckpt)
324
+ - [Pretrained N-HiTS M4 Monthly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_monthly.ckpt)
325
+ - [Pretrained N-HiTS M4 Yearly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_yearly.ckpt)
326
+ - [Pretrained N-BEATS M4 Hourly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_hourly.ckpt)
327
+ - [Pretrained N-BEATS M4 Daily](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_daily.ckpt)
328
+ - [Pretrained N-BEATS M4 Weekly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_weekly.ckpt)
329
+ - [Pretrained N-BEATS M4 Monthly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_monthly.ckpt)
330
+ - [Pretrained N-BEATS M4 Yearly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_yearly.ckpt)
331
+ """
332
+ st.write(transfer_text)
333
+
334
+ with tab_ref:
335
+ ref_text = """
336
+ If you are interested in the transfer learning literature applied to time series forecasting, take a look at these papers:
337
+ - [Meta-learning framework with applications to zero-shot time-series forecasting](https://arxiv.org/abs/2002.02887)
338
+ - [N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting](https://arxiv.org/abs/2201.12886)
339
+ """
340
+ st.write(ref_text)
341
+
342
+ # with tab_dummy:
343
+ # nixtla_text = """
344
+ # Nixtla is a startup that is building forecasting software for Data Scientists and Devs.
345
+
346
+ # We have been developing different open source libraries for machine learning, statistical and deep learning forecasting.
347
+
348
+ # In our [GitHub repo](https://github.com/Nixtla), you can find the projects that support this APP.
349
+ # """
350
+ # st.write(nixtla_text)
351
+ # st.image(
352
+ # "https://files.readme.io/168cdb2-Screen_Shot_2022-09-30_at_10.40.09.png",
353
+ # width=800,
354
+ # )
355
+
356
+ with st.sidebar:
357
+ st.download_button(
358
+ label="Download historical data as CSV",
359
+ data=convert_df(df),
360
+ file_name="history.csv",
361
+ mime="text/csv",
362
+ )
363
+ st.download_button(
364
+ label="Download forecasts as CSV",
365
+ data=convert_df(df_forecast),
366
+ file_name="forecasts.csv",
367
+ mime="text/csv",
368
+ )
369
+
370
+
371
+ if __name__ == "__main__":
372
+ st_transfer_learning()