azrai99 commited on
Commit
5616e81
·
verified ·
1 Parent(s): 1210377

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -41
app.py CHANGED
@@ -248,50 +248,49 @@ def transfer_learning_forecasting():
248
  frequency = determine_frequency(df)
249
  st.sidebar.write(f"Detected frequency: {frequency}")
250
 
251
- col1, col2 = st.columns([2,4])
252
- with col1:
253
- tab_insample, tab_forecast = st.tabs(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  ["Input data", "Forecast"]
255
  )
256
- with tab_insample:
257
- df_grid = df.drop(columns="unique_id")
258
- st.write(df_grid)
259
- # grid_table = AgGrid(
260
- # df_grid,
261
- # theme="alpine",
262
- # )
263
-
264
- with tab_forecast:
265
- df_grid = df.drop(columns="unique_id")
266
- # grid_table = AgGrid(
267
- # df_grid,
268
- # theme="alpine",
269
- # )
270
-
271
- with col2:
272
- # Load pre-trained models
273
- nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
274
- forecast_results = {}
275
-
276
 
277
-
278
- if st.sidebar.button("Submit"):
279
- start_time = time.time() # Start timing
280
- if model_choice == "NHITS":
281
- forecast_results['NHITS'] = generate_forecast(nhits_model, df)
282
- elif model_choice == "TimesNet":
283
- forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
284
- elif model_choice == "LSTM":
285
- forecast_results['LSTM'] = generate_forecast(lstm_model, df)
286
- elif model_choice == "TFT":
287
- forecast_results['TFT'] = generate_forecast(tft_model, df)
288
-
289
- for model_name, forecast_df in forecast_results.items():
290
- plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
291
-
292
- end_time = time.time() # End timing
293
- time_taken = end_time - start_time
294
- st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
295
 
296
 
297
  def dynamic_forecasting():
 
248
  frequency = determine_frequency(df)
249
  st.sidebar.write(f"Detected frequency: {frequency}")
250
 
251
+
252
+ nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
253
+ forecast_results = {}
254
+
255
+
256
+
257
+ if st.sidebar.button("Submit"):
258
+ start_time = time.time() # Start timing
259
+ if model_choice == "NHITS":
260
+ forecast_results['NHITS'] = generate_forecast(nhits_model, df)
261
+ elif model_choice == "TimesNet":
262
+ forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
263
+ elif model_choice == "LSTM":
264
+ forecast_results['LSTM'] = generate_forecast(lstm_model, df)
265
+ elif model_choice == "TFT":
266
+ forecast_results['TFT'] = generate_forecast(tft_model, df)
267
+
268
+ for model_name, forecast_df in forecast_results.items():
269
+ plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
270
+
271
+ end_time = time.time() # End timing
272
+ time_taken = end_time - start_time
273
+ st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
274
+
275
+ tab_insample, tab_forecast = st.tabs(
276
  ["Input data", "Forecast"]
277
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
+ with tab_insample:
280
+ df_grid = df.drop(columns="unique_id")
281
+ st.write(df_grid)
282
+ # grid_table = AgGrid(
283
+ # df_grid,
284
+ # theme="alpine",
285
+ # )
286
+
287
+ with tab_forecast:
288
+ df_grid = forecast_results[model_choice].drop(columns="unique_id")
289
+ st.write(df_grid)
290
+ # grid_table = AgGrid(
291
+ # df_grid,
292
+ # theme="alpine",
293
+ # )
 
 
 
294
 
295
 
296
  def dynamic_forecasting():