azrai99 commited on
Commit
b0365ad
·
verified ·
1 Parent(s): 2142d3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -39
app.py CHANGED
@@ -226,47 +226,41 @@ def transfer_learning_forecasting():
226
  frequency = determine_frequency(df)
227
  st.sidebar.write(f"Detected frequency: {frequency}")
228
 
229
- tab_insample = st.tabs(
230
- ["Input data"]
231
- )
232
- with tab_insample:
233
- df_grid = df.drop(columns="unique_id")
234
- grid_table = AgGrid(
235
- df_grid,
236
- editable=False,
237
- # theme="streamlit",
238
- fit_columns_on_grid_load=True,
239
- height=360,
240
- )
241
-
242
- # Load pre-trained models
243
- nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
244
- forecast_results = {}
245
-
246
- start_time = time.time() # Start timing
247
- if model_choice == "NHITS":
248
- forecast_results['NHITS'] = generate_forecast(nhits_model, df)
249
- elif model_choice == "TimesNet":
250
- forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
251
- elif model_choice == "LSTM":
252
- forecast_results['LSTM'] = generate_forecast(lstm_model, df)
253
- elif model_choice == "TFT":
254
- forecast_results['TFT'] = generate_forecast(tft_model, df)
255
-
256
- df_grid = df.drop(columns="unique_id")
257
- grid_table = AgGrid(
258
- df_grid,
259
- editable=False,
260
- theme="streamlit",
261
- fit_columns_on_grid_load=True,
262
- height=360,
263
  )
264
- for model_name, forecast_df in forecast_results.items():
265
- plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
 
 
 
 
 
 
266
 
267
- end_time = time.time() # End timing
268
- time_taken = end_time - start_time
269
- st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  def dynamic_forecasting():
272
  st.title("Dynamic Forecasting")
 
226
  frequency = determine_frequency(df)
227
  st.sidebar.write(f"Detected frequency: {frequency}")
228
 
229
+ col1, col2 = st.columns([2,4])
230
+ with col1:
231
+ tab_insample = st.tabs(
232
+ ["Input data"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  )
234
+ with tab_insample:
235
+ df_grid = df.drop(columns="unique_id")
236
+ grid_table = AgGrid(
237
+ df_grid,
238
+ theme="alpine",
239
+ fit_columns_on_grid_load=True,
240
+ height=360,
241
+ )
242
 
243
+ with col2:
244
+ # Load pre-trained models
245
+ nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
246
+ forecast_results = {}
247
+
248
+ start_time = time.time() # Start timing
249
+ if model_choice == "NHITS":
250
+ forecast_results['NHITS'] = generate_forecast(nhits_model, df)
251
+ elif model_choice == "TimesNet":
252
+ forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
253
+ elif model_choice == "LSTM":
254
+ forecast_results['LSTM'] = generate_forecast(lstm_model, df)
255
+ elif model_choice == "TFT":
256
+ forecast_results['TFT'] = generate_forecast(tft_model, df)
257
+
258
+ for model_name, forecast_df in forecast_results.items():
259
+ plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
260
+
261
+ end_time = time.time() # End timing
262
+ time_taken = end_time - start_time
263
+ st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
264
 
265
  def dynamic_forecasting():
266
  st.title("Dynamic Forecasting")