tbdavid2019 commited on
Commit
d89db10
·
1 Parent(s): fb219db

fix prophet

Browse files
Files changed (1) hide show
  1. app.py +17 -18
app.py CHANGED
@@ -16,7 +16,7 @@ import os
16
  import yfinance as yf
17
  import logging
18
  from datetime import datetime, timedelta
19
-
20
 
21
  # 設置日誌
22
  logging.basicConfig(level=logging.INFO,
@@ -233,7 +233,6 @@ def predict_stock(category, stock, stock_item, period, selected_features, model_
233
  all_predictions = np.vstack([last_original, predictions_original[1:]])
234
 
235
  elif model_choice == "Prophet":
236
- from prophet import Prophet
237
  prophet_df = df.reset_index()[['Date', 'Close']]
238
  prophet_df.rename(columns={'Date': 'ds', 'Close': 'y'}, inplace=True)
239
 
@@ -243,26 +242,26 @@ def predict_stock(category, stock, stock_item, period, selected_features, model_
243
  future = model.make_future_dataframe(periods=5)
244
  forecast = model.predict(future)
245
  all_predictions = forecast[['ds', 'yhat']].tail(6).values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  else:
248
  return gr.update(value=None), "未知的模型選擇"
249
 
250
- # 創建日期索引
251
- dates = [datetime.now() + timedelta(days=i) for i in range(len(all_predictions))]
252
- date_labels = [d.strftime('%m/%d') for d in dates]
253
-
254
- # 繪圖
255
- fig, ax = plt.subplots(figsize=(14, 7))
256
- ax.plot(date_labels, all_predictions, label="預測股價", marker='o', color='#FF9999', linewidth=2)
257
- ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
258
- ax.set_xlabel('日期', labelpad=10)
259
- ax.set_ylabel('股價', labelpad=10)
260
- ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
261
- ax.grid(True, linestyle='--', alpha=0.7)
262
-
263
- plt.tight_layout()
264
- return gr.update(value=fig), "預測成功"
265
-
266
  except Exception as e:
267
  logging.error(f"預測過程發生錯誤: {str(e)}")
268
  return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
 
16
  import yfinance as yf
17
  import logging
18
  from datetime import datetime, timedelta
19
+ from prophet import Prophet
20
 
21
  # 設置日誌
22
  logging.basicConfig(level=logging.INFO,
 
233
  all_predictions = np.vstack([last_original, predictions_original[1:]])
234
 
235
  elif model_choice == "Prophet":
 
236
  prophet_df = df.reset_index()[['Date', 'Close']]
237
  prophet_df.rename(columns={'Date': 'ds', 'Close': 'y'}, inplace=True)
238
 
 
242
  future = model.make_future_dataframe(periods=5)
243
  forecast = model.predict(future)
244
  all_predictions = forecast[['ds', 'yhat']].tail(6).values
245
+
246
+ # 取出日期和預測結果
247
+ date_labels = [d.strftime('%m/%d') for d in forecast['ds'].tail(6)]
248
+ predictions = forecast['yhat'].tail(6).values
249
+
250
+ # 繪圖
251
+ fig, ax = plt.subplots(figsize=(14, 7))
252
+ ax.plot(date_labels, predictions, label="預測股價", marker='o', color='#FF9999', linewidth=2)
253
+ ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
254
+ ax.set_xlabel('日期', labelpad=10)
255
+ ax.set_ylabel('股價', labelpad=10)
256
+ ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
257
+ ax.grid(True, linestyle='--', alpha=0.7)
258
+
259
+ plt.tight_layout()
260
+ return gr.update(value=fig), "預測成功"
261
 
262
  else:
263
  return gr.update(value=None), "未知的模型選擇"
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  except Exception as e:
266
  logging.error(f"預測過程發生錯誤: {str(e)}")
267
  return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"