tbdavid2019 commited on
Commit
999c140
·
1 Parent(s): d89db10

fix prophet form again

Browse files
Files changed (1) hide show
  1. app.py +25 -3
app.py CHANGED
@@ -232,7 +232,31 @@ def predict_stock(category, stock, stock_item, period, selected_features, model_
232
  )
233
  all_predictions = np.vstack([last_original, predictions_original[1:]])
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  elif model_choice == "Prophet":
 
 
 
236
  prophet_df = df.reset_index()[['Date', 'Close']]
237
  prophet_df.rename(columns={'Date': 'ds', 'Close': 'y'}, inplace=True)
238
 
@@ -241,10 +265,9 @@ def predict_stock(category, stock, stock_item, period, selected_features, model_
241
 
242
  future = model.make_future_dataframe(periods=5)
243
  forecast = model.predict(future)
244
- all_predictions = forecast[['ds', 'yhat']].tail(6).values
245
 
246
  # 取出日期和預測結果
247
- date_labels = [d.strftime('%m/%d') for d in forecast['ds'].tail(6)]
248
  predictions = forecast['yhat'].tail(6).values
249
 
250
  # 繪圖
@@ -255,7 +278,6 @@ def predict_stock(category, stock, stock_item, period, selected_features, model_
255
  ax.set_ylabel('股價', labelpad=10)
256
  ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
257
  ax.grid(True, linestyle='--', alpha=0.7)
258
-
259
  plt.tight_layout()
260
  return gr.update(value=fig), "預測成功"
261
 
 
232
  )
233
  all_predictions = np.vstack([last_original, predictions_original[1:]])
234
 
235
+ # 創建日期索引
236
+ dates = [datetime.now() + timedelta(days=i) for i in range(6)]
237
+ date_labels = [d.strftime('%m/%d') for d in dates]
238
+
239
+ # 繪圖
240
+ fig, ax = plt.subplots(figsize=(14, 7))
241
+ for i, feature in enumerate(selected_features):
242
+ ax.plot(date_labels, all_predictions[:, i], label=f'預測{feature}', marker='o', linewidth=2)
243
+ for j, value in enumerate(all_predictions[:, i]):
244
+ ax.annotate(f'{value:.2f}', (date_labels[j], value),
245
+ textcoords="offset points", xytext=(0,10),
246
+ ha='center', va='bottom')
247
+
248
+ ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
249
+ ax.set_xlabel('日期', labelpad=10)
250
+ ax.set_ylabel('股價', labelpad=10)
251
+ ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
252
+ ax.grid(True, linestyle='--', alpha=0.7)
253
+ plt.tight_layout()
254
+ return gr.update(value=fig), "預測成功"
255
+
256
  elif model_choice == "Prophet":
257
+ if 'Close' not in selected_features:
258
+ return gr.update(value=None), "Prophet 模型僅支持 'Close' 特徵"
259
+
260
  prophet_df = df.reset_index()[['Date', 'Close']]
261
  prophet_df.rename(columns={'Date': 'ds', 'Close': 'y'}, inplace=True)
262
 
 
265
 
266
  future = model.make_future_dataframe(periods=5)
267
  forecast = model.predict(future)
 
268
 
269
  # 取出日期和預測結果
270
+ date_labels = forecast['ds'].tail(6).dt.strftime('%m/%d').tolist()
271
  predictions = forecast['yhat'].tail(6).values
272
 
273
  # 繪圖
 
278
  ax.set_ylabel('股價', labelpad=10)
279
  ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
280
  ax.grid(True, linestyle='--', alpha=0.7)
 
281
  plt.tight_layout()
282
  return gr.update(value=fig), "預測成功"
283