tbdavid2019 commited on
Commit
6547b21
·
1 Parent(s): 1ddba8e

加入Prophet 雙選擇

Browse files
Files changed (1) hide show
  1. app.py +45 -35
app.py CHANGED
@@ -196,7 +196,7 @@ def update_stock(category, stock):
196
  status_output: gr.update(value="")
197
  }
198
 
199
- def predict_stock(category, stock, stock_item, period, selected_features):
200
  if not all([category, stock, stock_item]):
201
  return gr.update(value=None), "請選擇產業類別、類股和股票"
202
 
@@ -208,55 +208,60 @@ def predict_stock(category, stock, stock_item, period, selected_features):
208
 
209
  stock_items = get_stock_items(url)
210
  stock_code = stock_items.get(stock_item, "")
211
-
212
  if not stock_code:
213
  return gr.update(value=None), "無法獲取股票代碼"
214
-
215
  # 下載股票數據,根據用戶選擇的時間範圍
216
  df = yf.download(stock_code, period=period)
217
  if df.empty:
218
  raise ValueError("無法獲取股票數據")
219
 
220
- # 預測
221
- predictor = StockPredictor()
222
- predictor.train(df, selected_features)
223
-
224
- last_data = predictor.scaler.transform(df[selected_features].iloc[-1:].values)
225
- predictions = predictor.predict(last_data[0], 5)
226
-
227
- # 反轉預測結果
228
- last_original = df[selected_features].iloc[-1].values
229
- predictions_original = predictor.scaler.inverse_transform(
230
- np.vstack([last_data, predictions])
231
- )
232
- all_predictions = np.vstack([last_original, predictions_original[1:]])
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  # 創建日期索引
235
- dates = [datetime.now() + timedelta(days=i) for i in range(6)]
236
  date_labels = [d.strftime('%m/%d') for d in dates]
237
-
238
  # 繪圖
239
  fig, ax = plt.subplots(figsize=(14, 7))
240
- colors = ['#FF9999', '#66B2FF']
241
- labels = [f'預測{feature}' for feature in selected_features]
242
-
243
- for i, (label, color) in enumerate(zip(labels, colors)):
244
- ax.plot(date_labels, all_predictions[:, i], label=label,
245
- marker='o', color=color, linewidth=2)
246
- for j, value in enumerate(all_predictions[:, i]):
247
- ax.annotate(f'{value:.2f}', (date_labels[j], value),
248
- textcoords="offset points", xytext=(0,10),
249
- ha='center', va='bottom')
250
-
251
  ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
252
  ax.set_xlabel('日期', labelpad=10)
253
  ax.set_ylabel('股價', labelpad=10)
254
  ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
255
  ax.grid(True, linestyle='--', alpha=0.7)
256
-
257
  plt.tight_layout()
258
  return gr.update(value=fig), "預測成功"
259
-
260
  except Exception as e:
261
  logging.error(f"預測過程發生錯誤: {str(e)}")
262
  return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
@@ -296,28 +301,33 @@ with gr.Blocks() as demo:
296
  label="選擇要用於預測的特徵",
297
  value=['Open', 'Close']
298
  )
 
 
 
 
 
299
  predict_button = gr.Button("開始預測", variant="primary")
300
  status_output = gr.Textbox(label="狀態", interactive=False)
301
 
302
  with gr.Row():
303
  stock_plot = gr.Plot(label="股價預測圖")
304
-
305
  # 事件綁定
306
  category_dropdown.change(
307
  update_category,
308
  inputs=[category_dropdown],
309
  outputs=[stock_dropdown, stock_item_dropdown, stock_plot, status_output]
310
  )
311
-
312
  stock_dropdown.change(
313
  update_stock,
314
  inputs=[category_dropdown, stock_dropdown],
315
  outputs=[stock_item_dropdown, stock_plot, status_output]
316
  )
317
-
318
  predict_button.click(
319
  predict_stock,
320
- inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkbox],
321
  outputs=[stock_plot, status_output]
322
  )
323
 
 
196
  status_output: gr.update(value="")
197
  }
198
 
199
+ def predict_stock(category, stock, stock_item, period, selected_features, model_choice):
200
  if not all([category, stock, stock_item]):
201
  return gr.update(value=None), "請選擇產業類別、類股和股票"
202
 
 
208
 
209
  stock_items = get_stock_items(url)
210
  stock_code = stock_items.get(stock_item, "")
211
+
212
  if not stock_code:
213
  return gr.update(value=None), "無法獲取股票代碼"
214
+
215
  # 下載股票數據,根據用戶選擇的時間範圍
216
  df = yf.download(stock_code, period=period)
217
  if df.empty:
218
  raise ValueError("無法獲取股票數據")
219
 
220
+ # 根據模型選擇進行預測
221
+ if model_choice == "LSTM":
222
+ predictor = StockPredictor()
223
+ predictor.train(df, selected_features)
224
+ last_data = predictor.scaler.transform(df[selected_features].iloc[-1:].values)
225
+ predictions = predictor.predict(last_data[0], 5)
226
+
227
+ # 反轉預測結果
228
+ last_original = df[selected_features].iloc[-1].values
229
+ predictions_original = predictor.scaler.inverse_transform(
230
+ np.vstack([last_data, predictions])
231
+ )
232
+ all_predictions = np.vstack([last_original, predictions_original[1:]])
233
 
234
+ elif model_choice == "Prophet":
235
+ from prophet import Prophet
236
+ prophet_df = df.reset_index()[['Date', 'Close']]
237
+ prophet_df.rename(columns={'Date': 'ds', 'Close': 'y'}, inplace=True)
238
+
239
+ model = Prophet()
240
+ model.fit(prophet_df)
241
+
242
+ future = model.make_future_dataframe(periods=5)
243
+ forecast = model.predict(future)
244
+ all_predictions = forecast[['ds', 'yhat']].tail(6).values
245
+
246
+ else:
247
+ return gr.update(value=None), "未知的模型選擇"
248
+
249
  # 創建日期索引
250
+ dates = [datetime.now() + timedelta(days=i) for i in range(len(all_predictions))]
251
  date_labels = [d.strftime('%m/%d') for d in dates]
252
+
253
  # 繪圖
254
  fig, ax = plt.subplots(figsize=(14, 7))
255
+ ax.plot(date_labels, all_predictions, label="預測股價", marker='o', color='#FF9999', linewidth=2)
 
 
 
 
 
 
 
 
 
 
256
  ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
257
  ax.set_xlabel('日期', labelpad=10)
258
  ax.set_ylabel('股價', labelpad=10)
259
  ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
260
  ax.grid(True, linestyle='--', alpha=0.7)
261
+
262
  plt.tight_layout()
263
  return gr.update(value=fig), "預測成功"
264
+
265
  except Exception as e:
266
  logging.error(f"預測過程發生錯誤: {str(e)}")
267
  return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
 
301
  label="選擇要用於預測的特徵",
302
  value=['Open', 'Close']
303
  )
304
+ model_dropdown = gr.Dropdown(
305
+ choices=["LSTM", "Prophet"],
306
+ label="選擇預測模型",
307
+ value="LSTM"
308
+ )
309
  predict_button = gr.Button("開始預測", variant="primary")
310
  status_output = gr.Textbox(label="狀態", interactive=False)
311
 
312
  with gr.Row():
313
  stock_plot = gr.Plot(label="股價預測圖")
314
+
315
  # 事件綁定
316
  category_dropdown.change(
317
  update_category,
318
  inputs=[category_dropdown],
319
  outputs=[stock_dropdown, stock_item_dropdown, stock_plot, status_output]
320
  )
321
+
322
  stock_dropdown.change(
323
  update_stock,
324
  inputs=[category_dropdown, stock_dropdown],
325
  outputs=[stock_item_dropdown, stock_plot, status_output]
326
  )
327
+
328
  predict_button.click(
329
  predict_stock,
330
+ inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkbox, model_dropdown],
331
  outputs=[stock_plot, status_output]
332
  )
333