leonsimon23 commited on
Commit
85178a1
·
verified ·
1 Parent(s): d86885c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -66
app.py CHANGED
@@ -5,6 +5,8 @@ import time
5
  import warnings
6
  import os
7
  import logging
 
 
8
 
9
  # 数据分析与建模
10
  from scipy import stats
@@ -40,7 +42,6 @@ if not os.path.exists(OUTPUT_DIR):
40
  # ======================== (B) 辅助函数 ========================
41
 
42
  def calculate_metrics(actual, predicted):
43
- # (此函数来自您的原始代码)
44
  metrics_df = pd.DataFrame({'actual': actual, 'predicted': predicted}).dropna()
45
  if metrics_df.empty:
46
  return {'MAE': np.nan, 'RMSE': np.nan, 'MAPE': np.nan, 'sMAPE': np.nan}
@@ -60,17 +61,24 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
60
  """
61
  # --- 1. 初始化 ---
62
  log_lines = ["## 🚀 数据分析流程已启动..."]
63
- figures = []
64
  final_report_text = ""
65
  report_file_path = None
66
 
67
- # 定义一个辅助函数来更新界面状态
 
 
 
 
 
 
 
 
68
  def update_ui(new_log_line=None):
69
  if new_log_line:
70
  log_lines.append(new_log_line)
71
- # 返回当前所有输出的状态
72
  # [log, gallery, final_report, download_button]
73
- return "\n\n".join(log_lines), figures, final_report_text, report_file_path
74
 
75
  yield update_ui() # 立即显示启动信息
76
 
@@ -89,95 +97,67 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
89
  ts_data = df['Value']
90
  yield update_ui()
91
 
92
- # --- 3. 平稳性检验 ---
93
  log_lines.append("### 2. 平稳性检验与差分")
94
- # (代码与您提供的一致)
95
- diff_order = 0
96
  current_data = ts_data.dropna()
97
  adf_result = adfuller(current_data)
98
- p_value = adf_result[1]
99
- if p_value < 0.05:
100
- msg = f"✅ 序列在 d=0 阶差分后达到平稳 (p={p_value:.4f})。"
101
- log_lines.append(msg)
102
  d_order = 0
103
  else:
104
  current_data_diff = current_data.diff().dropna()
105
  adf_result_diff = adfuller(current_data_diff)
106
- p_value_diff = adf_result_diff[1]
107
- if p_value_diff < 0.05:
108
- msg = f"✅ 序列在 d=1 阶差分后达到平稳 (p={p_value_diff:.4f})。"
109
- log_lines.append(msg)
110
- d_order = 1
111
- current_data = current_data_diff
112
  else:
113
- msg = f"⚠️ 1阶差分后仍未平稳 (p={p_value_diff:.4f}),将使用 d=1 继续分析。"
114
- log_lines.append(msg)
115
- d_order = 1
116
- current_data = current_data_diff
117
  ts_stationary = current_data
118
  yield update_ui()
119
 
120
  # --- 4. 白噪声检验 ---
121
  log_lines.append("### 3. 白噪声检验")
122
- # (代码与您提供的一致)
123
  lags = min(10, len(ts_stationary) // 5)
124
  lb_test_result = acorr_ljungbox(ts_stationary, lags=[lags], return_df=True)
125
- lb_p_value = lb_test_result['lb_pvalue'].iloc[0]
126
- if lb_p_value > 0.05:
127
- log_lines.append(f"⚠️ 序列可能是白噪声(p-value = {lb_p_value:.4f}),模型可能无效。")
128
  else:
129
- log_lines.append(f"✅ 通过白噪声检验 (p-value = {lb_p_value:.4f}),可以进行后续建模。")
130
- yield update_ui()
131
-
132
  # --- 5. 季节性检验与分解 ---
133
  log_lines.append("\n### 4. 季节性检验与STL分解")
134
  period = 365
135
- seasonal_enabled = len(ts_data) > 2 * 14 # 改为数据多于两周则开启季节性
136
  m_period = 7 if seasonal_enabled else 1
137
  log_lines.append(f"✅ 季节性参数设定: m={m_period}, seasonal={seasonal_enabled}")
138
 
139
  if len(ts_data) >= 2 * period:
140
- # 【关键修复】确保 seasonal 参数是奇数
141
- seasonal_period_for_stl = period if period % 2 != 0 else period + 1
142
  log_lines.append(f"✅ 准备进行STL分解,周期(period)={period},季节平滑窗口(seasonal)={seasonal_period_for_stl}。")
143
- yield update_ui() # 更新一下日志,让用户知道我们在做什么
144
-
145
  stl = STL(ts_data, period=period, seasonal=seasonal_period_for_stl)
146
  res = stl.fit()
147
-
148
- #fig, axes = plt.subplots(4, 1, figsize=(12, 8), sharex=True)
149
- # 使用 res.plot() 可以自动处理标签
150
- #res.plot(axes=axes)
151
- #fig.suptitle(f'STL 分解图 (周期={period})', fontsize=16)
152
- #plt.tight_layout(rect=[0, 0, 1, 0.96])
153
-
154
-
155
- # 【关键修复】直接调用 res.plot(),它会返回一个 Figure 对象
156
  fig = res.plot()
157
-
158
- # 调整 Figure 的大小和标题
159
  fig.set_size_inches(12, 8)
160
- fig.suptitle(f'STL 分解图 (周期={period})', fontsize=16, y=0.98) # 使用 y 参数调整标题位置
161
- plt.tight_layout() # 自动调整布局
162
-
163
- figures.append(fig)
164
  log_lines.append("✅ STL分解图已生成。")
165
  else:
166
  log_lines.append("⚠️ 数据长度不足以进行年度季节性分解。")
167
  yield update_ui()
168
 
169
-
170
-
171
  # --- 6. 混合策略回测优化窗口大小 ---
172
  log_lines.append("\n### 5. 优化训练窗口大小")
173
  log_lines.append("⏳ **此步骤计算量大,可能需要5-15分钟,请耐心等待...**")
174
  yield update_ui()
175
 
176
  def evaluate_window_hybrid(window_size, time_series, d, m, seasonal):
177
- # (此函数来自您的代码)
178
  errors = []
179
  series_values = time_series.values
180
- backtest_length = 100 # 减少回测长度以加速
181
  if len(series_values) <= window_size + backtest_length: return {'window_size': window_size, 'mae': np.inf}
182
  end_index = len(series_values)
183
  start_index = end_index - backtest_length
@@ -194,7 +174,7 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
194
  if not errors: return {'window_size': window_size, 'mae': np.inf}
195
  return {'window_size': window_size, 'mae': np.mean(np.abs(errors))}
196
 
197
- window_sizes_to_test = np.arange(70, 211, 14) # 增大步长以加速
198
  with Parallel(n_jobs=-1) as parallel:
199
  results = parallel(
200
  delayed(evaluate_window_hybrid)(ws, ts_data, d_order, m_period, seasonal_enabled) for ws in window_sizes_to_test
@@ -212,7 +192,7 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
212
  ax.plot(window_results_df.index, window_results_df['mae'], marker='o', label='MAE')
213
  ax.set_title('训练窗口大小对预测误差的影响')
214
  ax.set_xlabel('训练窗口天数'); ax.set_ylabel('误差值'); ax.legend(); ax.grid(True)
215
- figures.append(fig)
216
  yield update_ui()
217
 
218
  # --- 7 & 8. 动态滚动预测与评估 ---
@@ -226,23 +206,23 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
226
 
227
  # SARIMA 滚动
228
  sarima_rolling_preds = []
229
- for i in range(len(test_rolling_target)):
230
  train_window = ts_data.iloc[split_point_roll + i - best_window_size : split_point_roll + i]
231
  try:
232
  model = pm.auto_arima(train_window, d=d_order, m=m_period, seasonal=seasonal_enabled,
233
  stepwise=True, trace=False, error_action='ignore', suppress_warnings=True)
234
  sarima_rolling_preds.append(model.predict(n_periods=1)[0])
235
  except:
236
- sarima_rolling_preds.append(np.nan)
237
  rolling_predictions['Auto-SARIMA'] = pd.Series(sarima_rolling_preds, index=test_rolling_target.index).ffill()
238
  log_lines.append("✅ Auto-SARIMA 滚动预测完成。")
239
  yield update_ui()
240
 
241
- # Prophet 滚动 (简化策略)
242
  prophet_rolling_preds = []
243
  prophet_model = None
244
- for i, (date, value) in enumerate(test_rolling_target.items()):
245
- if i % 14 == 0 or prophet_model is None: # 每 14 天重训练
246
  train_upto_date = ts_data.loc[:date - pd.Timedelta(days=1)]
247
  prophet_train_df = train_upto_date.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'})
248
  prophet_model = Prophet(yearly_seasonality='auto', weekly_seasonality=seasonal_enabled, daily_seasonality=False).fit(prophet_train_df)
@@ -268,9 +248,9 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
268
  ax.plot(test_rolling_target, label='真实值 (测试集)', color='blue', linewidth=2)
269
  for model_name, preds in rolling_predictions.items():
270
  is_best = ' (最佳)' if model_name == best_rolling_model_name else ''
271
- ax.plot(preds, label=f'{model_name} 预测{is_best}', linestyle='--')
272
  ax.set_title('滚动预测结果对比'); ax.legend(); ax.grid(True)
273
- figures.append(fig)
274
  yield update_ui()
275
 
276
  # --- 10. 最终未来预测 ---
@@ -301,7 +281,7 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
301
  ax.plot(final_forecast_series, label=f'未来 {forecast_horizon} 天预测', color='red', linestyle='--')
302
  ax.fill_between(future_dates, conf_int[:, 0], conf_int[:, 1], color='red', alpha=0.2, label='95% 置信区间')
303
  ax.set_title(f'最终未来用量预测 (基于 {best_rolling_model_name})'); ax.legend(); ax.grid(True)
304
- figures.append(fig)
305
 
306
  # 生成最终报告
307
  final_report_text = f"""
@@ -328,7 +308,8 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
328
  - **预测摘要**:
329
  - 未来一周平均日用量: **{final_forecast_series.head(7).mean():.2f}**
330
  - 未来一月平均日用量: **{final_forecast_series.head(30).mean():.2f}**
331
- """
 
332
  report_file_path = os.path.join(OUTPUT_DIR, 'final_analysis_report.txt')
333
  with open(report_file_path, 'w', encoding='utf-8') as f:
334
  f.write(final_report_text)
@@ -338,7 +319,6 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
338
 
339
  except Exception as e:
340
  log_lines.append(f"\n\n❌ **分析过程中断,出现错误:**\n`{str(e)}`")
341
- import traceback
342
  log_lines.append(f"\n**Traceback:**\n```{traceback.format_exc()}```")
343
  yield update_ui()
344
 
@@ -361,7 +341,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
361
 
362
  with gr.Tabs():
363
  with gr.TabItem("📊 可视化图表", id=0):
364
- gallery_output = gr.Gallery(label="分析图表", elem_id="gallery", columns=[1], height="auto")
365
  with gr.TabItem("📝 实时分析日志", id=1):
366
  log_output = gr.Markdown("点击按钮后,分析日志将实时显示在这里...")
367
  with gr.TabItem("📋 最终报告与下载", id=2):
 
5
  import warnings
6
  import os
7
  import logging
8
+ import tempfile
9
+ import traceback
10
 
11
  # 数据分析与建模
12
  from scipy import stats
 
42
  # ======================== (B) 辅助函数 ========================
43
 
44
  def calculate_metrics(actual, predicted):
 
45
  metrics_df = pd.DataFrame({'actual': actual, 'predicted': predicted}).dropna()
46
  if metrics_df.empty:
47
  return {'MAE': np.nan, 'RMSE': np.nan, 'MAPE': np.nan, 'sMAPE': np.nan}
 
61
  """
62
  # --- 1. 初始化 ---
63
  log_lines = ["## 🚀 数据分析流程已启动..."]
64
+ figure_paths = []
65
  final_report_text = ""
66
  report_file_path = None
67
 
68
+ # 辅助函数,用于将Matplotlib Figure保存为临时图片文件并返回路径
69
+ def save_fig_to_path(fig):
70
+ # 使用 NamedTemporaryFile 来创建一个不会被立即删除的临时文件
71
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
72
+ fig.savefig(tmpfile.name)
73
+ figure_paths.append(tmpfile.name)
74
+ plt.close(fig) # 操作完成后关闭图形,释放内存
75
+
76
+ # 辅助函数,用于更新UI状态
77
  def update_ui(new_log_line=None):
78
  if new_log_line:
79
  log_lines.append(new_log_line)
 
80
  # [log, gallery, final_report, download_button]
81
+ return "\n\n".join(log_lines), figure_paths, final_report_text, report_file_path
82
 
83
  yield update_ui() # 立即显示启动信息
84
 
 
97
  ts_data = df['Value']
98
  yield update_ui()
99
 
100
+ # --- 3. 平稳性检验与差分 ---
101
  log_lines.append("### 2. 平稳性检验与差分")
 
 
102
  current_data = ts_data.dropna()
103
  adf_result = adfuller(current_data)
104
+ if adf_result[1] < 0.05:
105
+ log_lines.append(f"✅ 序列在 d=0 阶差分后达到平稳 (p={adf_result[1]:.4f})。")
 
 
106
  d_order = 0
107
  else:
108
  current_data_diff = current_data.diff().dropna()
109
  adf_result_diff = adfuller(current_data_diff)
110
+ d_order = 1
111
+ if adf_result_diff[1] < 0.05:
112
+ log_lines.append(f"✅ 序列在 d=1 阶差分后达到平稳 (p={adf_result_diff[1]:.4f})。")
 
 
 
113
  else:
114
+ log_lines.append(f"⚠️ 1阶差分后仍未平稳 (p={adf_result_diff[1]:.4f}),将使用 d=1 继续分析。")
115
+ current_data = current_data_diff
 
 
116
  ts_stationary = current_data
117
  yield update_ui()
118
 
119
  # --- 4. 白噪声检验 ---
120
  log_lines.append("### 3. 白噪声检验")
 
121
  lags = min(10, len(ts_stationary) // 5)
122
  lb_test_result = acorr_ljungbox(ts_stationary, lags=[lags], return_df=True)
123
+ if lb_test_result['lb_pvalue'].iloc[0] > 0.05:
124
+ log_lines.append(f"⚠️ 序列可能是白噪声(p-value = {lb_test_result['lb_pvalue'].iloc[0]:.4f}),模型可能无效。")
 
125
  else:
126
+ log_lines.append(f"✅ 通过白噪声检验 (p-value = {lb_test_result['lb_pvalue'].iloc[0]:.4f}),可以进行后续建模。")
127
+ yield update_ui()
128
+
129
  # --- 5. 季节性检验与分解 ---
130
  log_lines.append("\n### 4. 季节性检验与STL分解")
131
  period = 365
132
+ seasonal_enabled = len(ts_data) > 2 * 14 # 数据多于两周则开启周季节性
133
  m_period = 7 if seasonal_enabled else 1
134
  log_lines.append(f"✅ 季节性参数设定: m={m_period}, seasonal={seasonal_enabled}")
135
 
136
  if len(ts_data) >= 2 * period:
137
+ seasonal_period_for_stl = period if period % 2 != 0 else period + 1
 
138
  log_lines.append(f"✅ 准备进行STL分解,周期(period)={period},季节平滑窗口(seasonal)={seasonal_period_for_stl}。")
139
+ yield update_ui()
 
140
  stl = STL(ts_data, period=period, seasonal=seasonal_period_for_stl)
141
  res = stl.fit()
 
 
 
 
 
 
 
 
 
142
  fig = res.plot()
 
 
143
  fig.set_size_inches(12, 8)
144
+ fig.suptitle(f'STL 分解图 (周期={period})', fontsize=16, y=0.98)
145
+ plt.tight_layout()
146
+ save_fig_to_path(fig)
 
147
  log_lines.append("✅ STL分解图已生成。")
148
  else:
149
  log_lines.append("⚠️ 数据长度不足以进行年度季节性分解。")
150
  yield update_ui()
151
 
 
 
152
  # --- 6. 混合策略回测优化窗口大小 ---
153
  log_lines.append("\n### 5. 优化训练窗口大小")
154
  log_lines.append("⏳ **此步骤计算量大,可能需要5-15分钟,请耐心等待...**")
155
  yield update_ui()
156
 
157
  def evaluate_window_hybrid(window_size, time_series, d, m, seasonal):
 
158
  errors = []
159
  series_values = time_series.values
160
+ backtest_length = 100
161
  if len(series_values) <= window_size + backtest_length: return {'window_size': window_size, 'mae': np.inf}
162
  end_index = len(series_values)
163
  start_index = end_index - backtest_length
 
174
  if not errors: return {'window_size': window_size, 'mae': np.inf}
175
  return {'window_size': window_size, 'mae': np.mean(np.abs(errors))}
176
 
177
+ window_sizes_to_test = np.arange(70, 211, 14)
178
  with Parallel(n_jobs=-1) as parallel:
179
  results = parallel(
180
  delayed(evaluate_window_hybrid)(ws, ts_data, d_order, m_period, seasonal_enabled) for ws in window_sizes_to_test
 
192
  ax.plot(window_results_df.index, window_results_df['mae'], marker='o', label='MAE')
193
  ax.set_title('训练窗口大小对预测误差的影响')
194
  ax.set_xlabel('训练窗口天数'); ax.set_ylabel('误差值'); ax.legend(); ax.grid(True)
195
+ save_fig_to_path(fig)
196
  yield update_ui()
197
 
198
  # --- 7 & 8. 动态滚动预测与评估 ---
 
206
 
207
  # SARIMA 滚动
208
  sarima_rolling_preds = []
209
+ for i in progress.tqdm(range(len(test_rolling_target)), desc="SARIMA Rolling Forecast"):
210
  train_window = ts_data.iloc[split_point_roll + i - best_window_size : split_point_roll + i]
211
  try:
212
  model = pm.auto_arima(train_window, d=d_order, m=m_period, seasonal=seasonal_enabled,
213
  stepwise=True, trace=False, error_action='ignore', suppress_warnings=True)
214
  sarima_rolling_preds.append(model.predict(n_periods=1)[0])
215
  except:
216
+ sarima_rolling_preds.append(sarima_rolling_preds[-1] if sarima_rolling_preds else np.nan)
217
  rolling_predictions['Auto-SARIMA'] = pd.Series(sarima_rolling_preds, index=test_rolling_target.index).ffill()
218
  log_lines.append("✅ Auto-SARIMA 滚动预测完成。")
219
  yield update_ui()
220
 
221
+ # Prophet 滚动
222
  prophet_rolling_preds = []
223
  prophet_model = None
224
+ for i, (date, value) in enumerate(progress.tqdm(test_rolling_target.items(), desc="Prophet Rolling Forecast")):
225
+ if i % 14 == 0 or prophet_model is None:
226
  train_upto_date = ts_data.loc[:date - pd.Timedelta(days=1)]
227
  prophet_train_df = train_upto_date.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'})
228
  prophet_model = Prophet(yearly_seasonality='auto', weekly_seasonality=seasonal_enabled, daily_seasonality=False).fit(prophet_train_df)
 
248
  ax.plot(test_rolling_target, label='真实值 (测试集)', color='blue', linewidth=2)
249
  for model_name, preds in rolling_predictions.items():
250
  is_best = ' (最佳)' if model_name == best_rolling_model_name else ''
251
+ ax.plot(preds.dropna(), label=f'{model_name} 预测{is_best}', linestyle='--')
252
  ax.set_title('滚动预测结果对比'); ax.legend(); ax.grid(True)
253
+ save_fig_to_path(fig)
254
  yield update_ui()
255
 
256
  # --- 10. 最终未来预测 ---
 
281
  ax.plot(final_forecast_series, label=f'未来 {forecast_horizon} 天预测', color='red', linestyle='--')
282
  ax.fill_between(future_dates, conf_int[:, 0], conf_int[:, 1], color='red', alpha=0.2, label='95% 置信区间')
283
  ax.set_title(f'最终未来用量预测 (基于 {best_rolling_model_name})'); ax.legend(); ax.grid(True)
284
+ save_fig_to_path(fig)
285
 
286
  # 生成最终报告
287
  final_report_text = f"""
 
308
  - **预测摘要**:
309
  - 未来一周平均日用量: **{final_forecast_series.head(7).mean():.2f}**
310
  - 未来一月平均日用量: **{final_forecast_series.head(30).mean():.2f}**
311
+ """.strip()
312
+
313
  report_file_path = os.path.join(OUTPUT_DIR, 'final_analysis_report.txt')
314
  with open(report_file_path, 'w', encoding='utf-8') as f:
315
  f.write(final_report_text)
 
319
 
320
  except Exception as e:
321
  log_lines.append(f"\n\n❌ **分析过程中断,出现错误:**\n`{str(e)}`")
 
322
  log_lines.append(f"\n**Traceback:**\n```{traceback.format_exc()}```")
323
  yield update_ui()
324
 
 
341
 
342
  with gr.Tabs():
343
  with gr.TabItem("📊 可视化图表", id=0):
344
+ gallery_output = gr.Gallery(label="分析图表", elem_id="gallery", columns=[1], height="auto", object_fit="contain")
345
  with gr.TabItem("📝 实时分析日志", id=1):
346
  log_output = gr.Markdown("点击按钮后,分析日志将实时显示在这里...")
347
  with gr.TabItem("📋 最终报告与下载", id=2):