azrai99 commited on
Commit
fbf37ad
·
verified ·
1 Parent(s): a4e353d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -53,8 +53,11 @@ def load_all_models():
53
 
54
  return nhits_models, timesnet_models, lstm_models, tft_models
55
 
56
- def generate_forecast(model, df):
57
- forecast_df = model.predict(df=df)
 
 
 
58
  return forecast_df
59
 
60
  def determine_frequency(df, ds_col):
@@ -171,7 +174,7 @@ def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds'):
171
 
172
  forecast_results = {}
173
  st.sidebar.write(f"Generating forecast using {model_type} model...")
174
- forecast_results[model_type] = generate_forecast(model, df)
175
 
176
  for model_name, forecast_df in forecast_results.items():
177
  plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison')
 
53
 
54
  return nhits_models, timesnet_models, lstm_models, tft_models
55
 
56
+ def generate_forecast(model, df,tag=False):
57
+ if tag:
58
+ forecast_df = model.predict()
59
+ else:
60
+ forecast_df = model.predict(df=df)
61
  return forecast_df
62
 
63
  def determine_frequency(df, ds_col):
 
174
 
175
  forecast_results = {}
176
  st.sidebar.write(f"Generating forecast using {model_type} model...")
177
+ forecast_results[model_type] = generate_forecast(model, df, tag='retrain')
178
 
179
  for model_name, forecast_df in forecast_results.items():
180
  plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison')