azrai99 commited on
Commit
64e36b4
·
verified ·
1 Parent(s): a752441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -0
app.py CHANGED
@@ -7,6 +7,7 @@ from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
7
  from neuralforecast.losses.pytorch import HuberMQLoss
8
  from neuralforecast.utils import AirPassengersDF
9
  import time
 
10
 
11
  @st.cache_resource
12
  def load_model(path, freq):
@@ -225,6 +226,15 @@ def transfer_learning_forecasting():
225
  frequency = determine_frequency(df)
226
  st.sidebar.write(f"Detected frequency: {frequency}")
227
 
 
 
 
 
 
 
 
 
 
228
  # Load pre-trained models
229
  nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
230
  forecast_results = {}
@@ -239,6 +249,14 @@ def transfer_learning_forecasting():
239
  elif model_choice == "TFT":
240
  forecast_results['TFT'] = generate_forecast(tft_model, df)
241
 
 
 
 
 
 
 
 
 
242
  for model_name, forecast_df in forecast_results.items():
243
  plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
244
 
 
7
  from neuralforecast.losses.pytorch import HuberMQLoss
8
  from neuralforecast.utils import AirPassengersDF
9
  import time
10
+ from st_aggrid import AgGrid
11
 
12
  @st.cache_resource
13
  def load_model(path, freq):
 
226
  frequency = determine_frequency(df)
227
  st.sidebar.write(f"Detected frequency: {frequency}")
228
 
229
+ df_grid = df.drop(columns="unique_id")
230
+ grid_table = AgGrid(
231
+ df_grid,
232
+ editable=False,
233
+ theme="streamlit",
234
+ fit_columns_on_grid_load=True,
235
+ height=360,
236
+ )
237
+
238
  # Load pre-trained models
239
  nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
240
  forecast_results = {}
 
249
  elif model_choice == "TFT":
250
  forecast_results['TFT'] = generate_forecast(tft_model, df)
251
 
252
+ df_grid = df.drop(columns="unique_id")
253
+ grid_table = AgGrid(
254
+ df_grid,
255
+ editable=False,
256
+ theme="streamlit",
257
+ fit_columns_on_grid_load=True,
258
+ height=360,
259
+ )
260
  for model_name, forecast_df in forecast_results.items():
261
  plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
262