Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -15,95 +15,262 @@ st.set_page_config(layout='wide')
|
|
15 |
|
16 |
@st.cache_resource
|
17 |
def load_model(path, freq):
|
18 |
-
|
|
|
19 |
|
20 |
@st.cache_resource
|
21 |
def load_all_models():
|
22 |
-
|
23 |
-
'D': './M4/
|
24 |
-
'M': './M4/
|
25 |
-
'H': './M4/
|
26 |
-
'W': './M4/
|
27 |
-
'Y': './M4/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
}
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
def generate_forecast(model, df,
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def determine_frequency(df):
|
38 |
-
df['ds'] = pd.to_datetime(df['ds'])
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
if not freq:
|
41 |
-
st.warning('
|
|
|
|
|
42 |
return freq
|
43 |
|
|
|
|
|
|
|
44 |
def plot_forecasts(forecast_df, train_df, title):
|
|
|
45 |
plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
|
|
|
|
|
46 |
historical_col = 'y'
|
47 |
forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
|
48 |
lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
|
49 |
hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
|
50 |
|
51 |
-
if forecast_col:
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
st.plotly_chart(fig)
|
60 |
-
|
61 |
-
def select_model_based_on_frequency(freq, models):
|
62 |
-
return {model: models[model][freq] for model in models}
|
63 |
-
|
64 |
-
def model_train(df, model, freq):
|
65 |
-
nf = NeuralForecast(models=[model], freq=freq)
|
66 |
-
df['ds'] = pd.to_datetime(df['ds'])
|
67 |
-
nf.fit(df)
|
68 |
-
return nf
|
69 |
-
|
70 |
-
def forecast_time_series(df, model_type, horizon, max_steps, y_col):
|
71 |
-
freq = determine_frequency(df)
|
72 |
-
st.sidebar.write(f"Data frequency: {freq}")
|
73 |
|
74 |
-
|
75 |
-
|
76 |
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
@st.cache_data
|
83 |
def load_default():
|
84 |
-
|
|
|
85 |
|
86 |
def transfer_learning_forecasting():
|
87 |
st.title("Zero-shot Forecasting")
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
df = df.rename(columns={ds_col: 'ds', y_col: 'y'}).assign(unique_id=1)[['unique_id', 'ds', 'y']]
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
frequency = determine_frequency(df)
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
|
|
|
98 |
if st.sidebar.button("Submit"):
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
|
|
|
|
|
|
102 |
|
103 |
pg = st.navigation({
|
104 |
"Neuralforecast": [
|
105 |
# Load pages from functions
|
106 |
st.Page(transfer_learning_forecasting, title="Zero-shot Forecasting", default=True, icon=":material/query_stats:"),
|
|
|
107 |
],
|
108 |
})
|
109 |
|
|
|
15 |
|
16 |
@st.cache_resource
|
17 |
def load_model(path, freq):
|
18 |
+
nf = NeuralForecast.load(path=path)
|
19 |
+
return nf
|
20 |
|
21 |
@st.cache_resource
|
22 |
def load_all_models():
|
23 |
+
nhits_paths = {
|
24 |
+
'D': './M4/NHITS/daily',
|
25 |
+
'M': './M4/NHITS/monthly',
|
26 |
+
'H': './M4/NHITS/hourly',
|
27 |
+
'W': './M4/NHITS/weekly',
|
28 |
+
'Y': './M4/NHITS/yearly'
|
29 |
+
}
|
30 |
+
|
31 |
+
timesnet_paths = {
|
32 |
+
'D': './M4/TimesNet/daily',
|
33 |
+
'M': './M4/TimesNet/monthly',
|
34 |
+
'H': './M4/TimesNet/hourly',
|
35 |
+
'W': './M4/TimesNet/weekly',
|
36 |
+
'Y': './M4/TimesNet/yearly'
|
37 |
+
}
|
38 |
+
|
39 |
+
lstm_paths = {
|
40 |
+
'D': './M4/LSTM/daily',
|
41 |
+
'M': './M4/LSTM/monthly',
|
42 |
+
'H': './M4/LSTM/hourly',
|
43 |
+
'W': './M4/LSTM/weekly',
|
44 |
+
'Y': './M4/LSTM/yearly'
|
45 |
}
|
46 |
|
47 |
+
tft_paths = {
|
48 |
+
'D': './M4/TFT/daily',
|
49 |
+
'M': './M4/TFT/monthly',
|
50 |
+
'H': './M4/TFT/hourly',
|
51 |
+
'W': './M4/TFT/weekly',
|
52 |
+
'Y': './M4/TFT/yearly'
|
53 |
+
}
|
54 |
+
nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()}
|
55 |
+
timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()}
|
56 |
+
lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()}
|
57 |
+
tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()}
|
58 |
+
|
59 |
+
return nhits_models, timesnet_models, lstm_models, tft_models
|
60 |
|
61 |
+
def generate_forecast(model, df,tag=False):
|
62 |
+
if tag == 'retrain':
|
63 |
+
forecast_df = model.predict()
|
64 |
+
else:
|
65 |
+
forecast_df = model.predict(df=df)
|
66 |
+
return forecast_df
|
67 |
|
68 |
def determine_frequency(df):
|
69 |
+
df['ds'] = pd.to_datetime(df['ds'])
|
70 |
+
df = df.drop_duplicates(subset='ds')
|
71 |
+
df = df.set_index('ds')
|
72 |
+
|
73 |
+
# # Create a complete date range
|
74 |
+
# full_range = pd.date_range(start=df.index.min(), end=df.index.max(),freq=freq)
|
75 |
+
|
76 |
+
# # Reindex the DataFrame to this full date range
|
77 |
+
# df_full = df.reindex(full_range)
|
78 |
+
|
79 |
+
# Infer the frequency
|
80 |
+
# freq = pd.infer_freq(df_full.index)
|
81 |
+
|
82 |
+
freq = pd.infer_freq(df.index)
|
83 |
if not freq:
|
84 |
+
st.warning('The forecast will use default Daily forecast due to date inconsistency. Please check your data.',icon="⚠️")
|
85 |
+
freq = 'D'
|
86 |
+
|
87 |
return freq
|
88 |
|
89 |
+
|
90 |
+
import plotly.graph_objects as go
|
91 |
+
|
92 |
def plot_forecasts(forecast_df, train_df, title):
|
93 |
+
# Combine historical and forecast data
|
94 |
plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
|
95 |
+
|
96 |
+
# Find relevant columns
|
97 |
historical_col = 'y'
|
98 |
forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
|
99 |
lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
|
100 |
hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
|
101 |
|
102 |
+
if forecast_col is None:
|
103 |
+
raise KeyError("No forecast column found in the data.")
|
104 |
+
|
105 |
+
# Create Plotly figure
|
106 |
+
fig = go.Figure()
|
107 |
+
|
108 |
+
# Add historical data
|
109 |
+
fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[historical_col], mode='lines', name='Historical'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
# Add forecast data
|
112 |
+
fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[forecast_col], mode='lines', name='Forecast'))
|
113 |
|
114 |
+
# Add confidence interval if available
|
115 |
+
if lo_col and hi_col:
|
116 |
+
fig.add_trace(go.Scatter(
|
117 |
+
x=plot_df.index,
|
118 |
+
y=plot_df[hi_col],
|
119 |
+
mode='lines',
|
120 |
+
line=dict(color='rgba(0,100,80,0.2)'),
|
121 |
+
showlegend=False
|
122 |
+
))
|
123 |
+
fig.add_trace(go.Scatter(
|
124 |
+
x=plot_df.index,
|
125 |
+
y=plot_df[lo_col],
|
126 |
+
mode='lines',
|
127 |
+
line=dict(color='rgba(0,100,80,0.2)'),
|
128 |
+
fill='tonexty',
|
129 |
+
fillcolor='rgba(0,100,80,0.2)',
|
130 |
+
name='90% Confidence Interval'
|
131 |
+
))
|
132 |
|
133 |
+
# Update layout
|
134 |
+
fig.update_layout(
|
135 |
+
title=title,
|
136 |
+
xaxis_title='Timestamp [t]',
|
137 |
+
yaxis_title='Value',
|
138 |
+
template='plotly_white'
|
139 |
+
)
|
140 |
+
|
141 |
+
# Display the plot
|
142 |
+
st.plotly_chart(fig)
|
143 |
+
|
144 |
+
|
145 |
+
def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models):
|
146 |
+
if freq == 'D':
|
147 |
+
return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D']
|
148 |
+
elif freq == 'ME':
|
149 |
+
return nhits_models['M'], timesnet_models['M'], lstm_models['M'], tft_models['M']
|
150 |
+
elif freq == 'H':
|
151 |
+
return nhits_models['H'], timesnet_models['H'], lstm_models['H'], tft_models['H']
|
152 |
+
elif freq in ['W', 'W-SUN']:
|
153 |
+
return nhits_models['W'], timesnet_models['W'], lstm_models['W'], tft_models['W']
|
154 |
+
elif freq in ['Y', 'Y-DEC']:
|
155 |
+
return nhits_models['Y'], timesnet_models['Y'], lstm_models['Y'], tft_models['Y']
|
156 |
+
else:
|
157 |
+
raise ValueError(f"Unsupported frequency: {freq}")
|
158 |
|
159 |
@st.cache_data
|
160 |
def load_default():
|
161 |
+
df = AirPassengersDF.copy()
|
162 |
+
return df
|
163 |
|
164 |
def transfer_learning_forecasting():
|
165 |
st.title("Zero-shot Forecasting")
|
166 |
+
st.markdown("""
|
167 |
+
Instant time series forecasting and visualization by using various pre-trained deep neural network-based model trained on M4 data.
|
168 |
+
""")
|
169 |
+
|
170 |
+
nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
|
|
|
171 |
|
172 |
+
with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
|
173 |
+
if 'uploaded_file' not in st.session_state:
|
174 |
+
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
175 |
+
if uploaded_file:
|
176 |
+
df = pd.read_csv(uploaded_file)
|
177 |
+
st.session_state.df = df
|
178 |
+
st.session_state.uploaded_file = uploaded_file
|
179 |
+
else:
|
180 |
+
df = load_default()
|
181 |
+
st.session_state.df = df
|
182 |
+
else:
|
183 |
+
if st.checkbox("Upload a new file (CSV)"):
|
184 |
+
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
185 |
+
if uploaded_file:
|
186 |
+
df = pd.read_csv(uploaded_file)
|
187 |
+
st.session_state.df = df
|
188 |
+
st.session_state.uploaded_file = uploaded_file
|
189 |
+
else:
|
190 |
+
df = st.session_state.df
|
191 |
+
else:
|
192 |
+
df = st.session_state.df
|
193 |
+
|
194 |
+
columns = df.columns.tolist()
|
195 |
+
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
|
196 |
+
target_columns = [col for col in columns if (col != ds_col) and (col != 'unique_id')]
|
197 |
+
y_col = st.selectbox("Select Target column", options=target_columns, index=0)
|
198 |
+
|
199 |
+
st.session_state.ds_col = ds_col
|
200 |
+
st.session_state.y_col = y_col
|
201 |
+
|
202 |
+
# Model selection and forecasting
|
203 |
+
st.sidebar.subheader("Model Selection and Forecasting")
|
204 |
+
model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
|
205 |
+
horizon = st.sidebar.number_input("Forecast horizon", value=12)
|
206 |
+
|
207 |
+
df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
|
208 |
+
df['unique_id']=1
|
209 |
+
df = df[['unique_id','ds','y']]
|
210 |
+
|
211 |
+
# Determine frequency of data
|
212 |
frequency = determine_frequency(df)
|
213 |
+
st.sidebar.write(f"Detected frequency: {frequency}")
|
214 |
+
|
215 |
+
|
216 |
+
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
|
217 |
+
forecast_results = {}
|
218 |
+
|
219 |
|
220 |
+
|
221 |
if st.sidebar.button("Submit"):
|
222 |
+
start_time = time.time() # Start timing
|
223 |
+
if model_choice == "NHITS":
|
224 |
+
forecast_results['NHITS'] = generate_forecast(nhits_model, df)
|
225 |
+
elif model_choice == "TimesNet":
|
226 |
+
forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
|
227 |
+
elif model_choice == "LSTM":
|
228 |
+
forecast_results['LSTM'] = generate_forecast(lstm_model, df)
|
229 |
+
elif model_choice == "TFT":
|
230 |
+
forecast_results['TFT'] = generate_forecast(tft_model, df)
|
231 |
+
|
232 |
+
st.session_state.forecast_results = forecast_results
|
233 |
+
for model_name, forecast_df in forecast_results.items():
|
234 |
+
plot_forecasts(forecast_df.iloc[:horizon,:], df, f'{model_name} Forecast for {y_col}')
|
235 |
+
|
236 |
+
end_time = time.time() # End timing
|
237 |
+
time_taken = end_time - start_time
|
238 |
+
st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
|
239 |
+
|
240 |
+
if 'forecast_results' in st.session_state:
|
241 |
+
forecast_results = st.session_state.forecast_results
|
242 |
+
|
243 |
+
st.markdown('You can download Input and Forecast Data below')
|
244 |
+
tab_insample, tab_forecast = st.tabs(
|
245 |
+
["Input data", "Forecast"]
|
246 |
+
)
|
247 |
+
|
248 |
+
with tab_insample:
|
249 |
+
df_grid = df.drop(columns="unique_id")
|
250 |
+
st.write(df_grid)
|
251 |
+
# grid_table = AgGrid(
|
252 |
+
# df_grid,
|
253 |
+
# theme="alpine",
|
254 |
+
# )
|
255 |
+
|
256 |
+
with tab_forecast:
|
257 |
+
if model_choice in forecast_results:
|
258 |
+
df_grid = forecast_results[model_choice]
|
259 |
+
st.write(df_grid)
|
260 |
+
# grid_table = AgGrid(
|
261 |
+
# df_grid,
|
262 |
+
# theme="alpine",
|
263 |
+
# )
|
264 |
|
265 |
+
def personalized_forecasting():
|
266 |
+
st.title('Personalized Forecasting')
|
267 |
+
st.subheader("Coming soon. Stay tuned")
|
268 |
|
269 |
pg = st.navigation({
|
270 |
"Neuralforecast": [
|
271 |
# Load pages from functions
|
272 |
st.Page(transfer_learning_forecasting, title="Zero-shot Forecasting", default=True, icon=":material/query_stats:"),
|
273 |
+
st.Page(personalized_forecasting, title="Personalized Forecasting", default=True, icon=":material/robots:")
|
274 |
],
|
275 |
})
|
276 |
|