leonsimon23 commited on
Commit
22e5b2e
·
verified ·
1 Parent(s): 7721301

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +362 -0
app.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ======================== (A) 导入库和配置环境 ========================
2
+ import pandas as pd
3
+ import numpy as np
4
+ import time
5
+ import warnings
6
+ import os
7
+ import logging
8
+
9
+ # 数据分析与建模
10
+ from scipy import stats
11
+ from statsmodels.tsa.stattools import adfuller
12
+ from statsmodels.tsa.seasonal import STL
13
+ from statsmodels.stats.diagnostic import acorr_ljungbox
14
+ from prophet import Prophet
15
+ import pmdarima as pm
16
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
17
+ from joblib import Parallel, delayed
18
+
19
+ # 可视化
20
+ import matplotlib.pyplot as plt
21
+ import seaborn as sns
22
+
23
+ # Web UI
24
+ import gradio as gr
25
+
26
+ # --- 全局设置 ---
27
+ warnings.filterwarnings("ignore")
28
+ logging.getLogger('prophet').setLevel(logging.ERROR)
29
+ logging.getLogger('cmdstanpy').setLevel(logging.ERROR)
30
+
31
+ # 配置中文字体 (使用 packages.txt 安装的字体)
32
+ plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei']
33
+ plt.rcParams['axes.unicode_minus'] = False
34
+
35
+ # --- 输出文件夹设置 ---
36
+ OUTPUT_DIR = 'outputs'
37
+ if not os.path.exists(OUTPUT_DIR):
38
+ os.makedirs(OUTPUT_DIR)
39
+
40
+ # ======================== (B) 辅助函数 ========================
41
+
42
+ def calculate_metrics(actual, predicted):
43
+ # (此函数来自您的原始代码)
44
+ metrics_df = pd.DataFrame({'actual': actual, 'predicted': predicted}).dropna()
45
+ if metrics_df.empty:
46
+ return {'MAE': np.nan, 'RMSE': np.nan, 'MAPE': np.nan, 'sMAPE': np.nan}
47
+ clean_actual, clean_predicted = metrics_df['actual'], metrics_df['predicted']
48
+ mae = mean_absolute_error(clean_actual, clean_predicted)
49
+ rmse = np.sqrt(mean_squared_error(clean_actual, clean_predicted))
50
+ actual_safe = np.where(clean_actual == 0, 1e-6, clean_actual)
51
+ mape = np.mean(np.abs((clean_actual - clean_predicted) / actual_safe)) * 100
52
+ smape = 200 * np.mean(np.abs(clean_actual - clean_predicted) / (np.abs(clean_actual) + np.abs(clean_predicted)))
53
+ return {'MAE': mae, 'RMSE': rmse, 'MAPE': mape, 'sMAPE': smape}
54
+
55
+ # ======================== (C) 主分析函数 (Gradio核心) ========================
56
+
57
+ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
58
+ """
59
+ 这个主函数封装了所有的分析步骤,并通过 yield 返回结果来实时更新Gradio界面。
60
+ """
61
+ # --- 1. 初始化 ---
62
+ log_lines = ["## 🚀 数据分析流程已启动..."]
63
+ figures = []
64
+ final_report_text = ""
65
+ report_file_path = None
66
+
67
+ # 定义一个辅助函数来更新界面状态
68
+ def update_ui(new_log_line=None):
69
+ if new_log_line:
70
+ log_lines.append(new_log_line)
71
+ # 返回当前所有输出的状态
72
+ # [log, gallery, final_report, download_button]
73
+ return "\n\n".join(log_lines), figures, final_report_text, report_file_path
74
+
75
+ yield update_ui() # 立即显示启动信息
76
+
77
+ try:
78
+ # --- 2. 数据清洗 ---
79
+ log_lines.append("### 1. 数据清洗")
80
+ df = pd.read_excel('gmqrkl.xlsx')
81
+ df['Date'] = pd.to_datetime(df['Date'])
82
+ df = df.drop_duplicates(subset=['Date']).sort_values('Date').reset_index(drop=True)
83
+ log_lines.append(f"✅ 数据读取并去重成功,共 {len(df)} 行。")
84
+
85
+ df['Value'] = df['Value'].replace(0, np.nan)
86
+ df['Value'].interpolate(method='linear', limit_direction='both', inplace=True)
87
+ log_lines.append("✅ 零值替换与线性插值完成。")
88
+ df.set_index('Date', inplace=True)
89
+ ts_data = df['Value']
90
+ yield update_ui()
91
+
92
+ # --- 3. 平稳性检验 ---
93
+ log_lines.append("### 2. 平稳性检验与差分")
94
+ # (代码与您提供的一致)
95
+ diff_order = 0
96
+ current_data = ts_data.dropna()
97
+ adf_result = adfuller(current_data)
98
+ p_value = adf_result[1]
99
+ if p_value < 0.05:
100
+ msg = f"✅ 序列在 d=0 阶差分后达到平稳 (p={p_value:.4f})。"
101
+ log_lines.append(msg)
102
+ d_order = 0
103
+ else:
104
+ current_data_diff = current_data.diff().dropna()
105
+ adf_result_diff = adfuller(current_data_diff)
106
+ p_value_diff = adf_result_diff[1]
107
+ if p_value_diff < 0.05:
108
+ msg = f"✅ 序列在 d=1 阶差分后达到平稳 (p={p_value_diff:.4f})。"
109
+ log_lines.append(msg)
110
+ d_order = 1
111
+ current_data = current_data_diff
112
+ else:
113
+ msg = f"⚠️ 1阶差分后仍未平稳 (p={p_value_diff:.4f}),将使用 d=1 继续分析。"
114
+ log_lines.append(msg)
115
+ d_order = 1
116
+ current_data = current_data_diff
117
+ ts_stationary = current_data
118
+ yield update_ui()
119
+
120
+ # --- 4. 白噪声检验 ---
121
+ log_lines.append("### 3. 白噪声检验")
122
+ # (代码与您提供的一致)
123
+ lags = min(10, len(ts_stationary) // 5)
124
+ lb_test_result = acorr_ljungbox(ts_stationary, lags=[lags], return_df=True)
125
+ lb_p_value = lb_test_result['lb_pvalue'].iloc[0]
126
+ if lb_p_value > 0.05:
127
+ log_lines.append(f"⚠️ 序列可能是白噪声(p-value = {lb_p_value:.4f}),模型可能无效。")
128
+ else:
129
+ log_lines.append(f"✅ 通过白噪声检验 (p-value = {lb_p_value:.4f}),可以进行后续建模。")
130
+ yield update_ui()
131
+
132
+ # --- 5. 季节性检验与分解 ---
133
+ log_lines.append("### 4. 季节性检验与STL分解")
134
+ # (简化了您的季节性检验逻辑,直接检查周期性并设定参数)
135
+ period = 365
136
+ seasonal_enabled = len(ts_data) > 2 * 7 # 简化判断,若数据多于两周则开启周季节性
137
+ m_period = 7 if seasonal_enabled else 1
138
+ log_lines.append(f"✅ 季节性参数设定: m={m_period}, seasonal={seasonal_enabled}")
139
+
140
+ if len(ts_data) >= 2 * period:
141
+ stl = STL(ts_data, period=period, seasonal=period+1)
142
+ res = stl.fit()
143
+ fig, axes = plt.subplots(4, 1, figsize=(12, 8), sharex=True)
144
+ res.plot(axes=axes)
145
+ fig.suptitle('STL 分解图 (周期=365)', fontsize=16)
146
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
147
+ figures.append(fig)
148
+ log_lines.append("✅ STL分解图已生成。")
149
+ else:
150
+ log_lines.append("⚠️ 数据长度不足以进行年度季节性分解。")
151
+ yield update_ui()
152
+
153
+ # --- 6. 混合策略回测优化窗口大小 ---
154
+ log_lines.append("\n### 5. 优化训练窗口大小")
155
+ log_lines.append("⏳ **此步骤计算量大,可能需要5-15分钟,请耐心等待...**")
156
+ yield update_ui()
157
+
158
+ def evaluate_window_hybrid(window_size, time_series, d, m, seasonal):
159
+ # (此函数来自您的代码)
160
+ errors = []
161
+ series_values = time_series.values
162
+ backtest_length = 100 # 减少回测长度以加速
163
+ if len(series_values) <= window_size + backtest_length: return {'window_size': window_size, 'mae': np.inf}
164
+ end_index = len(series_values)
165
+ start_index = end_index - backtest_length
166
+ for i in range(start_index, end_index):
167
+ train_window = pd.Series(series_values[i-window_size:i], index=time_series.index[i-window_size:i])
168
+ test_point = series_values[i]
169
+ use_seasonal = seasonal and (len(train_window) >= 2 * m)
170
+ try:
171
+ model = pm.auto_arima(train_window, d=d, seasonal=use_seasonal, m=m, max_p=2, max_q=2,
172
+ stepwise=True, trace=False, error_action='ignore', suppress_warnings=True)
173
+ forecast = model.predict(n_periods=1)[0]
174
+ errors.append(test_point - forecast)
175
+ except Exception: continue
176
+ if not errors: return {'window_size': window_size, 'mae': np.inf}
177
+ return {'window_size': window_size, 'mae': np.mean(np.abs(errors))}
178
+
179
+ window_sizes_to_test = np.arange(70, 211, 14) # 增大步长以加速
180
+ with Parallel(n_jobs=-1) as parallel:
181
+ results = parallel(
182
+ delayed(evaluate_window_hybrid)(ws, ts_data, d_order, m_period, seasonal_enabled) for ws in window_sizes_to_test
183
+ )
184
+
185
+ window_results_df = pd.DataFrame(results).sort_values('mae').set_index('window_size')
186
+ if not window_results_df.empty and np.isfinite(window_results_df['mae'].min()):
187
+ best_window_size = window_results_df['mae'].idxmin()
188
+ log_lines.append(f"✅ **窗口优化完成!** 基于MAE,最佳训练窗口大小为: **{best_window_size}** 天。")
189
+ else:
190
+ best_window_size = 90
191
+ log_lines.append(f"⚠️ 窗口优化失败,使用默认窗口大小: {best_window_size} 天。")
192
+
193
+ fig, ax = plt.subplots(figsize=(12, 6))
194
+ ax.plot(window_results_df.index, window_results_df['mae'], marker='o', label='MAE')
195
+ ax.set_title('训练窗口大小对预测误差的影响')
196
+ ax.set_xlabel('训练窗口天数'); ax.set_ylabel('误差值'); ax.legend(); ax.grid(True)
197
+ figures.append(fig)
198
+ yield update_ui()
199
+
200
+ # --- 7 & 8. 动态滚动预测与评估 ---
201
+ log_lines.append("\n### 6. 动态滚动预测与评估")
202
+ log_lines.append("⏳ **此步骤同样耗时,正在进行模型滚动预测...**")
203
+ yield update_ui()
204
+
205
+ split_point_roll = int(len(ts_data) * 0.8)
206
+ test_rolling_target = ts_data.iloc[split_point_roll:]
207
+ rolling_predictions = {}
208
+
209
+ # SARIMA 滚动
210
+ sarima_rolling_preds = []
211
+ for i in range(len(test_rolling_target)):
212
+ train_window = ts_data.iloc[split_point_roll + i - best_window_size : split_point_roll + i]
213
+ try:
214
+ model = pm.auto_arima(train_window, d=d_order, m=m_period, seasonal=seasonal_enabled,
215
+ stepwise=True, trace=False, error_action='ignore', suppress_warnings=True)
216
+ sarima_rolling_preds.append(model.predict(n_periods=1)[0])
217
+ except:
218
+ sarima_rolling_preds.append(np.nan)
219
+ rolling_predictions['Auto-SARIMA'] = pd.Series(sarima_rolling_preds, index=test_rolling_target.index).ffill()
220
+ log_lines.append("✅ Auto-SARIMA 滚动预测完成。")
221
+ yield update_ui()
222
+
223
+ # Prophet 滚动 (简化策略)
224
+ prophet_rolling_preds = []
225
+ prophet_model = None
226
+ for i, (date, value) in enumerate(test_rolling_target.items()):
227
+ if i % 14 == 0 or prophet_model is None: # 每 14 天重训练
228
+ train_upto_date = ts_data.loc[:date - pd.Timedelta(days=1)]
229
+ prophet_train_df = train_upto_date.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'})
230
+ prophet_model = Prophet(yearly_seasonality='auto', weekly_seasonality=seasonal_enabled, daily_seasonality=False).fit(prophet_train_df)
231
+ future_df = pd.DataFrame({'ds': [date]})
232
+ forecast = prophet_model.predict(future_df)
233
+ prophet_rolling_preds.append(forecast['yhat'].iloc[0])
234
+ rolling_predictions['Prophet'] = pd.Series(prophet_rolling_preds, index=test_rolling_target.index)
235
+ log_lines.append("✅ Prophet 滚动预测完成。")
236
+
237
+ # 评估
238
+ rolling_evaluation_results = {name: calculate_metrics(test_rolling_target, preds) for name, preds in rolling_predictions.items()}
239
+ rolling_evaluation_df = pd.DataFrame(rolling_evaluation_results).T.sort_values(by='MAE')
240
+ best_rolling_model_name = rolling_evaluation_df.index[0]
241
+ log_lines.append("\n**滚动预测性能对比:**")
242
+ log_lines.append(f"```\n{rolling_evaluation_df.to_markdown()}\n```")
243
+ log_lines.append(f"\n==> ✅ 最佳滚动预测模型是: **{best_rolling_model_name}**")
244
+ yield update_ui()
245
+
246
+ # --- 9. 滚动预测可视化 ---
247
+ log_lines.append("\n### 7. 生成结果图表")
248
+ fig, ax = plt.subplots(figsize=(15, 8))
249
+ ax.plot(ts_data, label='历史数据', color='gray', alpha=0.5)
250
+ ax.plot(test_rolling_target, label='真实值 (测试集)', color='blue', linewidth=2)
251
+ for model_name, preds in rolling_predictions.items():
252
+ is_best = ' (最佳)' if model_name == best_rolling_model_name else ''
253
+ ax.plot(preds, label=f'{model_name} 预测{is_best}', linestyle='--')
254
+ ax.set_title('滚动预测结果对比'); ax.legend(); ax.grid(True)
255
+ figures.append(fig)
256
+ yield update_ui()
257
+
258
+ # --- 10. 最终未来预测 ---
259
+ forecast_horizon = 90
260
+ log_lines.append(f"\n### 8. 使用最佳模型 `{best_rolling_model_name}` 预测未来 {forecast_horizon} 天")
261
+
262
+ # 训练最终模型
263
+ if 'Auto-SARIMA' in best_rolling_model_name:
264
+ final_train_data = ts_data.iloc[-best_window_size:]
265
+ final_model = pm.auto_arima(final_train_data, d=d_order, m=m_period, seasonal=seasonal_enabled,
266
+ stepwise=True, trace=False, error_action='ignore', suppress_warnings=True)
267
+ final_forecast_values, conf_int = final_model.predict(n_periods=forecast_horizon, return_conf_int=True)
268
+ else: # Prophet
269
+ final_train_data = ts_data
270
+ final_prophet_train_df = final_train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'})
271
+ final_model = Prophet(yearly_seasonality='auto', weekly_seasonality=seasonal_enabled, daily_seasonality=False).fit(final_prophet_train_df)
272
+ future_df = final_model.make_future_dataframe(periods=forecast_horizon, freq='D')
273
+ forecast_obj = final_model.predict(future_df)
274
+ final_forecast_values = forecast_obj['yhat'].iloc[-forecast_horizon:].values
275
+ conf_int = np.column_stack((forecast_obj['yhat_lower'].iloc[-forecast_horizon:].values, forecast_obj['yhat_upper'].iloc[-forecast_horizon:].values))
276
+
277
+ future_dates = pd.date_range(start=ts_data.index[-1] + pd.Timedelta(days=1), periods=forecast_horizon)
278
+ final_forecast_series = pd.Series(final_forecast_values, index=future_dates)
279
+
280
+ # 可视化最终预测
281
+ fig, ax = plt.subplots(figsize=(15, 8))
282
+ ax.plot(ts_data.tail(365), label='近期历史数据', color='blue')
283
+ ax.plot(final_forecast_series, label=f'未来 {forecast_horizon} 天预测', color='red', linestyle='--')
284
+ ax.fill_between(future_dates, conf_int[:, 0], conf_int[:, 1], color='red', alpha=0.2, label='95% 置信区间')
285
+ ax.set_title(f'最终未来用量预测 (基于 {best_rolling_model_name})'); ax.legend(); ax.grid(True)
286
+ figures.append(fig)
287
+
288
+ # 生成最终报告
289
+ final_report_text = f"""
290
+ # 药品用量预测分析报告
291
+
292
+ ## 1. 数据概览
293
+ - **数据时间范围**: {ts_data.index.min().strftime('%Y-%m-%d')} to {ts_data.index.max().strftime('%Y-%m-%d')}
294
+ - **总数据点**: {len(ts_data)}
295
+ - **平均用量**: {ts_data.mean():.2f}
296
+
297
+ ## 2. 分析与建模参数
298
+ - **平稳性差���阶数 (d)**: {d_order}
299
+ - **季节性周期 (m)**: {m_period}
300
+ - **最佳训练窗口**: {best_window_size} 天
301
+
302
+ ## 3. 模型评估 (基于动态滚动预测)
303
+ 通过在历史数据上进行滚动预测,我们能更真实地评估模型在实际应用中的表现。
304
+
305
+ {rolling_evaluation_df.to_markdown()}
306
+
307
+ ## 4. 最终结论与未来预测
308
+ - **最佳模型**: **{best_rolling_model_name}** 被选为最终预测模型,因为它在滚动预测中表现最佳(MAE最低)。
309
+ - **未来预测**: 已使用 `{best_rolling_model_name}` 模型对未来 **{forecast_horizon}** 天的用量进行预测。
310
+ - **预测摘要**:
311
+ - 未来一周平均日用量: **{final_forecast_series.head(7).mean():.2f}**
312
+ - 未来一月平均日用量: **{final_forecast_series.head(30).mean():.2f}**
313
+ """
314
+ report_file_path = os.path.join(OUTPUT_DIR, 'final_analysis_report.txt')
315
+ with open(report_file_path, 'w', encoding='utf-8') as f:
316
+ f.write(final_report_text)
317
+
318
+ log_lines.append("\n## 🎉 全部分析流程完成!请查看最终报告和图表。")
319
+ yield update_ui()
320
+
321
+ except Exception as e:
322
+ log_lines.append(f"\n\n❌ **分析过程中断,出现错误:**\n`{str(e)}`")
323
+ import traceback
324
+ log_lines.append(f"\n**Traceback:**\n```{traceback.format_exc()}```")
325
+ yield update_ui()
326
+
327
+
328
+ # ======================== (D) Gradio 界面构建 ========================
329
+
330
+ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
331
+ gr.Markdown(
332
+ """
333
+ # 📈 智能时序预测与分析平台 (动态回测版)
334
+ 欢迎使用!请确保您的数据文件 `gmqrkl.xlsx` 已上传至本 Space 的文件库中。
335
+ 然后,点击下方按钮,启动包含 **窗口优化** 和 **动态滚动预测** 的完整分析流程。
336
+ **注意:完整流程计算量大,可能需要10-20分钟。请耐心等待,并观察下方日志区的实时进度。**
337
+ """
338
+ )
339
+
340
+ start_button = gr.Button("🚀 点击这里,开始完整分析", variant="primary")
341
+
342
+ gr.Markdown("---")
343
+
344
+ with gr.Tabs():
345
+ with gr.TabItem("📊 可视化图表", id=0):
346
+ gallery_output = gr.Gallery(label="分析图表", elem_id="gallery", columns=[1], height="auto")
347
+ with gr.TabItem("📝 实时分析日志", id=1):
348
+ log_output = gr.Markdown("点击按钮后,分析日志将实时显示在这里...")
349
+ with gr.TabItem("📋 最终报告与下载", id=2):
350
+ final_report_output = gr.Markdown("分析完成后,最终报告将显示在这里。")
351
+ download_output = gr.File(label="下载报告文件")
352
+
353
+ start_button.click(
354
+ fn=run_full_analysis,
355
+ inputs=None,
356
+ outputs=[log_output, gallery_output, final_report_output, download_output]
357
+ )
358
+
359
+ gr.Markdown("<p style='text-align: center; font-size: 12px; color: grey;'>Powered by Gradio and Hugging Face Spaces.</p>")
360
+
361
+ if __name__ == "__main__":
362
+ demo.launch()