Spaces:
Running
Running
Commit
·
999c140
1
Parent(s):
d89db10
fix prophet form again
Browse files
app.py
CHANGED
@@ -232,7 +232,31 @@ def predict_stock(category, stock, stock_item, period, selected_features, model_
|
|
232 |
)
|
233 |
all_predictions = np.vstack([last_original, predictions_original[1:]])
|
234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
elif model_choice == "Prophet":
|
|
|
|
|
|
|
236 |
prophet_df = df.reset_index()[['Date', 'Close']]
|
237 |
prophet_df.rename(columns={'Date': 'ds', 'Close': 'y'}, inplace=True)
|
238 |
|
@@ -241,10 +265,9 @@ def predict_stock(category, stock, stock_item, period, selected_features, model_
|
|
241 |
|
242 |
future = model.make_future_dataframe(periods=5)
|
243 |
forecast = model.predict(future)
|
244 |
-
all_predictions = forecast[['ds', 'yhat']].tail(6).values
|
245 |
|
246 |
# 取出日期和預測結果
|
247 |
-
date_labels = [
|
248 |
predictions = forecast['yhat'].tail(6).values
|
249 |
|
250 |
# 繪圖
|
@@ -255,7 +278,6 @@ def predict_stock(category, stock, stock_item, period, selected_features, model_
|
|
255 |
ax.set_ylabel('股價', labelpad=10)
|
256 |
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
|
257 |
ax.grid(True, linestyle='--', alpha=0.7)
|
258 |
-
|
259 |
plt.tight_layout()
|
260 |
return gr.update(value=fig), "預測成功"
|
261 |
|
|
|
232 |
)
|
233 |
all_predictions = np.vstack([last_original, predictions_original[1:]])
|
234 |
|
235 |
+
# 創建日期索引
|
236 |
+
dates = [datetime.now() + timedelta(days=i) for i in range(6)]
|
237 |
+
date_labels = [d.strftime('%m/%d') for d in dates]
|
238 |
+
|
239 |
+
# 繪圖
|
240 |
+
fig, ax = plt.subplots(figsize=(14, 7))
|
241 |
+
for i, feature in enumerate(selected_features):
|
242 |
+
ax.plot(date_labels, all_predictions[:, i], label=f'預測{feature}', marker='o', linewidth=2)
|
243 |
+
for j, value in enumerate(all_predictions[:, i]):
|
244 |
+
ax.annotate(f'{value:.2f}', (date_labels[j], value),
|
245 |
+
textcoords="offset points", xytext=(0,10),
|
246 |
+
ha='center', va='bottom')
|
247 |
+
|
248 |
+
ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
|
249 |
+
ax.set_xlabel('日期', labelpad=10)
|
250 |
+
ax.set_ylabel('股價', labelpad=10)
|
251 |
+
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
|
252 |
+
ax.grid(True, linestyle='--', alpha=0.7)
|
253 |
+
plt.tight_layout()
|
254 |
+
return gr.update(value=fig), "預測成功"
|
255 |
+
|
256 |
elif model_choice == "Prophet":
|
257 |
+
if 'Close' not in selected_features:
|
258 |
+
return gr.update(value=None), "Prophet 模型僅支持 'Close' 特徵"
|
259 |
+
|
260 |
prophet_df = df.reset_index()[['Date', 'Close']]
|
261 |
prophet_df.rename(columns={'Date': 'ds', 'Close': 'y'}, inplace=True)
|
262 |
|
|
|
265 |
|
266 |
future = model.make_future_dataframe(periods=5)
|
267 |
forecast = model.predict(future)
|
|
|
268 |
|
269 |
# 取出日期和預測結果
|
270 |
+
date_labels = forecast['ds'].tail(6).dt.strftime('%m/%d').tolist()
|
271 |
predictions = forecast['yhat'].tail(6).values
|
272 |
|
273 |
# 繪圖
|
|
|
278 |
ax.set_ylabel('股價', labelpad=10)
|
279 |
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
|
280 |
ax.grid(True, linestyle='--', alpha=0.7)
|
|
|
281 |
plt.tight_layout()
|
282 |
return gr.update(value=fig), "預測成功"
|
283 |
|