azrai99 commited on
Commit
f4d5347
·
verified ·
1 Parent(s): 4f9e43c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -52
app.py CHANGED
@@ -15,95 +15,262 @@ st.set_page_config(layout='wide')
15
 
16
  @st.cache_resource
17
  def load_model(path, freq):
18
- return NeuralForecast.load(path=path)
 
19
 
20
  @st.cache_resource
21
  def load_all_models():
22
- model_paths = {
23
- 'D': './M4/{model}/daily',
24
- 'M': './M4/{model}/monthly',
25
- 'H': './M4/{model}/hourly',
26
- 'W': './M4/{model}/weekly',
27
- 'Y': './M4/{model}/yearly'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  }
29
 
30
- models = ['NHITS', 'TimesNet', 'LSTM', 'TFT']
31
- all_models = {model: {freq: load_model(model_paths[freq].format(model=model), freq) for freq in model_paths} for model in models}
32
- return all_models
 
 
 
 
 
 
 
 
 
 
33
 
34
- def generate_forecast(model, df, tag=False):
35
- return model.predict() if tag == 'retrain' else model.predict(df=df)
 
 
 
 
36
 
37
  def determine_frequency(df):
38
- df['ds'] = pd.to_datetime(df['ds']).drop_duplicates().set_index('ds')
39
- freq = pd.infer_freq(df.index) or 'D'
 
 
 
 
 
 
 
 
 
 
 
 
40
  if not freq:
41
- st.warning('Default Daily forecast due to date inconsistency.')
 
 
42
  return freq
43
 
 
 
 
44
  def plot_forecasts(forecast_df, train_df, title):
 
45
  plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
 
 
46
  historical_col = 'y'
47
  forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
48
  lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
49
  hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
50
 
51
- if forecast_col:
52
- fig = go.Figure()
53
- fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[historical_col], mode='lines', name='Historical'))
54
- fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[forecast_col], mode='lines', name='Forecast'))
55
- if lo_col and hi_col:
56
- fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[hi_col], mode='lines', line=dict(color='rgba(0,100,80,0.2)'), showlegend=False))
57
- fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[lo_col], mode='lines', fill='tonexty', fillcolor='rgba(0,100,80,0.2)', name='90% Confidence Interval'))
58
- fig.update_layout(title=title, xaxis_title='Timestamp', yaxis_title='Value', template='plotly_white')
59
- st.plotly_chart(fig)
60
-
61
- def select_model_based_on_frequency(freq, models):
62
- return {model: models[model][freq] for model in models}
63
-
64
- def model_train(df, model, freq):
65
- nf = NeuralForecast(models=[model], freq=freq)
66
- df['ds'] = pd.to_datetime(df['ds'])
67
- nf.fit(df)
68
- return nf
69
-
70
- def forecast_time_series(df, model_type, horizon, max_steps, y_col):
71
- freq = determine_frequency(df)
72
- st.sidebar.write(f"Data frequency: {freq}")
73
 
74
- selected_model = select_model(horizon, model_type, max_steps)
75
- model = model_train(df, selected_model, freq)
76
 
77
- forecast_results = generate_forecast(model, df, tag='retrain')
78
- st.session_state.forecast_results = forecast_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- plot_forecasts(forecast_results.iloc[:horizon,:], df, f'{model_type} Forecast for {y_col}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  @st.cache_data
83
  def load_default():
84
- return AirPassengersDF.copy()
 
85
 
86
  def transfer_learning_forecasting():
87
  st.title("Zero-shot Forecasting")
88
- all_models = load_all_models()
89
-
90
- df = st.session_state.get('df', load_default())
91
- ds_col, y_col = st.sidebar.selectbox("Date/Time column", df.columns), st.sidebar.selectbox("Target column", df.columns)
92
-
93
- df = df.rename(columns={ds_col: 'ds', y_col: 'y'}).assign(unique_id=1)[['unique_id', 'ds', 'y']]
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  frequency = determine_frequency(df)
96
- models = select_model_based_on_frequency(frequency, all_models)
 
 
 
 
 
97
 
 
98
  if st.sidebar.button("Submit"):
99
- model_choice = st.sidebar.selectbox("Select model", models.keys())
100
- forecast_time_series(df, model_choice, st.sidebar.number_input("Horizon", value=12), 50, y_col)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
 
102
 
103
  pg = st.navigation({
104
  "Neuralforecast": [
105
  # Load pages from functions
106
  st.Page(transfer_learning_forecasting, title="Zero-shot Forecasting", default=True, icon=":material/query_stats:"),
 
107
  ],
108
  })
109
 
 
15
 
16
  @st.cache_resource
17
  def load_model(path, freq):
18
+ nf = NeuralForecast.load(path=path)
19
+ return nf
20
 
21
  @st.cache_resource
22
  def load_all_models():
23
+ nhits_paths = {
24
+ 'D': './M4/NHITS/daily',
25
+ 'M': './M4/NHITS/monthly',
26
+ 'H': './M4/NHITS/hourly',
27
+ 'W': './M4/NHITS/weekly',
28
+ 'Y': './M4/NHITS/yearly'
29
+ }
30
+
31
+ timesnet_paths = {
32
+ 'D': './M4/TimesNet/daily',
33
+ 'M': './M4/TimesNet/monthly',
34
+ 'H': './M4/TimesNet/hourly',
35
+ 'W': './M4/TimesNet/weekly',
36
+ 'Y': './M4/TimesNet/yearly'
37
+ }
38
+
39
+ lstm_paths = {
40
+ 'D': './M4/LSTM/daily',
41
+ 'M': './M4/LSTM/monthly',
42
+ 'H': './M4/LSTM/hourly',
43
+ 'W': './M4/LSTM/weekly',
44
+ 'Y': './M4/LSTM/yearly'
45
  }
46
 
47
+ tft_paths = {
48
+ 'D': './M4/TFT/daily',
49
+ 'M': './M4/TFT/monthly',
50
+ 'H': './M4/TFT/hourly',
51
+ 'W': './M4/TFT/weekly',
52
+ 'Y': './M4/TFT/yearly'
53
+ }
54
+ nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()}
55
+ timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()}
56
+ lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()}
57
+ tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()}
58
+
59
+ return nhits_models, timesnet_models, lstm_models, tft_models
60
 
61
+ def generate_forecast(model, df,tag=False):
62
+ if tag == 'retrain':
63
+ forecast_df = model.predict()
64
+ else:
65
+ forecast_df = model.predict(df=df)
66
+ return forecast_df
67
 
68
  def determine_frequency(df):
69
+ df['ds'] = pd.to_datetime(df['ds'])
70
+ df = df.drop_duplicates(subset='ds')
71
+ df = df.set_index('ds')
72
+
73
+ # # Create a complete date range
74
+ # full_range = pd.date_range(start=df.index.min(), end=df.index.max(),freq=freq)
75
+
76
+ # # Reindex the DataFrame to this full date range
77
+ # df_full = df.reindex(full_range)
78
+
79
+ # Infer the frequency
80
+ # freq = pd.infer_freq(df_full.index)
81
+
82
+ freq = pd.infer_freq(df.index)
83
  if not freq:
84
+ st.warning('The forecast will use default Daily forecast due to date inconsistency. Please check your data.',icon="⚠️")
85
+ freq = 'D'
86
+
87
  return freq
88
 
89
+
90
+ import plotly.graph_objects as go
91
+
92
  def plot_forecasts(forecast_df, train_df, title):
93
+ # Combine historical and forecast data
94
  plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
95
+
96
+ # Find relevant columns
97
  historical_col = 'y'
98
  forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
99
  lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
100
  hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
101
 
102
+ if forecast_col is None:
103
+ raise KeyError("No forecast column found in the data.")
104
+
105
+ # Create Plotly figure
106
+ fig = go.Figure()
107
+
108
+ # Add historical data
109
+ fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[historical_col], mode='lines', name='Historical'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ # Add forecast data
112
+ fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[forecast_col], mode='lines', name='Forecast'))
113
 
114
+ # Add confidence interval if available
115
+ if lo_col and hi_col:
116
+ fig.add_trace(go.Scatter(
117
+ x=plot_df.index,
118
+ y=plot_df[hi_col],
119
+ mode='lines',
120
+ line=dict(color='rgba(0,100,80,0.2)'),
121
+ showlegend=False
122
+ ))
123
+ fig.add_trace(go.Scatter(
124
+ x=plot_df.index,
125
+ y=plot_df[lo_col],
126
+ mode='lines',
127
+ line=dict(color='rgba(0,100,80,0.2)'),
128
+ fill='tonexty',
129
+ fillcolor='rgba(0,100,80,0.2)',
130
+ name='90% Confidence Interval'
131
+ ))
132
 
133
+ # Update layout
134
+ fig.update_layout(
135
+ title=title,
136
+ xaxis_title='Timestamp [t]',
137
+ yaxis_title='Value',
138
+ template='plotly_white'
139
+ )
140
+
141
+ # Display the plot
142
+ st.plotly_chart(fig)
143
+
144
+
145
+ def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models):
146
+ if freq == 'D':
147
+ return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D']
148
+ elif freq == 'ME':
149
+ return nhits_models['M'], timesnet_models['M'], lstm_models['M'], tft_models['M']
150
+ elif freq == 'H':
151
+ return nhits_models['H'], timesnet_models['H'], lstm_models['H'], tft_models['H']
152
+ elif freq in ['W', 'W-SUN']:
153
+ return nhits_models['W'], timesnet_models['W'], lstm_models['W'], tft_models['W']
154
+ elif freq in ['Y', 'Y-DEC']:
155
+ return nhits_models['Y'], timesnet_models['Y'], lstm_models['Y'], tft_models['Y']
156
+ else:
157
+ raise ValueError(f"Unsupported frequency: {freq}")
158
 
159
  @st.cache_data
160
  def load_default():
161
+ df = AirPassengersDF.copy()
162
+ return df
163
 
164
  def transfer_learning_forecasting():
165
  st.title("Zero-shot Forecasting")
166
+ st.markdown("""
167
+ Instant time series forecasting and visualization by using various pre-trained deep neural network-based model trained on M4 data.
168
+ """)
169
+
170
+ nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
 
171
 
172
+ with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
173
+ if 'uploaded_file' not in st.session_state:
174
+ uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
175
+ if uploaded_file:
176
+ df = pd.read_csv(uploaded_file)
177
+ st.session_state.df = df
178
+ st.session_state.uploaded_file = uploaded_file
179
+ else:
180
+ df = load_default()
181
+ st.session_state.df = df
182
+ else:
183
+ if st.checkbox("Upload a new file (CSV)"):
184
+ uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
185
+ if uploaded_file:
186
+ df = pd.read_csv(uploaded_file)
187
+ st.session_state.df = df
188
+ st.session_state.uploaded_file = uploaded_file
189
+ else:
190
+ df = st.session_state.df
191
+ else:
192
+ df = st.session_state.df
193
+
194
+ columns = df.columns.tolist()
195
+ ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
196
+ target_columns = [col for col in columns if (col != ds_col) and (col != 'unique_id')]
197
+ y_col = st.selectbox("Select Target column", options=target_columns, index=0)
198
+
199
+ st.session_state.ds_col = ds_col
200
+ st.session_state.y_col = y_col
201
+
202
+ # Model selection and forecasting
203
+ st.sidebar.subheader("Model Selection and Forecasting")
204
+ model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
205
+ horizon = st.sidebar.number_input("Forecast horizon", value=12)
206
+
207
+ df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
208
+ df['unique_id']=1
209
+ df = df[['unique_id','ds','y']]
210
+
211
+ # Determine frequency of data
212
  frequency = determine_frequency(df)
213
+ st.sidebar.write(f"Detected frequency: {frequency}")
214
+
215
+
216
+ nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
217
+ forecast_results = {}
218
+
219
 
220
+
221
  if st.sidebar.button("Submit"):
222
+ start_time = time.time() # Start timing
223
+ if model_choice == "NHITS":
224
+ forecast_results['NHITS'] = generate_forecast(nhits_model, df)
225
+ elif model_choice == "TimesNet":
226
+ forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
227
+ elif model_choice == "LSTM":
228
+ forecast_results['LSTM'] = generate_forecast(lstm_model, df)
229
+ elif model_choice == "TFT":
230
+ forecast_results['TFT'] = generate_forecast(tft_model, df)
231
+
232
+ st.session_state.forecast_results = forecast_results
233
+ for model_name, forecast_df in forecast_results.items():
234
+ plot_forecasts(forecast_df.iloc[:horizon,:], df, f'{model_name} Forecast for {y_col}')
235
+
236
+ end_time = time.time() # End timing
237
+ time_taken = end_time - start_time
238
+ st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
239
+
240
+ if 'forecast_results' in st.session_state:
241
+ forecast_results = st.session_state.forecast_results
242
+
243
+ st.markdown('You can download Input and Forecast Data below')
244
+ tab_insample, tab_forecast = st.tabs(
245
+ ["Input data", "Forecast"]
246
+ )
247
+
248
+ with tab_insample:
249
+ df_grid = df.drop(columns="unique_id")
250
+ st.write(df_grid)
251
+ # grid_table = AgGrid(
252
+ # df_grid,
253
+ # theme="alpine",
254
+ # )
255
+
256
+ with tab_forecast:
257
+ if model_choice in forecast_results:
258
+ df_grid = forecast_results[model_choice]
259
+ st.write(df_grid)
260
+ # grid_table = AgGrid(
261
+ # df_grid,
262
+ # theme="alpine",
263
+ # )
264
 
265
+ def personalized_forecasting():
266
+ st.title('Personalized Forecasting')
267
+ st.subheader("Coming soon. Stay tuned")
268
 
269
  pg = st.navigation({
270
  "Neuralforecast": [
271
  # Load pages from functions
272
  st.Page(transfer_learning_forecasting, title="Zero-shot Forecasting", default=True, icon=":material/query_stats:"),
273
+ st.Page(personalized_forecasting, title="Personalized Forecasting", default=True, icon=":material/robots:")
274
  ],
275
  })
276