Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -155,9 +155,10 @@ def select_model(horizon, model_type, max_steps=200):
|
|
155 |
else:
|
156 |
raise ValueError(f"Unsupported model type: {model_type}")
|
157 |
|
158 |
-
def model_train(df,model, ds_col):
|
|
|
159 |
df[ds_col] = pd.to_datetime(df[ds_col])
|
160 |
-
|
161 |
return model
|
162 |
|
163 |
def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds'):
|
@@ -166,7 +167,7 @@ def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds'):
|
|
166 |
st.sidebar.write(f"Data frequency: {freq}")
|
167 |
|
168 |
selected_model = select_model(horizon, model_type, max_steps)
|
169 |
-
model = model_train(df, selected_model, ds_col)
|
170 |
|
171 |
forecast_results = {}
|
172 |
st.sidebar.write(f"Generating forecast using {model_type} model...")
|
@@ -187,7 +188,8 @@ def load_default():
|
|
187 |
def transfer_learning_forecasting():
|
188 |
st.title("Transfer Learning Forecasting")
|
189 |
|
190 |
-
|
|
|
191 |
with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
|
192 |
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
193 |
if uploaded_file:
|
|
|
155 |
else:
|
156 |
raise ValueError(f"Unsupported model type: {model_type}")
|
157 |
|
158 |
+
def model_train(df,model, ds_col, freq):
|
159 |
+
nf = NeuralForecast(models=[model], freq=freq)
|
160 |
df[ds_col] = pd.to_datetime(df[ds_col])
|
161 |
+
nf.fit(df)
|
162 |
return model
|
163 |
|
164 |
def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds'):
|
|
|
167 |
st.sidebar.write(f"Data frequency: {freq}")
|
168 |
|
169 |
selected_model = select_model(horizon, model_type, max_steps)
|
170 |
+
model = model_train(df, selected_model, ds_col,freq)
|
171 |
|
172 |
forecast_results = {}
|
173 |
st.sidebar.write(f"Generating forecast using {model_type} model...")
|
|
|
188 |
def transfer_learning_forecasting():
|
189 |
st.title("Transfer Learning Forecasting")
|
190 |
|
191 |
+
nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
|
192 |
+
|
193 |
with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
|
194 |
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
195 |
if uploaded_file:
|