azrai99 commited on
Commit
c0f7916
·
verified ·
1 Parent(s): c77c48a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -21
app.py CHANGED
@@ -8,8 +8,6 @@ from neuralforecast.losses.pytorch import HuberMQLoss
8
  from neuralforecast.utils import AirPassengersDF
9
  import time
10
 
11
-
12
-
13
  @st.cache_resource
14
  def load_model(path, freq):
15
  nf = NeuralForecast.load(path=path)
@@ -17,7 +15,6 @@ def load_model(path, freq):
17
 
18
  @st.cache_resource
19
  def load_all_models():
20
- # Paths for saving models
21
  nhits_paths = {
22
  'D': './M4/NHITS/daily',
23
  'M': './M4/NHITS/monthly',
@@ -161,16 +158,17 @@ def select_model(horizon, model_type, max_steps=200):
161
  def model_train(df,model):
162
  model.fit(df)
163
  return model
 
164
  def forecast_time_series(df, model_type, horizon, max_steps=200):
165
  start_time = time.time() # Start timing
166
  freq = determine_frequency(df)
167
- st.write(f"Determined frequency: {freq}")
168
 
169
  selected_model = select_model(horizon, model_type, max_steps)
170
- model = model_train(df,selected_model)
171
 
172
  forecast_results = {}
173
- st.write(f"Generating forecast using {model_type} model...")
174
  forecast_results[model_type] = generate_forecast(model, df)
175
 
176
  for model_name, forecast_df in forecast_results.items():
@@ -178,7 +176,7 @@ def forecast_time_series(df, model_type, horizon, max_steps=200):
178
 
179
  end_time = time.time() # End timing
180
  time_taken = end_time - start_time
181
- st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
182
 
183
  @st.cache_data
184
  def load_default():
@@ -189,7 +187,7 @@ def transfer_learning_forecasting():
189
  st.title("Transfer Learning Forecasting")
190
 
191
  # Upload dataset
192
- uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
193
  if uploaded_file:
194
  df = pd.read_csv(uploaded_file)
195
  else:
@@ -198,13 +196,13 @@ def transfer_learning_forecasting():
198
  nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
199
 
200
  # Model selection and forecasting
201
- st.subheader("Model Selection and Forecasting")
202
- model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
203
- horizon = st.number_input("Forecast horizon", value=18)
204
 
205
  # Determine frequency of data
206
  frequency = determine_frequency(df)
207
- st.write(f"Detected frequency: {frequency}")
208
 
209
  # Load pre-trained models
210
  nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
@@ -225,27 +223,26 @@ def transfer_learning_forecasting():
225
 
226
  end_time = time.time() # End timing
227
  time_taken = end_time - start_time
228
- st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
229
 
230
  def dynamic_forecasting():
231
  st.title("Dynamic Forecasting")
232
 
233
  # Upload dataset
234
- uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
235
  if uploaded_file:
236
  df = pd.read_csv(uploaded_file)
237
  else:
238
  df = load_default()
239
 
240
  # Dynamic forecasting
241
- st.subheader("Dynamic Model Selection and Forecasting")
242
- dynamic_model_choice = st.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
243
- dynamic_horizon = st.number_input("Forecast horizon", value=18)
244
- dynamic_max_steps = st.number_input('Max steps', value=200)
245
 
246
-
247
  forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps)
248
-
249
  pg = st.navigation({
250
  "Overview": [
251
  # Load pages from functions
@@ -257,4 +254,4 @@ pg = st.navigation({
257
  try:
258
  pg.run()
259
  except Exception as e:
260
- st.error(f"Something went wrong: {str(e)}", icon=":material/error:")
 
8
  from neuralforecast.utils import AirPassengersDF
9
  import time
10
 
 
 
11
  @st.cache_resource
12
  def load_model(path, freq):
13
  nf = NeuralForecast.load(path=path)
 
15
 
16
  @st.cache_resource
17
  def load_all_models():
 
18
  nhits_paths = {
19
  'D': './M4/NHITS/daily',
20
  'M': './M4/NHITS/monthly',
 
158
  def model_train(df,model):
159
  model.fit(df)
160
  return model
161
+
162
  def forecast_time_series(df, model_type, horizon, max_steps=200):
163
  start_time = time.time() # Start timing
164
  freq = determine_frequency(df)
165
+ st.sidebar.write(f"Determined frequency: {freq}")
166
 
167
  selected_model = select_model(horizon, model_type, max_steps)
168
+ model = model_train(df, selected_model)
169
 
170
  forecast_results = {}
171
+ st.sidebar.write(f"Generating forecast using {model_type} model...")
172
  forecast_results[model_type] = generate_forecast(model, df)
173
 
174
  for model_name, forecast_df in forecast_results.items():
 
176
 
177
  end_time = time.time() # End timing
178
  time_taken = end_time - start_time
179
+ st.sidebar.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
180
 
181
  @st.cache_data
182
  def load_default():
 
187
  st.title("Transfer Learning Forecasting")
188
 
189
  # Upload dataset
190
+ uploaded_file = st.sidebar.file_uploader("Upload your time series data (CSV)", type=["csv"])
191
  if uploaded_file:
192
  df = pd.read_csv(uploaded_file)
193
  else:
 
196
  nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
197
 
198
  # Model selection and forecasting
199
+ st.sidebar.subheader("Model Selection and Forecasting")
200
+ model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
201
+ horizon = st.sidebar.number_input("Forecast horizon", value=18)
202
 
203
  # Determine frequency of data
204
  frequency = determine_frequency(df)
205
+ st.sidebar.write(f"Detected frequency: {frequency}")
206
 
207
  # Load pre-trained models
208
  nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
 
223
 
224
  end_time = time.time() # End timing
225
  time_taken = end_time - start_time
226
+ st.sidebar.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
227
 
228
  def dynamic_forecasting():
229
  st.title("Dynamic Forecasting")
230
 
231
  # Upload dataset
232
+ uploaded_file = st.sidebar.file_uploader("Upload your time series data (CSV)", type=["csv"])
233
  if uploaded_file:
234
  df = pd.read_csv(uploaded_file)
235
  else:
236
  df = load_default()
237
 
238
  # Dynamic forecasting
239
+ st.sidebar.subheader("Dynamic Model Selection and Forecasting")
240
+ dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
241
+ dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
242
+ dynamic_max_steps = st.sidebar.number_input('Max steps', value=200)
243
 
 
244
  forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps)
245
+
246
  pg = st.navigation({
247
  "Overview": [
248
  # Load pages from functions
 
254
  try:
255
  pg.run()
256
  except Exception as e:
257
+ st.sidebar.error(f"Something went wrong: {e}", icon=":material/error:")