azrai99 commited on
Commit
106863f
·
verified ·
1 Parent(s): 0b84d55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -187,6 +187,7 @@ def load_default():
187
  def transfer_learning_forecasting():
188
  st.title("Transfer Learning Forecasting")
189
 
 
190
  with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
191
  uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
192
  if uploaded_file:
@@ -200,11 +201,10 @@ def transfer_learning_forecasting():
200
  columns = df.columns.tolist() # Convert Index to list
201
  ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
202
  y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1)
203
- unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")
204
 
205
  st.session_state.ds_col = ds_col
206
  st.session_state.y_col = y_col
207
- st.session_state.unique_id_col = unique_id_col
208
 
209
  # Model selection and forecasting
210
  st.sidebar.subheader("Model Selection and Forecasting")
@@ -212,6 +212,7 @@ def transfer_learning_forecasting():
212
  horizon = st.sidebar.number_input("Forecast horizon", value=18)
213
 
214
  df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
 
215
  st.session_state.df = df
216
 
217
  # Determine frequency of data
@@ -255,11 +256,11 @@ def dynamic_forecasting():
255
  columns = df.columns.tolist() # Convert Index to list
256
  ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
257
  y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1)
258
- unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")
259
 
 
260
  st.session_state.ds_col = ds_col
261
  st.session_state.y_col = y_col
262
- st.session_state.unique_id_col = unique_id_col
263
 
264
  # Dynamic forecasting
265
  st.sidebar.subheader("Dynamic Model Selection and Forecasting")
 
187
  def transfer_learning_forecasting():
188
  st.title("Transfer Learning Forecasting")
189
 
190
+ nhits_model, timesnet_model, lstm_model, tft_model = load_all_models()
191
  with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
192
  uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
193
  if uploaded_file:
 
201
  columns = df.columns.tolist() # Convert Index to list
202
  ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
203
  y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1)
204
+ # unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")
205
 
206
  st.session_state.ds_col = ds_col
207
  st.session_state.y_col = y_col
 
208
 
209
  # Model selection and forecasting
210
  st.sidebar.subheader("Model Selection and Forecasting")
 
212
  horizon = st.sidebar.number_input("Forecast horizon", value=18)
213
 
214
  df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
215
+ df['unique_id']=1
216
  st.session_state.df = df
217
 
218
  # Determine frequency of data
 
256
  columns = df.columns.tolist() # Convert Index to list
257
  ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
258
  y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1)
259
+ # unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")
260
 
261
+ df['unique_id']=1
262
  st.session_state.ds_col = ds_col
263
  st.session_state.y_col = y_col
 
264
 
265
  # Dynamic forecasting
266
  st.sidebar.subheader("Dynamic Model Selection and Forecasting")