azrai99 commited on
Commit
0e2e1e6
·
verified ·
1 Parent(s): d066a0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -264,7 +264,8 @@ def transfer_learning_forecasting():
264
  forecast_results['LSTM'] = generate_forecast(lstm_model, df)
265
  elif model_choice == "TFT":
266
  forecast_results['TFT'] = generate_forecast(tft_model, df)
267
-
 
268
  for model_name, forecast_df in forecast_results.items():
269
  plot_forecasts(forecast_df.iloc[:horizon,:], df, f'{model_name} Forecast for {y_col}')
270
 
@@ -272,8 +273,6 @@ def transfer_learning_forecasting():
272
  time_taken = end_time - start_time
273
  st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
274
 
275
- st.session_state.forecast_results = forecast_results
276
-
277
  if 'forecast_results' in st.session_state:
278
  forecast_results = st.session_state.forecast_results
279
 
@@ -291,12 +290,13 @@ def transfer_learning_forecasting():
291
  # )
292
 
293
  with tab_forecast:
294
- df_grid = forecast_results[model_choice]
295
- st.write(df_grid)
296
- # grid_table = AgGrid(
297
- # df_grid,
298
- # theme="alpine",
299
- # )
 
300
 
301
 
302
  def dynamic_forecasting():
 
264
  forecast_results['LSTM'] = generate_forecast(lstm_model, df)
265
  elif model_choice == "TFT":
266
  forecast_results['TFT'] = generate_forecast(tft_model, df)
267
+
268
+ st.session_state.forecast_results = forecast_results
269
  for model_name, forecast_df in forecast_results.items():
270
  plot_forecasts(forecast_df.iloc[:horizon,:], df, f'{model_name} Forecast for {y_col}')
271
 
 
273
  time_taken = end_time - start_time
274
  st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
275
 
 
 
276
  if 'forecast_results' in st.session_state:
277
  forecast_results = st.session_state.forecast_results
278
 
 
290
  # )
291
 
292
  with tab_forecast:
293
+ if model_choice in forecast_results:
294
+ df_grid = forecast_results[model_choice]
295
+ st.write(df_grid)
296
+ # grid_table = AgGrid(
297
+ # df_grid,
298
+ # theme="alpine",
299
+ # )
300
 
301
 
302
  def dynamic_forecasting():