Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -53,8 +53,11 @@ def load_all_models():
|
|
53 |
|
54 |
return nhits_models, timesnet_models, lstm_models, tft_models
|
55 |
|
56 |
-
def generate_forecast(model, df):
|
57 |
-
|
|
|
|
|
|
|
58 |
return forecast_df
|
59 |
|
60 |
def determine_frequency(df, ds_col):
|
@@ -171,7 +174,7 @@ def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds'):
|
|
171 |
|
172 |
forecast_results = {}
|
173 |
st.sidebar.write(f"Generating forecast using {model_type} model...")
|
174 |
-
forecast_results[model_type] = generate_forecast(model, df)
|
175 |
|
176 |
for model_name, forecast_df in forecast_results.items():
|
177 |
plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison')
|
|
|
53 |
|
54 |
return nhits_models, timesnet_models, lstm_models, tft_models
|
55 |
|
56 |
+
def generate_forecast(model, df,tag=False):
|
57 |
+
if tag:
|
58 |
+
forecast_df = model.predict()
|
59 |
+
else:
|
60 |
+
forecast_df = model.predict(df=df)
|
61 |
return forecast_df
|
62 |
|
63 |
def determine_frequency(df, ds_col):
|
|
|
174 |
|
175 |
forecast_results = {}
|
176 |
st.sidebar.write(f"Generating forecast using {model_type} model...")
|
177 |
+
forecast_results[model_type] = generate_forecast(model, df, tag='retrain')
|
178 |
|
179 |
for model_name, forecast_df in forecast_results.items():
|
180 |
plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison')
|