Spaces:
Running
Running
Commit
·
d83f194
1
Parent(s):
39eedf9
Initial commit
Browse files- app.py +283 -0
- readme.md +64 -0
- requirements.txt +11 -0
app.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import aiohttp
|
3 |
+
import asyncio
|
4 |
+
import requests
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
from sklearn.preprocessing import MinMaxScaler
|
9 |
+
from tensorflow.keras.models import Sequential
|
10 |
+
from tensorflow.keras.layers import LSTM, Dense, Dropout
|
11 |
+
from tensorflow.keras.optimizers import Adam
|
12 |
+
from datetime import datetime, timedelta
|
13 |
+
import plotly.graph_objs as go
|
14 |
+
import plotly.io as pio
|
15 |
+
import yfinance as yf
|
16 |
+
import logging
|
17 |
+
import tempfile
|
18 |
+
import os
|
19 |
+
import matplotlib as mpl
|
20 |
+
import matplotlib.font_manager as fm
|
21 |
+
|
22 |
+
# 設置日志
|
23 |
+
logging.basicConfig(level=logging.INFO,
|
24 |
+
format='%(asctime)s - %(levelname)s - %(message)s')
|
25 |
+
|
26 |
+
# 字體設置
|
27 |
+
def setup_font():
|
28 |
+
try:
|
29 |
+
url_font = "https://drive.google.com/uc?id=1eGAsTN1HBpJAkeVM57_C7ccp7hbgSz3_"
|
30 |
+
response_font = requests.get(url_font)
|
31 |
+
|
32 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.ttf') as tmp_file:
|
33 |
+
tmp_file.write(response_font.content)
|
34 |
+
tmp_file_path = tmp_file.name
|
35 |
+
|
36 |
+
fm.fontManager.addfont(tmp_file_path)
|
37 |
+
mpl.rc('font', family='Taipei Sans TC Beta')
|
38 |
+
except Exception as e:
|
39 |
+
logging.error(f"字體設置失敗: {str(e)}")
|
40 |
+
# 使用備用字體
|
41 |
+
mpl.rc('font', family='SimHei')
|
42 |
+
|
43 |
+
# 網路請求設置
|
44 |
+
headers = {
|
45 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
46 |
+
'Accept-Language': 'zh-TW,zh;q=0.9,en-US;q=0.8,en;q=0.7',
|
47 |
+
'Accept-Encoding': 'gzip, deflate, br',
|
48 |
+
'Connection': 'keep-alive',
|
49 |
+
'Upgrade-Insecure-Requests': '1'
|
50 |
+
}
|
51 |
+
|
52 |
+
async def fetch_stock_categories():
|
53 |
+
try:
|
54 |
+
url = "https://tw.stock.yahoo.com/class/"
|
55 |
+
async with aiohttp.ClientSession() as session:
|
56 |
+
async with session.get(url, headers=headers) as response:
|
57 |
+
response_text = await response.text()
|
58 |
+
soup = BeautifulSoup(response_text, 'html.parser')
|
59 |
+
main_categories = soup.find_all('div', class_='C($c-link-text)')
|
60 |
+
|
61 |
+
data = []
|
62 |
+
for category in main_categories:
|
63 |
+
main_category_name = category.find('h2', class_="Fw(b) Fz(24px) Lh(32px)")
|
64 |
+
if main_category_name:
|
65 |
+
main_category_name = main_category_name.text.strip()
|
66 |
+
sub_categories = category.find_all('a', class_='Fz(16px) Lh(1.5) C($c-link-text) C($c-active-text):h Fw(b):h Td(n)')
|
67 |
+
|
68 |
+
for sub_category in sub_categories:
|
69 |
+
data.append({
|
70 |
+
'台股': main_category_name,
|
71 |
+
'類股': sub_category.text.strip(),
|
72 |
+
'網址': "https://tw.stock.yahoo.com" + sub_category['href']
|
73 |
+
})
|
74 |
+
|
75 |
+
category_dict = {}
|
76 |
+
for item in data:
|
77 |
+
if item['台股'] not in category_dict:
|
78 |
+
category_dict[item['台股']] = []
|
79 |
+
category_dict[item['台股']].append({'類股': item['類股'], '網址': item['網址']})
|
80 |
+
|
81 |
+
return category_dict
|
82 |
+
except Exception as e:
|
83 |
+
logging.error(f"獲取股票類別失敗: {str(e)}")
|
84 |
+
return {}
|
85 |
+
|
86 |
+
# 股票預測模型類別
|
87 |
+
class StockPredictor:
|
88 |
+
def __init__(self):
|
89 |
+
self.model = None
|
90 |
+
self.scaler = MinMaxScaler()
|
91 |
+
|
92 |
+
def prepare_data(self, df, selected_features):
|
93 |
+
scaled_data = self.scaler.fit_transform(df[selected_features])
|
94 |
+
|
95 |
+
X, y = [], []
|
96 |
+
for i in range(len(scaled_data) - 1):
|
97 |
+
X.append(scaled_data[i])
|
98 |
+
y.append(scaled_data[i+1])
|
99 |
+
|
100 |
+
return np.array(X).reshape(-1, 1, len(selected_features)), np.array(y)
|
101 |
+
|
102 |
+
def build_model(self, input_shape):
|
103 |
+
model = Sequential([
|
104 |
+
LSTM(100, activation='relu', input_shape=input_shape, return_sequences=True),
|
105 |
+
Dropout(0.2),
|
106 |
+
LSTM(50, activation='relu'),
|
107 |
+
Dropout(0.2),
|
108 |
+
Dense(input_shape[1])
|
109 |
+
])
|
110 |
+
model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
|
111 |
+
return model
|
112 |
+
|
113 |
+
def train(self, df, selected_features):
|
114 |
+
X, y = self.prepare_data(df, selected_features)
|
115 |
+
self.model = self.build_model((1, X.shape[2]))
|
116 |
+
history = self.model.fit(
|
117 |
+
X, y,
|
118 |
+
epochs=50,
|
119 |
+
batch_size=32,
|
120 |
+
validation_split=0.2,
|
121 |
+
verbose=0
|
122 |
+
)
|
123 |
+
return history
|
124 |
+
|
125 |
+
def predict(self, last_data, n_days):
|
126 |
+
predictions = []
|
127 |
+
current_data = last_data.copy()
|
128 |
+
|
129 |
+
for _ in range(n_days):
|
130 |
+
next_day = self.model.predict(current_data.reshape(1, 1, -1), verbose=0)
|
131 |
+
predictions.append(next_day[0])
|
132 |
+
|
133 |
+
current_data = next_day
|
134 |
+
|
135 |
+
return np.array(predictions)
|
136 |
+
|
137 |
+
# Gradio界面函數
|
138 |
+
async def update_stocks(category):
|
139 |
+
if not category or category not in category_dict:
|
140 |
+
return []
|
141 |
+
return [item['類股'] for item in category_dict[category]]
|
142 |
+
|
143 |
+
async def get_stock_items(url):
|
144 |
+
try:
|
145 |
+
async with aiohttp.ClientSession() as session:
|
146 |
+
async with session.get(url, headers=headers) as response:
|
147 |
+
response_text = await response.text()
|
148 |
+
soup = BeautifulSoup(response_text, 'html.parser')
|
149 |
+
stock_items = soup.find_all('li', class_='List(n)')
|
150 |
+
|
151 |
+
stocks_dict = {}
|
152 |
+
for item in stock_items:
|
153 |
+
stock_name = item.find('div', class_='Lh(20px) Fw(600) Fz(16px) Ell')
|
154 |
+
stock_code = item.find('span', class_='Fz(14px) C(#979ba7) Ell')
|
155 |
+
if stock_name and stock_code:
|
156 |
+
full_code = stock_code.text.strip()
|
157 |
+
display_code = full_code.split('.')[0]
|
158 |
+
display_name = f"{stock_name.text.strip()}{display_code}"
|
159 |
+
stocks_dict[display_name] = full_code
|
160 |
+
|
161 |
+
return stocks_dict
|
162 |
+
except Exception as e:
|
163 |
+
logging.error(f"獲取股票項目失敗: {str(e)}")
|
164 |
+
return {}
|
165 |
+
|
166 |
+
async def predict_stock(category, stock, stock_item, period, selected_features):
|
167 |
+
if not all([category, stock, stock_item]):
|
168 |
+
return gr.update(value=None), "請選擇產業類別、類股和股票"
|
169 |
+
|
170 |
+
try:
|
171 |
+
url = next((item['網址'] for item in category_dict.get(category, [])
|
172 |
+
if item['類股'] == stock), None)
|
173 |
+
if not url:
|
174 |
+
return gr.update(value=None), "無法獲取類股網址"
|
175 |
+
|
176 |
+
stock_items = await get_stock_items(url)
|
177 |
+
stock_code = stock_items.get(stock_item, "")
|
178 |
+
|
179 |
+
if not stock_code:
|
180 |
+
return gr.update(value=None), "無法獲取股票代碼"
|
181 |
+
|
182 |
+
# 下載股票數據,根據用戶選擇的時間範圍
|
183 |
+
df = yf.download(stock_code, period=period)
|
184 |
+
if df.empty:
|
185 |
+
raise ValueError("無法獲取股票數據")
|
186 |
+
|
187 |
+
# 預測
|
188 |
+
predictor = StockPredictor()
|
189 |
+
predictor.train(df, selected_features)
|
190 |
+
|
191 |
+
last_data = predictor.scaler.transform(df.iloc[-1:][selected_features])
|
192 |
+
predictions = predictor.predict(last_data[0], 5)
|
193 |
+
|
194 |
+
# 創建日期指標
|
195 |
+
dates = [datetime.now() + timedelta(days=i) for i in range(6)]
|
196 |
+
date_labels = [d.strftime('%m/%d') for d in dates]
|
197 |
+
|
198 |
+
# 用 Plotly 繪圖
|
199 |
+
fig = go.Figure()
|
200 |
+
for i, feature in enumerate(selected_features):
|
201 |
+
fig.add_trace(go.Scatter(
|
202 |
+
x=date_labels,
|
203 |
+
y=np.hstack([df[feature].iloc[-1], predictions[:, i]]),
|
204 |
+
mode='lines+markers',
|
205 |
+
name=f'預測{feature}'
|
206 |
+
))
|
207 |
+
|
208 |
+
fig.update_layout(
|
209 |
+
title=f'{stock_item} 股價預測 (未來5天)',
|
210 |
+
xaxis_title='日期',
|
211 |
+
yaxis_title='股價',
|
212 |
+
template='plotly_dark'
|
213 |
+
)
|
214 |
+
|
215 |
+
return gr.update(value=pio.to_html(fig, full_html=False)), "預測成功"
|
216 |
+
|
217 |
+
except Exception as e:
|
218 |
+
logging.error(f"預測過程發生錯誤: {str(e)}")
|
219 |
+
return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
|
220 |
+
|
221 |
+
# 初始化
|
222 |
+
setup_font()
|
223 |
+
category_dict = asyncio.run(fetch_stock_categories())
|
224 |
+
categories = list(category_dict.keys())
|
225 |
+
|
226 |
+
# Gradio界面
|
227 |
+
with gr.Blocks() as demo:
|
228 |
+
gr.Markdown("# 台股預測系統")
|
229 |
+
with gr.Row():
|
230 |
+
with gr.Column():
|
231 |
+
category_dropdown = gr.Dropdown(
|
232 |
+
choices=categories,
|
233 |
+
label="產業類別",
|
234 |
+
value=None
|
235 |
+
)
|
236 |
+
stock_dropdown = gr.Dropdown(
|
237 |
+
choices=[],
|
238 |
+
label="類股",
|
239 |
+
value=None
|
240 |
+
)
|
241 |
+
stock_item_dropdown = gr.Dropdown(
|
242 |
+
choices=[],
|
243 |
+
label="股票",
|
244 |
+
value=None
|
245 |
+
)
|
246 |
+
period_dropdown = gr.Dropdown(
|
247 |
+
choices=["1y", "6mo", "3mo", "1mo"],
|
248 |
+
label="抓取時間範圍",
|
249 |
+
value="1y"
|
250 |
+
)
|
251 |
+
features_checkbox = gr.CheckboxGroup(
|
252 |
+
choices=['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume'],
|
253 |
+
label="選擇要用於預測的特徵",
|
254 |
+
value=['Open', 'Close']
|
255 |
+
)
|
256 |
+
predict_button = gr.Button("開始預測", variant="primary")
|
257 |
+
status_output = gr.Textbox(label="狀態", interactive=False)
|
258 |
+
|
259 |
+
with gr.Row():
|
260 |
+
stock_plot = gr.HTML(label="股價預測圖")
|
261 |
+
|
262 |
+
# 事件綁定
|
263 |
+
category_dropdown.change(
|
264 |
+
update_stocks,
|
265 |
+
inputs=[category_dropdown],
|
266 |
+
outputs=[stock_dropdown]
|
267 |
+
)
|
268 |
+
|
269 |
+
stock_dropdown.change(
|
270 |
+
update_stocks,
|
271 |
+
inputs=[category_dropdown],
|
272 |
+
outputs=[stock_item_dropdown]
|
273 |
+
)
|
274 |
+
|
275 |
+
predict_button.click(
|
276 |
+
predict_stock,
|
277 |
+
inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkbox],
|
278 |
+
outputs=[stock_plot, status_output]
|
279 |
+
)
|
280 |
+
|
281 |
+
# 啟動應用
|
282 |
+
if __name__ == "__main__":
|
283 |
+
demo.launch(share=False)
|
readme.md
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stock Prediction Application
|
2 |
+
|
3 |
+
This project is a stock prediction application developed using Gradio, TensorFlow, and Plotly, with integration of Yahoo Finance for stock data fetching. It allows users to predict future stock prices by selecting various stock features, periods, and categories in the Taiwanese market. The application is designed to be deployed on Hugging Face Spaces, providing a user-friendly interface for non-technical users.
|
4 |
+
|
5 |
+
### Features
|
6 |
+
- **Real-time Stock Data**: Retrieve stock data directly from Yahoo Finance.
|
7 |
+
- **Customizable Prediction Features**: Users can select different features (e.g., Open, High, Low, Close, Volume) for prediction.
|
8 |
+
- **Dynamic Charting**: The stock prices are displayed with interactive charts using Plotly.
|
9 |
+
- **Flexible Data Range**: Users can select different data ranges (e.g., 1 year, 6 months, 3 months, 1 month).
|
10 |
+
- **Taiwanese Stock Categories**: Extract and analyze Taiwanese stock categories to help users gain insights.
|
11 |
+
|
12 |
+
### Installation
|
13 |
+
To run this project locally, you need to have Python installed and the required dependencies. You can install the dependencies using the following command:
|
14 |
+
|
15 |
+
```sh
|
16 |
+
pip install -r requirements.txt
|
17 |
+
```
|
18 |
+
|
19 |
+
### Run the Application
|
20 |
+
To launch the application, run the following command:
|
21 |
+
|
22 |
+
```sh
|
23 |
+
python app.py
|
24 |
+
```
|
25 |
+
|
26 |
+
### Deploy to Hugging Face Spaces
|
27 |
+
This application can be deployed to Hugging Face Spaces by pushing the project files to your Hugging Face repository.
|
28 |
+
|
29 |
+
### License
|
30 |
+
This project is open-sourced under the MIT license.
|
31 |
+
|
32 |
+
---
|
33 |
+
|
34 |
+
# 股票預測應用程序
|
35 |
+
|
36 |
+
這個項目是一個使用 Gradio、TensorFlow 和 Plotly 開發的股票預測應用程序,集成了 Yahoo Finance 以抓取股票數據。它允許用戶通過選擇各種股票特徵、時間範圍和台股類別來預測未來的股票價格。該應用旨在部署到 Hugging Face Spaces,為非技術用戶提供友好的界面。
|
37 |
+
|
38 |
+
### 功能特點
|
39 |
+
- **即時股票數據**:直接從 Yahoo Finance 獲取股票數據。
|
40 |
+
- **可自定義預測特徵**:用戶可以選擇不同的特徵(如 開盤價、最高價、最低價、收盤價、成交量)進行預測。
|
41 |
+
- **動態圖表顯示**:使用 Plotly 提供互動式的股價圖表顯示。
|
42 |
+
- **靈活的數據範圍選擇**:用戶可以選擇不同的數據範圍(如 1年、半年、3個月、1個月)。
|
43 |
+
- **台股類別分析**:提取並分析台灣股票類別,幫助用戶獲得更多見解。
|
44 |
+
|
45 |
+
### 安裝
|
46 |
+
要在本地運行此項目,您需要安裝 Python 和必要的依賴項。您可以使用以下命令安裝依賴項:
|
47 |
+
|
48 |
+
```sh
|
49 |
+
pip install -r requirements.txt
|
50 |
+
```
|
51 |
+
|
52 |
+
### 運行應用程序
|
53 |
+
運行以下命令以啟動應用程序:
|
54 |
+
|
55 |
+
```sh
|
56 |
+
python app.py
|
57 |
+
```
|
58 |
+
|
59 |
+
### 部署到 Hugging Face Spaces
|
60 |
+
您可以將這個應用程序部署到 Hugging Face Spaces,只需將項目文件推送到您的 Hugging Face 倉庫即可。
|
61 |
+
|
62 |
+
### 授權協議
|
63 |
+
此項目以 MIT 許可證開源。
|
64 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=3.0.18
|
2 |
+
aiohttp>=3.8.1
|
3 |
+
requests>=2.31.0
|
4 |
+
beautifulsoup4>=4.12.2
|
5 |
+
pandas>=1.5.3
|
6 |
+
numpy>=1.24.0
|
7 |
+
scikit-learn>=1.2.2
|
8 |
+
tensorflow>=2.13.0
|
9 |
+
plotly>=5.15.0
|
10 |
+
yfinance>=0.2.26
|
11 |
+
matplotlib>=3.7.2
|