azrai99 commited on
Commit
1258ec5
·
verified ·
1 Parent(s): 0a98de7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -25
app.py CHANGED
@@ -57,9 +57,9 @@ def generate_forecast(model, df):
57
  forecast_df = model.predict(df=df)
58
  return forecast_df
59
 
60
- def determine_frequency(df):
61
- df['ds'] = pd.to_datetime(df['ds'])
62
- df = df.set_index('ds')
63
  freq = pd.infer_freq(df.index)
64
  return freq
65
 
@@ -155,18 +155,18 @@ def select_model(horizon, model_type, max_steps=200):
155
  else:
156
  raise ValueError(f"Unsupported model type: {model_type}")
157
 
158
- def model_train(df,model):
159
- df['ds'] = pd.to_datetime(df['ds'])
160
  model.fit(df)
161
  return model
162
 
163
- def forecast_time_series(df, model_type, horizon, max_steps=200):
164
  start_time = time.time() # Start timing
165
- freq = determine_frequency(df)
166
  st.sidebar.write(f"Data frequency: {freq}")
167
 
168
  selected_model = select_model(horizon, model_type, max_steps)
169
- model = model_train(df, selected_model)
170
 
171
  forecast_results = {}
172
  st.sidebar.write(f"Generating forecast using {model_type} model...")
@@ -187,22 +187,35 @@ def load_default():
187
  def transfer_learning_forecasting():
188
  st.title("Transfer Learning Forecasting")
189
 
190
- # Upload dataset
191
- uploaded_file = st.sidebar.file_uploader("Upload your time series data (CSV)", type=["csv"])
192
- if uploaded_file:
193
- df = pd.read_csv(uploaded_file)
194
- else:
195
- df = load_default()
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
198
-
199
  # Model selection and forecasting
200
  st.sidebar.subheader("Model Selection and Forecasting")
201
  model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
202
  horizon = st.sidebar.number_input("Forecast horizon", value=18)
203
 
 
 
 
204
  # Determine frequency of data
205
- frequency = determine_frequency(df)
206
  st.sidebar.write(f"Detected frequency: {frequency}")
207
 
208
  # Load pre-trained models
@@ -229,20 +242,35 @@ def transfer_learning_forecasting():
229
  def dynamic_forecasting():
230
  st.title("Dynamic Forecasting")
231
 
232
- # Upload dataset
233
- uploaded_file = st.sidebar.file_uploader("Upload your time series data (CSV)", type=["csv"])
234
- if uploaded_file:
235
- df = pd.read_csv(uploaded_file)
236
- else:
237
- df = load_default()
238
-
 
 
 
 
 
 
 
 
 
 
 
 
239
  # Dynamic forecasting
240
  st.sidebar.subheader("Dynamic Model Selection and Forecasting")
241
  dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
242
  dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
243
  dynamic_max_steps = st.sidebar.number_input('Max steps', value=200)
244
 
245
- forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps)
 
 
 
246
 
247
  pg = st.navigation({
248
  "Overview": [
 
57
  forecast_df = model.predict(df=df)
58
  return forecast_df
59
 
60
+ def determine_frequency(df, ds_col):
61
+ df[ds_col] = pd.to_datetime(df[ds_col])
62
+ df = df.set_index(ds_col)
63
  freq = pd.infer_freq(df.index)
64
  return freq
65
 
 
155
  else:
156
  raise ValueError(f"Unsupported model type: {model_type}")
157
 
158
+ def model_train(df,model, ds_col):
159
+ df[ds_col] = pd.to_datetime(df[ds_col])
160
  model.fit(df)
161
  return model
162
 
163
+ def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds'):
164
  start_time = time.time() # Start timing
165
+ freq = determine_frequency(df, ds_col)
166
  st.sidebar.write(f"Data frequency: {freq}")
167
 
168
  selected_model = select_model(horizon, model_type, max_steps)
169
+ model = model_train(df, selected_model, ds_col)
170
 
171
  forecast_results = {}
172
  st.sidebar.write(f"Generating forecast using {model_type} model...")
 
187
  def transfer_learning_forecasting():
188
  st.title("Transfer Learning Forecasting")
189
 
190
+ with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
191
+ uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
192
+ if uploaded_file:
193
+ df = pd.read_csv(uploaded_file)
194
+ st.session_state.df = df
195
+ else:
196
+ df = load_default()
197
+ st.session_state.df = df
198
+
199
+ # Column selection
200
+ columns = df.columns if not df.empty else []
201
+ ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
202
+ y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1)
203
+ unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")
204
+
205
+ st.session_state.ds_col = ds_col
206
+ st.session_state.y_col = y_col
207
+ st.session_state.unique_id_col = unique_id_col
208
 
 
 
209
  # Model selection and forecasting
210
  st.sidebar.subheader("Model Selection and Forecasting")
211
  model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
212
  horizon = st.sidebar.number_input("Forecast horizon", value=18)
213
 
214
+ df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
215
+ st.session_state.df = df
216
+
217
  # Determine frequency of data
218
+ frequency = determine_frequency(df, 'ds')
219
  st.sidebar.write(f"Detected frequency: {frequency}")
220
 
221
  # Load pre-trained models
 
242
  def dynamic_forecasting():
243
  st.title("Dynamic Forecasting")
244
 
245
+ with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
246
+ uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
247
+ if uploaded_file:
248
+ df = pd.read_csv(uploaded_file)
249
+ st.session_state.df = df
250
+ else:
251
+ df = load_default()
252
+ st.session_state.df = df
253
+
254
+ # Column selection
255
+ columns = df.columns if not df.empty else []
256
+ ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
257
+ y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1)
258
+ unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")
259
+
260
+ st.session_state.ds_col = ds_col
261
+ st.session_state.y_col = y_col
262
+ st.session_state.unique_id_col = unique_id_col
263
+
264
  # Dynamic forecasting
265
  st.sidebar.subheader("Dynamic Model Selection and Forecasting")
266
  dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
267
  dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
268
  dynamic_max_steps = st.sidebar.number_input('Max steps', value=200)
269
 
270
+ df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
271
+ st.session_state.df = df
272
+
273
+ forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps, ds_col='ds')
274
 
275
  pg = st.navigation({
276
  "Overview": [