azrai99 commited on
Commit
91a8ee3
·
verified ·
1 Parent(s): cb9ca26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -1
app.py CHANGED
@@ -86,7 +86,7 @@ def determine_frequency(df):
86
 
87
  return freq
88
 
89
- def plot_forecasts(forecast_df, train_df, title):
90
  fig, ax = plt.subplots(1, 1, figsize=(20, 7))
91
  plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
92
  historical_col = 'y'
@@ -112,6 +112,61 @@ def plot_forecasts(forecast_df, train_df, title):
112
  ax.grid()
113
  st.pyplot(fig)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models):
116
  if freq == 'D':
117
  return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D']
 
86
 
87
  return freq
88
 
89
+ def plot_forecasts_matplotlib(forecast_df, train_df, title):
90
  fig, ax = plt.subplots(1, 1, figsize=(20, 7))
91
  plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
92
  historical_col = 'y'
 
112
  ax.grid()
113
  st.pyplot(fig)
114
 
115
+ import plotly.graph_objects as go
116
+
117
+ def plot_forecasts(forecast_df, train_df, title):
118
+ # Combine historical and forecast data
119
+ plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
120
+
121
+ # Find relevant columns
122
+ historical_col = 'y'
123
+ forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
124
+ lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
125
+ hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
126
+
127
+ if forecast_col is None:
128
+ raise KeyError("No forecast column found in the data.")
129
+
130
+ # Create Plotly figure
131
+ fig = go.Figure()
132
+
133
+ # Add historical data
134
+ fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[historical_col], mode='lines', name='Historical'))
135
+
136
+ # Add forecast data
137
+ fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[forecast_col], mode='lines', name='Forecast'))
138
+
139
+ # Add confidence interval if available
140
+ if lo_col and hi_col:
141
+ fig.add_trace(go.Scatter(
142
+ x=plot_df.index,
143
+ y=plot_df[lo_col],
144
+ mode='lines',
145
+ fill='tozeroy',
146
+ fillcolor='rgba(0,100,80,0.2)',
147
+ line=dict(color='rgba(0,100,80,0.2)'),
148
+ name='90% Confidence Interval'
149
+ ))
150
+ fig.add_trace(go.Scatter(
151
+ x=plot_df.index,
152
+ y=plot_df[hi_col],
153
+ mode='lines',
154
+ line=dict(color='rgba(0,100,80,0.2)'),
155
+ showlegend=False
156
+ ))
157
+
158
+ # Update layout
159
+ fig.update_layout(
160
+ title=title,
161
+ xaxis_title='Timestamp [t]',
162
+ yaxis_title='Value',
163
+ template='plotly_white'
164
+ )
165
+
166
+ # Display the plot
167
+ st.plotly_chart(fig)
168
+
169
+
170
  def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models):
171
  if freq == 'D':
172
  return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D']