Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,8 @@ import time
|
|
5 |
import warnings
|
6 |
import os
|
7 |
import logging
|
|
|
|
|
8 |
|
9 |
# 数据分析与建模
|
10 |
from scipy import stats
|
@@ -40,7 +42,6 @@ if not os.path.exists(OUTPUT_DIR):
|
|
40 |
# ======================== (B) 辅助函数 ========================
|
41 |
|
42 |
def calculate_metrics(actual, predicted):
|
43 |
-
# (此函数来自您的原始代码)
|
44 |
metrics_df = pd.DataFrame({'actual': actual, 'predicted': predicted}).dropna()
|
45 |
if metrics_df.empty:
|
46 |
return {'MAE': np.nan, 'RMSE': np.nan, 'MAPE': np.nan, 'sMAPE': np.nan}
|
@@ -60,17 +61,24 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
|
|
60 |
"""
|
61 |
# --- 1. 初始化 ---
|
62 |
log_lines = ["## 🚀 数据分析流程已启动..."]
|
63 |
-
|
64 |
final_report_text = ""
|
65 |
report_file_path = None
|
66 |
|
67 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
def update_ui(new_log_line=None):
|
69 |
if new_log_line:
|
70 |
log_lines.append(new_log_line)
|
71 |
-
# 返回当前所有输出的状态
|
72 |
# [log, gallery, final_report, download_button]
|
73 |
-
return "\n\n".join(log_lines),
|
74 |
|
75 |
yield update_ui() # 立即显示启动信息
|
76 |
|
@@ -89,95 +97,67 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
|
|
89 |
ts_data = df['Value']
|
90 |
yield update_ui()
|
91 |
|
92 |
-
# --- 3.
|
93 |
log_lines.append("### 2. 平稳性检验与差分")
|
94 |
-
# (代码与您提供的一致)
|
95 |
-
diff_order = 0
|
96 |
current_data = ts_data.dropna()
|
97 |
adf_result = adfuller(current_data)
|
98 |
-
|
99 |
-
|
100 |
-
msg = f"✅ 序列在 d=0 阶差分后达到平稳 (p={p_value:.4f})。"
|
101 |
-
log_lines.append(msg)
|
102 |
d_order = 0
|
103 |
else:
|
104 |
current_data_diff = current_data.diff().dropna()
|
105 |
adf_result_diff = adfuller(current_data_diff)
|
106 |
-
|
107 |
-
if
|
108 |
-
|
109 |
-
log_lines.append(msg)
|
110 |
-
d_order = 1
|
111 |
-
current_data = current_data_diff
|
112 |
else:
|
113 |
-
|
114 |
-
|
115 |
-
d_order = 1
|
116 |
-
current_data = current_data_diff
|
117 |
ts_stationary = current_data
|
118 |
yield update_ui()
|
119 |
|
120 |
# --- 4. 白噪声检验 ---
|
121 |
log_lines.append("### 3. 白噪声检验")
|
122 |
-
# (代码与您提供的一致)
|
123 |
lags = min(10, len(ts_stationary) // 5)
|
124 |
lb_test_result = acorr_ljungbox(ts_stationary, lags=[lags], return_df=True)
|
125 |
-
|
126 |
-
|
127 |
-
log_lines.append(f"⚠️ 序列可能是白噪声(p-value = {lb_p_value:.4f}),模型可能无效。")
|
128 |
else:
|
129 |
-
log_lines.append(f"✅ 通过白噪声检验 (p-value = {
|
130 |
-
yield update_ui()
|
131 |
-
|
132 |
# --- 5. 季节性检验与分解 ---
|
133 |
log_lines.append("\n### 4. 季节性检验与STL分解")
|
134 |
period = 365
|
135 |
-
seasonal_enabled = len(ts_data) > 2 * 14 #
|
136 |
m_period = 7 if seasonal_enabled else 1
|
137 |
log_lines.append(f"✅ 季节性参数设定: m={m_period}, seasonal={seasonal_enabled}")
|
138 |
|
139 |
if len(ts_data) >= 2 * period:
|
140 |
-
|
141 |
-
seasonal_period_for_stl = period if period % 2 != 0 else period + 1
|
142 |
log_lines.append(f"✅ 准备进行STL分解,周期(period)={period},季节平滑窗口(seasonal)={seasonal_period_for_stl}。")
|
143 |
-
yield update_ui()
|
144 |
-
|
145 |
stl = STL(ts_data, period=period, seasonal=seasonal_period_for_stl)
|
146 |
res = stl.fit()
|
147 |
-
|
148 |
-
#fig, axes = plt.subplots(4, 1, figsize=(12, 8), sharex=True)
|
149 |
-
# 使用 res.plot() 可以自动处理标签
|
150 |
-
#res.plot(axes=axes)
|
151 |
-
#fig.suptitle(f'STL 分解图 (周期={period})', fontsize=16)
|
152 |
-
#plt.tight_layout(rect=[0, 0, 1, 0.96])
|
153 |
-
|
154 |
-
|
155 |
-
# 【关键修复】直接调用 res.plot(),它会返回一个 Figure 对象
|
156 |
fig = res.plot()
|
157 |
-
|
158 |
-
# 调整 Figure 的大小和标题
|
159 |
fig.set_size_inches(12, 8)
|
160 |
-
fig.suptitle(f'STL 分解图 (周期={period})', fontsize=16, y=0.98)
|
161 |
-
plt.tight_layout()
|
162 |
-
|
163 |
-
figures.append(fig)
|
164 |
log_lines.append("✅ STL分解图已生成。")
|
165 |
else:
|
166 |
log_lines.append("⚠️ 数据长度不足以进行年度季节性分解。")
|
167 |
yield update_ui()
|
168 |
|
169 |
-
|
170 |
-
|
171 |
# --- 6. 混合策略回测优化窗口大小 ---
|
172 |
log_lines.append("\n### 5. 优化训练窗口大小")
|
173 |
log_lines.append("⏳ **此步骤计算量大,可能需要5-15分钟,请耐心等待...**")
|
174 |
yield update_ui()
|
175 |
|
176 |
def evaluate_window_hybrid(window_size, time_series, d, m, seasonal):
|
177 |
-
# (此函数来自您的代码)
|
178 |
errors = []
|
179 |
series_values = time_series.values
|
180 |
-
backtest_length = 100
|
181 |
if len(series_values) <= window_size + backtest_length: return {'window_size': window_size, 'mae': np.inf}
|
182 |
end_index = len(series_values)
|
183 |
start_index = end_index - backtest_length
|
@@ -194,7 +174,7 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
|
|
194 |
if not errors: return {'window_size': window_size, 'mae': np.inf}
|
195 |
return {'window_size': window_size, 'mae': np.mean(np.abs(errors))}
|
196 |
|
197 |
-
window_sizes_to_test = np.arange(70, 211, 14)
|
198 |
with Parallel(n_jobs=-1) as parallel:
|
199 |
results = parallel(
|
200 |
delayed(evaluate_window_hybrid)(ws, ts_data, d_order, m_period, seasonal_enabled) for ws in window_sizes_to_test
|
@@ -212,7 +192,7 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
|
|
212 |
ax.plot(window_results_df.index, window_results_df['mae'], marker='o', label='MAE')
|
213 |
ax.set_title('训练窗口大小对预测误差的影响')
|
214 |
ax.set_xlabel('训练窗口天数'); ax.set_ylabel('误差值'); ax.legend(); ax.grid(True)
|
215 |
-
|
216 |
yield update_ui()
|
217 |
|
218 |
# --- 7 & 8. 动态滚动预测与评估 ---
|
@@ -226,23 +206,23 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
|
|
226 |
|
227 |
# SARIMA 滚动
|
228 |
sarima_rolling_preds = []
|
229 |
-
for i in range(len(test_rolling_target)):
|
230 |
train_window = ts_data.iloc[split_point_roll + i - best_window_size : split_point_roll + i]
|
231 |
try:
|
232 |
model = pm.auto_arima(train_window, d=d_order, m=m_period, seasonal=seasonal_enabled,
|
233 |
stepwise=True, trace=False, error_action='ignore', suppress_warnings=True)
|
234 |
sarima_rolling_preds.append(model.predict(n_periods=1)[0])
|
235 |
except:
|
236 |
-
sarima_rolling_preds.append(np.nan)
|
237 |
rolling_predictions['Auto-SARIMA'] = pd.Series(sarima_rolling_preds, index=test_rolling_target.index).ffill()
|
238 |
log_lines.append("✅ Auto-SARIMA 滚动预测完成。")
|
239 |
yield update_ui()
|
240 |
|
241 |
-
# Prophet 滚动
|
242 |
prophet_rolling_preds = []
|
243 |
prophet_model = None
|
244 |
-
for i, (date, value) in enumerate(test_rolling_target.items()):
|
245 |
-
if i % 14 == 0 or prophet_model is None:
|
246 |
train_upto_date = ts_data.loc[:date - pd.Timedelta(days=1)]
|
247 |
prophet_train_df = train_upto_date.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'})
|
248 |
prophet_model = Prophet(yearly_seasonality='auto', weekly_seasonality=seasonal_enabled, daily_seasonality=False).fit(prophet_train_df)
|
@@ -268,9 +248,9 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
|
|
268 |
ax.plot(test_rolling_target, label='真实值 (测试集)', color='blue', linewidth=2)
|
269 |
for model_name, preds in rolling_predictions.items():
|
270 |
is_best = ' (最佳)' if model_name == best_rolling_model_name else ''
|
271 |
-
ax.plot(preds, label=f'{model_name} 预测{is_best}', linestyle='--')
|
272 |
ax.set_title('滚动预测结果对比'); ax.legend(); ax.grid(True)
|
273 |
-
|
274 |
yield update_ui()
|
275 |
|
276 |
# --- 10. 最终未来预测 ---
|
@@ -301,7 +281,7 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
|
|
301 |
ax.plot(final_forecast_series, label=f'未来 {forecast_horizon} 天预测', color='red', linestyle='--')
|
302 |
ax.fill_between(future_dates, conf_int[:, 0], conf_int[:, 1], color='red', alpha=0.2, label='95% 置信区间')
|
303 |
ax.set_title(f'最终未来用量预测 (基于 {best_rolling_model_name})'); ax.legend(); ax.grid(True)
|
304 |
-
|
305 |
|
306 |
# 生成最终报告
|
307 |
final_report_text = f"""
|
@@ -328,7 +308,8 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
|
|
328 |
- **预测摘要**:
|
329 |
- 未来一周平均日用量: **{final_forecast_series.head(7).mean():.2f}**
|
330 |
- 未来一月平均日用量: **{final_forecast_series.head(30).mean():.2f}**
|
331 |
-
"""
|
|
|
332 |
report_file_path = os.path.join(OUTPUT_DIR, 'final_analysis_report.txt')
|
333 |
with open(report_file_path, 'w', encoding='utf-8') as f:
|
334 |
f.write(final_report_text)
|
@@ -338,7 +319,6 @@ def run_full_analysis(progress=gr.Progress(track_tqdm=True)):
|
|
338 |
|
339 |
except Exception as e:
|
340 |
log_lines.append(f"\n\n❌ **分析过程中断,出现错误:**\n`{str(e)}`")
|
341 |
-
import traceback
|
342 |
log_lines.append(f"\n**Traceback:**\n```{traceback.format_exc()}```")
|
343 |
yield update_ui()
|
344 |
|
@@ -361,7 +341,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
|
|
361 |
|
362 |
with gr.Tabs():
|
363 |
with gr.TabItem("📊 可视化图表", id=0):
|
364 |
-
gallery_output = gr.Gallery(label="分析图表", elem_id="gallery", columns=[1], height="auto")
|
365 |
with gr.TabItem("📝 实时分析日志", id=1):
|
366 |
log_output = gr.Markdown("点击按钮后,分析日志将实时显示在这里...")
|
367 |
with gr.TabItem("📋 最终报告与下载", id=2):
|
|
|
5 |
import warnings
|
6 |
import os
|
7 |
import logging
|
8 |
+
import tempfile
|
9 |
+
import traceback
|
10 |
|
11 |
# 数据分析与建模
|
12 |
from scipy import stats
|
|
|
42 |
# ======================== (B) 辅助函数 ========================
|
43 |
|
44 |
def calculate_metrics(actual, predicted):
|
|
|
45 |
metrics_df = pd.DataFrame({'actual': actual, 'predicted': predicted}).dropna()
|
46 |
if metrics_df.empty:
|
47 |
return {'MAE': np.nan, 'RMSE': np.nan, 'MAPE': np.nan, 'sMAPE': np.nan}
|
|
|
61 |
"""
|
62 |
# --- 1. 初始化 ---
|
63 |
log_lines = ["## 🚀 数据分析流程已启动..."]
|
64 |
+
figure_paths = []
|
65 |
final_report_text = ""
|
66 |
report_file_path = None
|
67 |
|
68 |
+
# 辅助函数,用于将Matplotlib Figure保存为临时图片文件并返回路径
|
69 |
+
def save_fig_to_path(fig):
|
70 |
+
# 使用 NamedTemporaryFile 来创建一个不会被立即删除的临时文件
|
71 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
|
72 |
+
fig.savefig(tmpfile.name)
|
73 |
+
figure_paths.append(tmpfile.name)
|
74 |
+
plt.close(fig) # 操作完成后关闭图形,释放内存
|
75 |
+
|
76 |
+
# 辅助函数,用于更新UI状态
|
77 |
def update_ui(new_log_line=None):
|
78 |
if new_log_line:
|
79 |
log_lines.append(new_log_line)
|
|
|
80 |
# [log, gallery, final_report, download_button]
|
81 |
+
return "\n\n".join(log_lines), figure_paths, final_report_text, report_file_path
|
82 |
|
83 |
yield update_ui() # 立即显示启动信息
|
84 |
|
|
|
97 |
ts_data = df['Value']
|
98 |
yield update_ui()
|
99 |
|
100 |
+
# --- 3. 平稳性检验与差分 ---
|
101 |
log_lines.append("### 2. 平稳性检验与差分")
|
|
|
|
|
102 |
current_data = ts_data.dropna()
|
103 |
adf_result = adfuller(current_data)
|
104 |
+
if adf_result[1] < 0.05:
|
105 |
+
log_lines.append(f"✅ 序列在 d=0 阶差分后达到平稳 (p={adf_result[1]:.4f})。")
|
|
|
|
|
106 |
d_order = 0
|
107 |
else:
|
108 |
current_data_diff = current_data.diff().dropna()
|
109 |
adf_result_diff = adfuller(current_data_diff)
|
110 |
+
d_order = 1
|
111 |
+
if adf_result_diff[1] < 0.05:
|
112 |
+
log_lines.append(f"✅ 序列在 d=1 阶差分后达到平稳 (p={adf_result_diff[1]:.4f})。")
|
|
|
|
|
|
|
113 |
else:
|
114 |
+
log_lines.append(f"⚠️ 1阶差分后仍未平稳 (p={adf_result_diff[1]:.4f}),将使用 d=1 继续分析。")
|
115 |
+
current_data = current_data_diff
|
|
|
|
|
116 |
ts_stationary = current_data
|
117 |
yield update_ui()
|
118 |
|
119 |
# --- 4. 白噪声检验 ---
|
120 |
log_lines.append("### 3. 白噪声检验")
|
|
|
121 |
lags = min(10, len(ts_stationary) // 5)
|
122 |
lb_test_result = acorr_ljungbox(ts_stationary, lags=[lags], return_df=True)
|
123 |
+
if lb_test_result['lb_pvalue'].iloc[0] > 0.05:
|
124 |
+
log_lines.append(f"⚠️ 序列可能是白噪声(p-value = {lb_test_result['lb_pvalue'].iloc[0]:.4f}),模型可能无效。")
|
|
|
125 |
else:
|
126 |
+
log_lines.append(f"✅ 通过白噪声检验 (p-value = {lb_test_result['lb_pvalue'].iloc[0]:.4f}),可以进行后续建模。")
|
127 |
+
yield update_ui()
|
128 |
+
|
129 |
# --- 5. 季节性检验与分解 ---
|
130 |
log_lines.append("\n### 4. 季节性检验与STL分解")
|
131 |
period = 365
|
132 |
+
seasonal_enabled = len(ts_data) > 2 * 14 # 数据多于两周则开启周季节性
|
133 |
m_period = 7 if seasonal_enabled else 1
|
134 |
log_lines.append(f"✅ 季节性参数设定: m={m_period}, seasonal={seasonal_enabled}")
|
135 |
|
136 |
if len(ts_data) >= 2 * period:
|
137 |
+
seasonal_period_for_stl = period if period % 2 != 0 else period + 1
|
|
|
138 |
log_lines.append(f"✅ 准备进行STL分解,周期(period)={period},季节平滑窗口(seasonal)={seasonal_period_for_stl}。")
|
139 |
+
yield update_ui()
|
|
|
140 |
stl = STL(ts_data, period=period, seasonal=seasonal_period_for_stl)
|
141 |
res = stl.fit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
fig = res.plot()
|
|
|
|
|
143 |
fig.set_size_inches(12, 8)
|
144 |
+
fig.suptitle(f'STL 分解图 (周期={period})', fontsize=16, y=0.98)
|
145 |
+
plt.tight_layout()
|
146 |
+
save_fig_to_path(fig)
|
|
|
147 |
log_lines.append("✅ STL分解图已生成。")
|
148 |
else:
|
149 |
log_lines.append("⚠️ 数据长度不足以进行年度季节性分解。")
|
150 |
yield update_ui()
|
151 |
|
|
|
|
|
152 |
# --- 6. 混合策略回测优化窗口大小 ---
|
153 |
log_lines.append("\n### 5. 优化训练窗口大小")
|
154 |
log_lines.append("⏳ **此步骤计算量大,可能需要5-15分钟,请耐心等待...**")
|
155 |
yield update_ui()
|
156 |
|
157 |
def evaluate_window_hybrid(window_size, time_series, d, m, seasonal):
|
|
|
158 |
errors = []
|
159 |
series_values = time_series.values
|
160 |
+
backtest_length = 100
|
161 |
if len(series_values) <= window_size + backtest_length: return {'window_size': window_size, 'mae': np.inf}
|
162 |
end_index = len(series_values)
|
163 |
start_index = end_index - backtest_length
|
|
|
174 |
if not errors: return {'window_size': window_size, 'mae': np.inf}
|
175 |
return {'window_size': window_size, 'mae': np.mean(np.abs(errors))}
|
176 |
|
177 |
+
window_sizes_to_test = np.arange(70, 211, 14)
|
178 |
with Parallel(n_jobs=-1) as parallel:
|
179 |
results = parallel(
|
180 |
delayed(evaluate_window_hybrid)(ws, ts_data, d_order, m_period, seasonal_enabled) for ws in window_sizes_to_test
|
|
|
192 |
ax.plot(window_results_df.index, window_results_df['mae'], marker='o', label='MAE')
|
193 |
ax.set_title('训练窗口大小对预测误差的影响')
|
194 |
ax.set_xlabel('训练窗口天数'); ax.set_ylabel('误差值'); ax.legend(); ax.grid(True)
|
195 |
+
save_fig_to_path(fig)
|
196 |
yield update_ui()
|
197 |
|
198 |
# --- 7 & 8. 动态滚动预测与评估 ---
|
|
|
206 |
|
207 |
# SARIMA 滚动
|
208 |
sarima_rolling_preds = []
|
209 |
+
for i in progress.tqdm(range(len(test_rolling_target)), desc="SARIMA Rolling Forecast"):
|
210 |
train_window = ts_data.iloc[split_point_roll + i - best_window_size : split_point_roll + i]
|
211 |
try:
|
212 |
model = pm.auto_arima(train_window, d=d_order, m=m_period, seasonal=seasonal_enabled,
|
213 |
stepwise=True, trace=False, error_action='ignore', suppress_warnings=True)
|
214 |
sarima_rolling_preds.append(model.predict(n_periods=1)[0])
|
215 |
except:
|
216 |
+
sarima_rolling_preds.append(sarima_rolling_preds[-1] if sarima_rolling_preds else np.nan)
|
217 |
rolling_predictions['Auto-SARIMA'] = pd.Series(sarima_rolling_preds, index=test_rolling_target.index).ffill()
|
218 |
log_lines.append("✅ Auto-SARIMA 滚动预测完成。")
|
219 |
yield update_ui()
|
220 |
|
221 |
+
# Prophet 滚动
|
222 |
prophet_rolling_preds = []
|
223 |
prophet_model = None
|
224 |
+
for i, (date, value) in enumerate(progress.tqdm(test_rolling_target.items(), desc="Prophet Rolling Forecast")):
|
225 |
+
if i % 14 == 0 or prophet_model is None:
|
226 |
train_upto_date = ts_data.loc[:date - pd.Timedelta(days=1)]
|
227 |
prophet_train_df = train_upto_date.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'})
|
228 |
prophet_model = Prophet(yearly_seasonality='auto', weekly_seasonality=seasonal_enabled, daily_seasonality=False).fit(prophet_train_df)
|
|
|
248 |
ax.plot(test_rolling_target, label='真实值 (测试集)', color='blue', linewidth=2)
|
249 |
for model_name, preds in rolling_predictions.items():
|
250 |
is_best = ' (最佳)' if model_name == best_rolling_model_name else ''
|
251 |
+
ax.plot(preds.dropna(), label=f'{model_name} 预测{is_best}', linestyle='--')
|
252 |
ax.set_title('滚动预测结果对比'); ax.legend(); ax.grid(True)
|
253 |
+
save_fig_to_path(fig)
|
254 |
yield update_ui()
|
255 |
|
256 |
# --- 10. 最终未来预测 ---
|
|
|
281 |
ax.plot(final_forecast_series, label=f'未来 {forecast_horizon} 天预测', color='red', linestyle='--')
|
282 |
ax.fill_between(future_dates, conf_int[:, 0], conf_int[:, 1], color='red', alpha=0.2, label='95% 置信区间')
|
283 |
ax.set_title(f'最终未来用量预测 (基于 {best_rolling_model_name})'); ax.legend(); ax.grid(True)
|
284 |
+
save_fig_to_path(fig)
|
285 |
|
286 |
# 生成最终报告
|
287 |
final_report_text = f"""
|
|
|
308 |
- **预测摘要**:
|
309 |
- 未来一周平均日用量: **{final_forecast_series.head(7).mean():.2f}**
|
310 |
- 未来一月平均日用量: **{final_forecast_series.head(30).mean():.2f}**
|
311 |
+
""".strip()
|
312 |
+
|
313 |
report_file_path = os.path.join(OUTPUT_DIR, 'final_analysis_report.txt')
|
314 |
with open(report_file_path, 'w', encoding='utf-8') as f:
|
315 |
f.write(final_report_text)
|
|
|
319 |
|
320 |
except Exception as e:
|
321 |
log_lines.append(f"\n\n❌ **分析过程中断,出现错误:**\n`{str(e)}`")
|
|
|
322 |
log_lines.append(f"\n**Traceback:**\n```{traceback.format_exc()}```")
|
323 |
yield update_ui()
|
324 |
|
|
|
341 |
|
342 |
with gr.Tabs():
|
343 |
with gr.TabItem("📊 可视化图表", id=0):
|
344 |
+
gallery_output = gr.Gallery(label="分析图表", elem_id="gallery", columns=[1], height="auto", object_fit="contain")
|
345 |
with gr.TabItem("📝 实时分析日志", id=1):
|
346 |
log_output = gr.Markdown("点击按钮后,分析日志将实时显示在这里...")
|
347 |
with gr.TabItem("📋 最终报告与下载", id=2):
|