Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -57,9 +57,9 @@ def generate_forecast(model, df):
|
|
57 |
forecast_df = model.predict(df=df)
|
58 |
return forecast_df
|
59 |
|
60 |
-
def determine_frequency(df):
|
61 |
-
df[
|
62 |
-
df = df.set_index(
|
63 |
freq = pd.infer_freq(df.index)
|
64 |
return freq
|
65 |
|
@@ -155,18 +155,18 @@ def select_model(horizon, model_type, max_steps=200):
|
|
155 |
else:
|
156 |
raise ValueError(f"Unsupported model type: {model_type}")
|
157 |
|
158 |
-
def model_train(df,model):
|
159 |
-
df[
|
160 |
model.fit(df)
|
161 |
return model
|
162 |
|
163 |
-
def forecast_time_series(df, model_type, horizon, max_steps=200):
|
164 |
start_time = time.time() # Start timing
|
165 |
-
freq = determine_frequency(df)
|
166 |
st.sidebar.write(f"Data frequency: {freq}")
|
167 |
|
168 |
selected_model = select_model(horizon, model_type, max_steps)
|
169 |
-
model = model_train(df, selected_model)
|
170 |
|
171 |
forecast_results = {}
|
172 |
st.sidebar.write(f"Generating forecast using {model_type} model...")
|
@@ -187,22 +187,35 @@ def load_default():
|
|
187 |
def transfer_learning_forecasting():
|
188 |
st.title("Transfer Learning Forecasting")
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
-
nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
|
198 |
-
|
199 |
# Model selection and forecasting
|
200 |
st.sidebar.subheader("Model Selection and Forecasting")
|
201 |
model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
|
202 |
horizon = st.sidebar.number_input("Forecast horizon", value=18)
|
203 |
|
|
|
|
|
|
|
204 |
# Determine frequency of data
|
205 |
-
frequency = determine_frequency(df)
|
206 |
st.sidebar.write(f"Detected frequency: {frequency}")
|
207 |
|
208 |
# Load pre-trained models
|
@@ -229,20 +242,35 @@ def transfer_learning_forecasting():
|
|
229 |
def dynamic_forecasting():
|
230 |
st.title("Dynamic Forecasting")
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
# Dynamic forecasting
|
240 |
st.sidebar.subheader("Dynamic Model Selection and Forecasting")
|
241 |
dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
|
242 |
dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
|
243 |
dynamic_max_steps = st.sidebar.number_input('Max steps', value=200)
|
244 |
|
245 |
-
|
|
|
|
|
|
|
246 |
|
247 |
pg = st.navigation({
|
248 |
"Overview": [
|
|
|
57 |
forecast_df = model.predict(df=df)
|
58 |
return forecast_df
|
59 |
|
60 |
+
def determine_frequency(df, ds_col):
|
61 |
+
df[ds_col] = pd.to_datetime(df[ds_col])
|
62 |
+
df = df.set_index(ds_col)
|
63 |
freq = pd.infer_freq(df.index)
|
64 |
return freq
|
65 |
|
|
|
155 |
else:
|
156 |
raise ValueError(f"Unsupported model type: {model_type}")
|
157 |
|
158 |
+
def model_train(df,model, ds_col):
|
159 |
+
df[ds_col] = pd.to_datetime(df[ds_col])
|
160 |
model.fit(df)
|
161 |
return model
|
162 |
|
163 |
+
def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds'):
|
164 |
start_time = time.time() # Start timing
|
165 |
+
freq = determine_frequency(df, ds_col)
|
166 |
st.sidebar.write(f"Data frequency: {freq}")
|
167 |
|
168 |
selected_model = select_model(horizon, model_type, max_steps)
|
169 |
+
model = model_train(df, selected_model, ds_col)
|
170 |
|
171 |
forecast_results = {}
|
172 |
st.sidebar.write(f"Generating forecast using {model_type} model...")
|
|
|
187 |
def transfer_learning_forecasting():
|
188 |
st.title("Transfer Learning Forecasting")
|
189 |
|
190 |
+
with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
|
191 |
+
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
192 |
+
if uploaded_file:
|
193 |
+
df = pd.read_csv(uploaded_file)
|
194 |
+
st.session_state.df = df
|
195 |
+
else:
|
196 |
+
df = load_default()
|
197 |
+
st.session_state.df = df
|
198 |
+
|
199 |
+
# Column selection
|
200 |
+
columns = df.columns if not df.empty else []
|
201 |
+
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
|
202 |
+
y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1)
|
203 |
+
unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")
|
204 |
+
|
205 |
+
st.session_state.ds_col = ds_col
|
206 |
+
st.session_state.y_col = y_col
|
207 |
+
st.session_state.unique_id_col = unique_id_col
|
208 |
|
|
|
|
|
209 |
# Model selection and forecasting
|
210 |
st.sidebar.subheader("Model Selection and Forecasting")
|
211 |
model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
|
212 |
horizon = st.sidebar.number_input("Forecast horizon", value=18)
|
213 |
|
214 |
+
df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
|
215 |
+
st.session_state.df = df
|
216 |
+
|
217 |
# Determine frequency of data
|
218 |
+
frequency = determine_frequency(df, 'ds')
|
219 |
st.sidebar.write(f"Detected frequency: {frequency}")
|
220 |
|
221 |
# Load pre-trained models
|
|
|
242 |
def dynamic_forecasting():
|
243 |
st.title("Dynamic Forecasting")
|
244 |
|
245 |
+
with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
|
246 |
+
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
247 |
+
if uploaded_file:
|
248 |
+
df = pd.read_csv(uploaded_file)
|
249 |
+
st.session_state.df = df
|
250 |
+
else:
|
251 |
+
df = load_default()
|
252 |
+
st.session_state.df = df
|
253 |
+
|
254 |
+
# Column selection
|
255 |
+
columns = df.columns if not df.empty else []
|
256 |
+
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
|
257 |
+
y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1)
|
258 |
+
unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")
|
259 |
+
|
260 |
+
st.session_state.ds_col = ds_col
|
261 |
+
st.session_state.y_col = y_col
|
262 |
+
st.session_state.unique_id_col = unique_id_col
|
263 |
+
|
264 |
# Dynamic forecasting
|
265 |
st.sidebar.subheader("Dynamic Model Selection and Forecasting")
|
266 |
dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
|
267 |
dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
|
268 |
dynamic_max_steps = st.sidebar.number_input('Max steps', value=200)
|
269 |
|
270 |
+
df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
|
271 |
+
st.session_state.df = df
|
272 |
+
|
273 |
+
forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps, ds_col='ds')
|
274 |
|
275 |
pg = st.navigation({
|
276 |
"Overview": [
|