Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -264,7 +264,8 @@ def transfer_learning_forecasting():
|
|
264 |
forecast_results['LSTM'] = generate_forecast(lstm_model, df)
|
265 |
elif model_choice == "TFT":
|
266 |
forecast_results['TFT'] = generate_forecast(tft_model, df)
|
267 |
-
|
|
|
268 |
for model_name, forecast_df in forecast_results.items():
|
269 |
plot_forecasts(forecast_df.iloc[:horizon,:], df, f'{model_name} Forecast for {y_col}')
|
270 |
|
@@ -272,8 +273,6 @@ def transfer_learning_forecasting():
|
|
272 |
time_taken = end_time - start_time
|
273 |
st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
|
274 |
|
275 |
-
st.session_state.forecast_results = forecast_results
|
276 |
-
|
277 |
if 'forecast_results' in st.session_state:
|
278 |
forecast_results = st.session_state.forecast_results
|
279 |
|
@@ -291,12 +290,13 @@ def transfer_learning_forecasting():
|
|
291 |
# )
|
292 |
|
293 |
with tab_forecast:
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
|
|
300 |
|
301 |
|
302 |
def dynamic_forecasting():
|
|
|
264 |
forecast_results['LSTM'] = generate_forecast(lstm_model, df)
|
265 |
elif model_choice == "TFT":
|
266 |
forecast_results['TFT'] = generate_forecast(tft_model, df)
|
267 |
+
|
268 |
+
st.session_state.forecast_results = forecast_results
|
269 |
for model_name, forecast_df in forecast_results.items():
|
270 |
plot_forecasts(forecast_df.iloc[:horizon,:], df, f'{model_name} Forecast for {y_col}')
|
271 |
|
|
|
273 |
time_taken = end_time - start_time
|
274 |
st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
|
275 |
|
|
|
|
|
276 |
if 'forecast_results' in st.session_state:
|
277 |
forecast_results = st.session_state.forecast_results
|
278 |
|
|
|
290 |
# )
|
291 |
|
292 |
with tab_forecast:
|
293 |
+
if model_choice in forecast_results:
|
294 |
+
df_grid = forecast_results[model_choice]
|
295 |
+
st.write(df_grid)
|
296 |
+
# grid_table = AgGrid(
|
297 |
+
# df_grid,
|
298 |
+
# theme="alpine",
|
299 |
+
# )
|
300 |
|
301 |
|
302 |
def dynamic_forecasting():
|