Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -86,7 +86,7 @@ def determine_frequency(df):
|
|
86 |
|
87 |
return freq
|
88 |
|
89 |
-
def
|
90 |
fig, ax = plt.subplots(1, 1, figsize=(20, 7))
|
91 |
plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
|
92 |
historical_col = 'y'
|
@@ -112,6 +112,61 @@ def plot_forecasts(forecast_df, train_df, title):
|
|
112 |
ax.grid()
|
113 |
st.pyplot(fig)
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models):
|
116 |
if freq == 'D':
|
117 |
return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D']
|
|
|
86 |
|
87 |
return freq
|
88 |
|
89 |
+
def plot_forecasts_matplotlib(forecast_df, train_df, title):
|
90 |
fig, ax = plt.subplots(1, 1, figsize=(20, 7))
|
91 |
plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
|
92 |
historical_col = 'y'
|
|
|
112 |
ax.grid()
|
113 |
st.pyplot(fig)
|
114 |
|
115 |
+
import plotly.graph_objects as go
|
116 |
+
|
117 |
+
def plot_forecasts(forecast_df, train_df, title):
|
118 |
+
# Combine historical and forecast data
|
119 |
+
plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
|
120 |
+
|
121 |
+
# Find relevant columns
|
122 |
+
historical_col = 'y'
|
123 |
+
forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
|
124 |
+
lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
|
125 |
+
hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
|
126 |
+
|
127 |
+
if forecast_col is None:
|
128 |
+
raise KeyError("No forecast column found in the data.")
|
129 |
+
|
130 |
+
# Create Plotly figure
|
131 |
+
fig = go.Figure()
|
132 |
+
|
133 |
+
# Add historical data
|
134 |
+
fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[historical_col], mode='lines', name='Historical'))
|
135 |
+
|
136 |
+
# Add forecast data
|
137 |
+
fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[forecast_col], mode='lines', name='Forecast'))
|
138 |
+
|
139 |
+
# Add confidence interval if available
|
140 |
+
if lo_col and hi_col:
|
141 |
+
fig.add_trace(go.Scatter(
|
142 |
+
x=plot_df.index,
|
143 |
+
y=plot_df[lo_col],
|
144 |
+
mode='lines',
|
145 |
+
fill='tozeroy',
|
146 |
+
fillcolor='rgba(0,100,80,0.2)',
|
147 |
+
line=dict(color='rgba(0,100,80,0.2)'),
|
148 |
+
name='90% Confidence Interval'
|
149 |
+
))
|
150 |
+
fig.add_trace(go.Scatter(
|
151 |
+
x=plot_df.index,
|
152 |
+
y=plot_df[hi_col],
|
153 |
+
mode='lines',
|
154 |
+
line=dict(color='rgba(0,100,80,0.2)'),
|
155 |
+
showlegend=False
|
156 |
+
))
|
157 |
+
|
158 |
+
# Update layout
|
159 |
+
fig.update_layout(
|
160 |
+
title=title,
|
161 |
+
xaxis_title='Timestamp [t]',
|
162 |
+
yaxis_title='Value',
|
163 |
+
template='plotly_white'
|
164 |
+
)
|
165 |
+
|
166 |
+
# Display the plot
|
167 |
+
st.plotly_chart(fig)
|
168 |
+
|
169 |
+
|
170 |
def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models):
|
171 |
if freq == 'D':
|
172 |
return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D']
|