tbdavid2019 commited on
Commit
f7c1877
·
1 Parent(s): 73cc4bb
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -194,24 +194,24 @@ def update_stock(category, stock):
194
  stock_plot: gr.update(value=None)
195
  }
196
 
197
- def predict_stock(category, stock, stock_item, features, model_type):
198
  if not all([category, stock, stock_item]):
199
- return gr.update(value=None)
200
 
201
  try:
202
  url = next((item['網址'] for item in category_dict.get(category, [])
203
  if item['類股'] == stock), None)
204
  if not url:
205
- return gr.update(value=None)
206
 
207
  stock_items = get_stock_items(url)
208
  stock_code = stock_items.get(stock_item, "")
209
 
210
  if not stock_code:
211
- return gr.update(value=None)
212
 
213
  # 下載股票數據
214
- df = yf.download(stock_code, period="1y")
215
  if df.empty:
216
  raise ValueError("無法獲取股票數據")
217
 
@@ -248,13 +248,20 @@ def predict_stock(category, stock, stock_item, features, model_type):
248
  colors = ['#FF9999', '#66B2FF']
249
  labels = ['預測開盤價', '預測收盤價']
250
 
251
- for i, (label, color) in enumerate(zip(labels, colors)):
252
- ax.plot(date_labels, all_predictions if model_type == "Prophet" else all_predictions[:, i],
253
- label=label, marker='o', color=color, linewidth=2)
254
- for j, value in enumerate(all_predictions if model_type == "Prophet" else all_predictions[:, i]):
255
- ax.annotate(f'{value:.2f}', (date_labels[j], value),
256
- textcoords="offset points", xytext=(0,10),
257
- ha='center', va='bottom')
 
 
 
 
 
 
 
258
 
259
  ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
260
  ax.set_xlabel('日期', labelpad=10)
@@ -263,11 +270,11 @@ def predict_stock(category, stock, stock_item, features, model_type):
263
  ax.grid(True, linestyle='--', alpha=0.7)
264
 
265
  plt.tight_layout()
266
- return gr.update(value=fig)
267
 
268
  except Exception as e:
269
  logging.error(f"預測過程發生錯誤: {str(e)}")
270
- return gr.update(value=None)
271
 
272
  # 初始化
273
  setup_font()
@@ -330,7 +337,7 @@ with gr.Blocks() as demo:
330
 
331
  predict_button.click(
332
  predict_stock,
333
- inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, features_checkboxes, model_type_dropdown],
334
  outputs=[stock_plot, status_textbox]
335
  )
336
 
 
194
  stock_plot: gr.update(value=None)
195
  }
196
 
197
+ def predict_stock(category, stock, stock_item, period, features, model_type):
198
  if not all([category, stock, stock_item]):
199
+ return gr.update(value=None), "請選擇完整的選項"
200
 
201
  try:
202
  url = next((item['網址'] for item in category_dict.get(category, [])
203
  if item['類股'] == stock), None)
204
  if not url:
205
+ return gr.update(value=None), "無法找到該類股的網址"
206
 
207
  stock_items = get_stock_items(url)
208
  stock_code = stock_items.get(stock_item, "")
209
 
210
  if not stock_code:
211
+ return gr.update(value=None), "無法找到該股票的代碼"
212
 
213
  # 下載股票數據
214
+ df = yf.download(stock_code, period=period)
215
  if df.empty:
216
  raise ValueError("無法獲取股票數據")
217
 
 
248
  colors = ['#FF9999', '#66B2FF']
249
  labels = ['預測開盤價', '預測收盤價']
250
 
251
+ for i, (label, color) in enumerate(labels):
252
+ if model_type == "Prophet":
253
+ ax.plot(date_labels, all_predictions, label='預測收盤價', marker='o', color=colors[1], linewidth=2)
254
+ for j, value in enumerate(all_predictions):
255
+ ax.annotate(f'{value:.2f}', (date_labels[j], value),
256
+ textcoords="offset points", xytext=(0,10),
257
+ ha='center', va='bottom')
258
+ break
259
+ else:
260
+ ax.plot(date_labels, all_predictions[:, i], label=label, marker='o', color=color, linewidth=2)
261
+ for j, value in enumerate(all_predictions[:, i]):
262
+ ax.annotate(f'{value:.2f}', (date_labels[j], value),
263
+ textcoords="offset points", xytext=(0,10),
264
+ ha='center', va='bottom')
265
 
266
  ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
267
  ax.set_xlabel('日期', labelpad=10)
 
270
  ax.grid(True, linestyle='--', alpha=0.7)
271
 
272
  plt.tight_layout()
273
+ return gr.update(value=fig), "預測成功"
274
 
275
  except Exception as e:
276
  logging.error(f"預測過程發生錯誤: {str(e)}")
277
+ return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
278
 
279
  # 初始化
280
  setup_font()
 
337
 
338
  predict_button.click(
339
  predict_stock,
340
+ inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkboxes, model_type_dropdown],
341
  outputs=[stock_plot, status_textbox]
342
  )
343