Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
|
|
7 |
from neuralforecast.losses.pytorch import HuberMQLoss
|
8 |
from neuralforecast.utils import AirPassengersDF
|
9 |
import time
|
|
|
10 |
|
11 |
@st.cache_resource
|
12 |
def load_model(path, freq):
|
@@ -225,6 +226,15 @@ def transfer_learning_forecasting():
|
|
225 |
frequency = determine_frequency(df)
|
226 |
st.sidebar.write(f"Detected frequency: {frequency}")
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
# Load pre-trained models
|
229 |
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
|
230 |
forecast_results = {}
|
@@ -239,6 +249,14 @@ def transfer_learning_forecasting():
|
|
239 |
elif model_choice == "TFT":
|
240 |
forecast_results['TFT'] = generate_forecast(tft_model, df)
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
for model_name, forecast_df in forecast_results.items():
|
243 |
plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
|
244 |
|
|
|
7 |
from neuralforecast.losses.pytorch import HuberMQLoss
|
8 |
from neuralforecast.utils import AirPassengersDF
|
9 |
import time
|
10 |
+
from st_aggrid import AgGrid
|
11 |
|
12 |
@st.cache_resource
|
13 |
def load_model(path, freq):
|
|
|
226 |
frequency = determine_frequency(df)
|
227 |
st.sidebar.write(f"Detected frequency: {frequency}")
|
228 |
|
229 |
+
df_grid = df.drop(columns="unique_id")
|
230 |
+
grid_table = AgGrid(
|
231 |
+
df_grid,
|
232 |
+
editable=False,
|
233 |
+
theme="streamlit",
|
234 |
+
fit_columns_on_grid_load=True,
|
235 |
+
height=360,
|
236 |
+
)
|
237 |
+
|
238 |
# Load pre-trained models
|
239 |
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
|
240 |
forecast_results = {}
|
|
|
249 |
elif model_choice == "TFT":
|
250 |
forecast_results['TFT'] = generate_forecast(tft_model, df)
|
251 |
|
252 |
+
df_grid = df.drop(columns="unique_id")
|
253 |
+
grid_table = AgGrid(
|
254 |
+
df_grid,
|
255 |
+
editable=False,
|
256 |
+
theme="streamlit",
|
257 |
+
fit_columns_on_grid_load=True,
|
258 |
+
height=360,
|
259 |
+
)
|
260 |
for model_name, forecast_df in forecast_results.items():
|
261 |
plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
|
262 |
|