tbdavid2019 commited on
Commit
a79dcc6
·
1 Parent(s): 8da5de8
Files changed (2) hide show
  1. app.py +110 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import yfinance as yf
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ from prophet import Prophet
6
+ from datetime import datetime, timedelta
7
+ import logging
8
+
9
+ # 設置日誌
10
+ logging.basicConfig(level=logging.INFO,
11
+ format='%(asctime)s - %(levelname)s - %(message)s')
12
+
13
+ def predict_stock_price(stock_code, period, prediction_days):
14
+ try:
15
+ # 下載股票數據
16
+ df = yf.download(stock_code, period=period)
17
+ if df.empty:
18
+ return "無法獲取股票數據", None
19
+
20
+ # 準備數據
21
+ data = df.reset_index()
22
+ data = data[['Date', 'Close']]
23
+ data.columns = ['ds', 'y']
24
+
25
+ # 訓練 Prophet 模型
26
+ model = Prophet(daily_seasonality=True)
27
+ model.fit(data)
28
+
29
+ # 創建未來日期
30
+ future = model.make_future_dataframe(periods=prediction_days)
31
+ forecast = model.predict(future)
32
+
33
+ # 繪製圖表
34
+ plt.figure(figsize=(12, 6))
35
+
36
+ # 繪製實際數據
37
+ plt.plot(data['ds'], data['y'],
38
+ label='實際收盤價',
39
+ color='blue')
40
+
41
+ # 繪製預測數據
42
+ plt.plot(forecast['ds'], forecast['yhat'],
43
+ label='預測收盤價',
44
+ color='orange',
45
+ linestyle='--')
46
+
47
+ # 添加預測區間
48
+ plt.fill_between(forecast['ds'],
49
+ forecast['yhat_lower'],
50
+ forecast['yhat_upper'],
51
+ color='orange',
52
+ alpha=0.2)
53
+
54
+ # 設置圖表格式
55
+ plt.title(f'{stock_code} 股價預測', pad=20)
56
+ plt.xlabel('日期')
57
+ plt.ylabel('股價')
58
+ plt.xticks(rotation=45)
59
+ plt.legend()
60
+ plt.grid(True, linestyle='--', alpha=0.7)
61
+ plt.tight_layout()
62
+
63
+ # 返回預測結果和圖表
64
+ return forecast.tail(prediction_days).to_string(), plt.gcf()
65
+
66
+ except Exception as e:
67
+ logging.error(f"預測過程發生錯誤: {str(e)}")
68
+ return f"預測過程發生錯誤: {str(e)}", None
69
+
70
+ # Gradio 介面
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("# 台股預測系統")
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ stock_input = gr.Textbox(
77
+ label="股票代碼",
78
+ placeholder="例如: 2330.TW",
79
+ value="2330.TW"
80
+ )
81
+
82
+ period_dropdown = gr.Dropdown(
83
+ choices=["1mo", "3mo", "6mo", "1y", "2y", "5y", "10y", "max"],
84
+ label="歷史數據期間",
85
+ value="1y"
86
+ )
87
+
88
+ prediction_days = gr.Slider(
89
+ minimum=5,
90
+ maximum=30,
91
+ value=5,
92
+ step=1,
93
+ label="預測天數"
94
+ )
95
+
96
+ predict_button = gr.Button("開始預測", variant="primary")
97
+
98
+ with gr.Column():
99
+ output_plot = gr.Plot(label="股價預測圖")
100
+ output_text = gr.Textbox(label="預測結果")
101
+
102
+ predict_button.click(
103
+ predict_stock_price,
104
+ inputs=[stock_input, period_dropdown, prediction_days],
105
+ outputs=[output_text, output_plot]
106
+ )
107
+
108
+ # 啟動應用
109
+ if __name__ == "__main__":
110
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ gradio>=4.0.0
3
+ yfinance>=0.2.3
4
+ pandas>=1.3.0
5
+ numpy>=1.21.0
6
+ matplotlib>=3.4.3
7
+ prophet>=1.1.4