azrai99 commited on
Commit
a4e353d
·
verified ·
1 Parent(s): 106863f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -155,9 +155,10 @@ def select_model(horizon, model_type, max_steps=200):
155
  else:
156
  raise ValueError(f"Unsupported model type: {model_type}")
157
 
158
- def model_train(df,model, ds_col):
 
159
  df[ds_col] = pd.to_datetime(df[ds_col])
160
- model.fit(df)
161
  return model
162
 
163
  def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds'):
@@ -166,7 +167,7 @@ def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds'):
166
  st.sidebar.write(f"Data frequency: {freq}")
167
 
168
  selected_model = select_model(horizon, model_type, max_steps)
169
- model = model_train(df, selected_model, ds_col)
170
 
171
  forecast_results = {}
172
  st.sidebar.write(f"Generating forecast using {model_type} model...")
@@ -187,7 +188,8 @@ def load_default():
187
  def transfer_learning_forecasting():
188
  st.title("Transfer Learning Forecasting")
189
 
190
- nhits_model, timesnet_model, lstm_model, tft_model = load_all_models()
 
191
  with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
192
  uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
193
  if uploaded_file:
 
155
  else:
156
  raise ValueError(f"Unsupported model type: {model_type}")
157
 
158
+ def model_train(df,model, ds_col, freq):
159
+ nf = NeuralForecast(models=[model], freq=freq)
160
  df[ds_col] = pd.to_datetime(df[ds_col])
161
+ nf.fit(df)
162
  return model
163
 
164
  def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds'):
 
167
  st.sidebar.write(f"Data frequency: {freq}")
168
 
169
  selected_model = select_model(horizon, model_type, max_steps)
170
+ model = model_train(df, selected_model, ds_col,freq)
171
 
172
  forecast_results = {}
173
  st.sidebar.write(f"Generating forecast using {model_type} model...")
 
188
  def transfer_learning_forecasting():
189
  st.title("Transfer Learning Forecasting")
190
 
191
+ nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
192
+
193
  with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
194
  uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
195
  if uploaded_file: