azrai99 commited on
Commit
8d3eb2c
·
verified ·
1 Parent(s): 3127dc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -54,7 +54,7 @@ def load_all_models():
54
  return nhits_models, timesnet_models, lstm_models, tft_models
55
 
56
  def generate_forecast(model, df,tag=False):
57
- if tag:
58
  forecast_df = model.predict()
59
  else:
60
  forecast_df = model.predict(df=df)
@@ -162,7 +162,7 @@ def model_train(df,model, ds_col, freq):
162
  nf = NeuralForecast(models=[model], freq=freq)
163
  df[ds_col] = pd.to_datetime(df[ds_col])
164
  nf.fit(df)
165
- return model
166
 
167
  def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds',y_col):
168
  start_time = time.time() # Start timing
 
54
  return nhits_models, timesnet_models, lstm_models, tft_models
55
 
56
  def generate_forecast(model, df,tag=False):
57
+ if tag == 'retrain':
58
  forecast_df = model.predict()
59
  else:
60
  forecast_df = model.predict(df=df)
 
162
  nf = NeuralForecast(models=[model], freq=freq)
163
  df[ds_col] = pd.to_datetime(df[ds_col])
164
  nf.fit(df)
165
+ return nf
166
 
167
  def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds',y_col):
168
  start_time = time.time() # Start timing