tbdavid2019 commited on
Commit
d83f194
·
1 Parent(s): 39eedf9

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +283 -0
  2. readme.md +64 -0
  3. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import aiohttp
3
+ import asyncio
4
+ import requests
5
+ from bs4 import BeautifulSoup
6
+ import pandas as pd
7
+ import numpy as np
8
+ from sklearn.preprocessing import MinMaxScaler
9
+ from tensorflow.keras.models import Sequential
10
+ from tensorflow.keras.layers import LSTM, Dense, Dropout
11
+ from tensorflow.keras.optimizers import Adam
12
+ from datetime import datetime, timedelta
13
+ import plotly.graph_objs as go
14
+ import plotly.io as pio
15
+ import yfinance as yf
16
+ import logging
17
+ import tempfile
18
+ import os
19
+ import matplotlib as mpl
20
+ import matplotlib.font_manager as fm
21
+
22
+ # 設置日志
23
+ logging.basicConfig(level=logging.INFO,
24
+ format='%(asctime)s - %(levelname)s - %(message)s')
25
+
26
+ # 字體設置
27
+ def setup_font():
28
+ try:
29
+ url_font = "https://drive.google.com/uc?id=1eGAsTN1HBpJAkeVM57_C7ccp7hbgSz3_"
30
+ response_font = requests.get(url_font)
31
+
32
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.ttf') as tmp_file:
33
+ tmp_file.write(response_font.content)
34
+ tmp_file_path = tmp_file.name
35
+
36
+ fm.fontManager.addfont(tmp_file_path)
37
+ mpl.rc('font', family='Taipei Sans TC Beta')
38
+ except Exception as e:
39
+ logging.error(f"字體設置失敗: {str(e)}")
40
+ # 使用備用字體
41
+ mpl.rc('font', family='SimHei')
42
+
43
+ # 網路請求設置
44
+ headers = {
45
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
46
+ 'Accept-Language': 'zh-TW,zh;q=0.9,en-US;q=0.8,en;q=0.7',
47
+ 'Accept-Encoding': 'gzip, deflate, br',
48
+ 'Connection': 'keep-alive',
49
+ 'Upgrade-Insecure-Requests': '1'
50
+ }
51
+
52
+ async def fetch_stock_categories():
53
+ try:
54
+ url = "https://tw.stock.yahoo.com/class/"
55
+ async with aiohttp.ClientSession() as session:
56
+ async with session.get(url, headers=headers) as response:
57
+ response_text = await response.text()
58
+ soup = BeautifulSoup(response_text, 'html.parser')
59
+ main_categories = soup.find_all('div', class_='C($c-link-text)')
60
+
61
+ data = []
62
+ for category in main_categories:
63
+ main_category_name = category.find('h2', class_="Fw(b) Fz(24px) Lh(32px)")
64
+ if main_category_name:
65
+ main_category_name = main_category_name.text.strip()
66
+ sub_categories = category.find_all('a', class_='Fz(16px) Lh(1.5) C($c-link-text) C($c-active-text):h Fw(b):h Td(n)')
67
+
68
+ for sub_category in sub_categories:
69
+ data.append({
70
+ '台股': main_category_name,
71
+ '類股': sub_category.text.strip(),
72
+ '網址': "https://tw.stock.yahoo.com" + sub_category['href']
73
+ })
74
+
75
+ category_dict = {}
76
+ for item in data:
77
+ if item['台股'] not in category_dict:
78
+ category_dict[item['台股']] = []
79
+ category_dict[item['台股']].append({'類股': item['類股'], '網址': item['網址']})
80
+
81
+ return category_dict
82
+ except Exception as e:
83
+ logging.error(f"獲取股票類別失敗: {str(e)}")
84
+ return {}
85
+
86
+ # 股票預測模型類別
87
+ class StockPredictor:
88
+ def __init__(self):
89
+ self.model = None
90
+ self.scaler = MinMaxScaler()
91
+
92
+ def prepare_data(self, df, selected_features):
93
+ scaled_data = self.scaler.fit_transform(df[selected_features])
94
+
95
+ X, y = [], []
96
+ for i in range(len(scaled_data) - 1):
97
+ X.append(scaled_data[i])
98
+ y.append(scaled_data[i+1])
99
+
100
+ return np.array(X).reshape(-1, 1, len(selected_features)), np.array(y)
101
+
102
+ def build_model(self, input_shape):
103
+ model = Sequential([
104
+ LSTM(100, activation='relu', input_shape=input_shape, return_sequences=True),
105
+ Dropout(0.2),
106
+ LSTM(50, activation='relu'),
107
+ Dropout(0.2),
108
+ Dense(input_shape[1])
109
+ ])
110
+ model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
111
+ return model
112
+
113
+ def train(self, df, selected_features):
114
+ X, y = self.prepare_data(df, selected_features)
115
+ self.model = self.build_model((1, X.shape[2]))
116
+ history = self.model.fit(
117
+ X, y,
118
+ epochs=50,
119
+ batch_size=32,
120
+ validation_split=0.2,
121
+ verbose=0
122
+ )
123
+ return history
124
+
125
+ def predict(self, last_data, n_days):
126
+ predictions = []
127
+ current_data = last_data.copy()
128
+
129
+ for _ in range(n_days):
130
+ next_day = self.model.predict(current_data.reshape(1, 1, -1), verbose=0)
131
+ predictions.append(next_day[0])
132
+
133
+ current_data = next_day
134
+
135
+ return np.array(predictions)
136
+
137
+ # Gradio界面函數
138
+ async def update_stocks(category):
139
+ if not category or category not in category_dict:
140
+ return []
141
+ return [item['類股'] for item in category_dict[category]]
142
+
143
+ async def get_stock_items(url):
144
+ try:
145
+ async with aiohttp.ClientSession() as session:
146
+ async with session.get(url, headers=headers) as response:
147
+ response_text = await response.text()
148
+ soup = BeautifulSoup(response_text, 'html.parser')
149
+ stock_items = soup.find_all('li', class_='List(n)')
150
+
151
+ stocks_dict = {}
152
+ for item in stock_items:
153
+ stock_name = item.find('div', class_='Lh(20px) Fw(600) Fz(16px) Ell')
154
+ stock_code = item.find('span', class_='Fz(14px) C(#979ba7) Ell')
155
+ if stock_name and stock_code:
156
+ full_code = stock_code.text.strip()
157
+ display_code = full_code.split('.')[0]
158
+ display_name = f"{stock_name.text.strip()}{display_code}"
159
+ stocks_dict[display_name] = full_code
160
+
161
+ return stocks_dict
162
+ except Exception as e:
163
+ logging.error(f"獲取股票項目失敗: {str(e)}")
164
+ return {}
165
+
166
+ async def predict_stock(category, stock, stock_item, period, selected_features):
167
+ if not all([category, stock, stock_item]):
168
+ return gr.update(value=None), "請選擇產業類別、類股和股票"
169
+
170
+ try:
171
+ url = next((item['網址'] for item in category_dict.get(category, [])
172
+ if item['類股'] == stock), None)
173
+ if not url:
174
+ return gr.update(value=None), "無法獲取類股網址"
175
+
176
+ stock_items = await get_stock_items(url)
177
+ stock_code = stock_items.get(stock_item, "")
178
+
179
+ if not stock_code:
180
+ return gr.update(value=None), "無法獲取股票代碼"
181
+
182
+ # 下載股票數據,根據用戶選擇的時間範圍
183
+ df = yf.download(stock_code, period=period)
184
+ if df.empty:
185
+ raise ValueError("無法獲取股票數據")
186
+
187
+ # 預測
188
+ predictor = StockPredictor()
189
+ predictor.train(df, selected_features)
190
+
191
+ last_data = predictor.scaler.transform(df.iloc[-1:][selected_features])
192
+ predictions = predictor.predict(last_data[0], 5)
193
+
194
+ # 創建日期指標
195
+ dates = [datetime.now() + timedelta(days=i) for i in range(6)]
196
+ date_labels = [d.strftime('%m/%d') for d in dates]
197
+
198
+ # 用 Plotly 繪圖
199
+ fig = go.Figure()
200
+ for i, feature in enumerate(selected_features):
201
+ fig.add_trace(go.Scatter(
202
+ x=date_labels,
203
+ y=np.hstack([df[feature].iloc[-1], predictions[:, i]]),
204
+ mode='lines+markers',
205
+ name=f'預測{feature}'
206
+ ))
207
+
208
+ fig.update_layout(
209
+ title=f'{stock_item} 股價預測 (未來5天)',
210
+ xaxis_title='日期',
211
+ yaxis_title='股價',
212
+ template='plotly_dark'
213
+ )
214
+
215
+ return gr.update(value=pio.to_html(fig, full_html=False)), "預測成功"
216
+
217
+ except Exception as e:
218
+ logging.error(f"預測過程發生錯誤: {str(e)}")
219
+ return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
220
+
221
+ # 初始化
222
+ setup_font()
223
+ category_dict = asyncio.run(fetch_stock_categories())
224
+ categories = list(category_dict.keys())
225
+
226
+ # Gradio界面
227
+ with gr.Blocks() as demo:
228
+ gr.Markdown("# 台股預測系統")
229
+ with gr.Row():
230
+ with gr.Column():
231
+ category_dropdown = gr.Dropdown(
232
+ choices=categories,
233
+ label="產業類別",
234
+ value=None
235
+ )
236
+ stock_dropdown = gr.Dropdown(
237
+ choices=[],
238
+ label="類股",
239
+ value=None
240
+ )
241
+ stock_item_dropdown = gr.Dropdown(
242
+ choices=[],
243
+ label="股票",
244
+ value=None
245
+ )
246
+ period_dropdown = gr.Dropdown(
247
+ choices=["1y", "6mo", "3mo", "1mo"],
248
+ label="抓取時間範圍",
249
+ value="1y"
250
+ )
251
+ features_checkbox = gr.CheckboxGroup(
252
+ choices=['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume'],
253
+ label="選擇要用於預測的特徵",
254
+ value=['Open', 'Close']
255
+ )
256
+ predict_button = gr.Button("開始預測", variant="primary")
257
+ status_output = gr.Textbox(label="狀態", interactive=False)
258
+
259
+ with gr.Row():
260
+ stock_plot = gr.HTML(label="股價預測圖")
261
+
262
+ # 事件綁定
263
+ category_dropdown.change(
264
+ update_stocks,
265
+ inputs=[category_dropdown],
266
+ outputs=[stock_dropdown]
267
+ )
268
+
269
+ stock_dropdown.change(
270
+ update_stocks,
271
+ inputs=[category_dropdown],
272
+ outputs=[stock_item_dropdown]
273
+ )
274
+
275
+ predict_button.click(
276
+ predict_stock,
277
+ inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkbox],
278
+ outputs=[stock_plot, status_output]
279
+ )
280
+
281
+ # 啟動應用
282
+ if __name__ == "__main__":
283
+ demo.launch(share=False)
readme.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stock Prediction Application
2
+
3
+ This project is a stock prediction application developed using Gradio, TensorFlow, and Plotly, with integration of Yahoo Finance for stock data fetching. It allows users to predict future stock prices by selecting various stock features, periods, and categories in the Taiwanese market. The application is designed to be deployed on Hugging Face Spaces, providing a user-friendly interface for non-technical users.
4
+
5
+ ### Features
6
+ - **Real-time Stock Data**: Retrieve stock data directly from Yahoo Finance.
7
+ - **Customizable Prediction Features**: Users can select different features (e.g., Open, High, Low, Close, Volume) for prediction.
8
+ - **Dynamic Charting**: The stock prices are displayed with interactive charts using Plotly.
9
+ - **Flexible Data Range**: Users can select different data ranges (e.g., 1 year, 6 months, 3 months, 1 month).
10
+ - **Taiwanese Stock Categories**: Extract and analyze Taiwanese stock categories to help users gain insights.
11
+
12
+ ### Installation
13
+ To run this project locally, you need to have Python installed and the required dependencies. You can install the dependencies using the following command:
14
+
15
+ ```sh
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ ### Run the Application
20
+ To launch the application, run the following command:
21
+
22
+ ```sh
23
+ python app.py
24
+ ```
25
+
26
+ ### Deploy to Hugging Face Spaces
27
+ This application can be deployed to Hugging Face Spaces by pushing the project files to your Hugging Face repository.
28
+
29
+ ### License
30
+ This project is open-sourced under the MIT license.
31
+
32
+ ---
33
+
34
+ # 股票預測應用程序
35
+
36
+ 這個項目是一個使用 Gradio、TensorFlow 和 Plotly 開發的股票預測應用程序,集成了 Yahoo Finance 以抓取股票數據。它允許用戶通過選擇各種股票特徵、時間範圍和台股類別來預測未來的股票價格。該應用旨在部署到 Hugging Face Spaces,為非技術用戶提供友好的界面。
37
+
38
+ ### 功能特點
39
+ - **即時股票數據**:直接從 Yahoo Finance 獲取股票數據。
40
+ - **可自定義預測特徵**:用戶可以選擇不同的特徵(如 開盤價、最高價、最低價、收盤價、成交量)進行預測。
41
+ - **動態圖表顯示**:使用 Plotly 提供互動式的股價圖表顯示。
42
+ - **靈活的數據範圍選擇**:用戶可以選擇不同的數據範圍(如 1年、半年、3個月、1個月)。
43
+ - **台股類別分析**:提取並分析台灣股票類別,幫助用戶獲得更多見解。
44
+
45
+ ### 安裝
46
+ 要在本地運行此項目,您需要安裝 Python 和必要的依賴項。您可以使用以下命令安裝依賴項:
47
+
48
+ ```sh
49
+ pip install -r requirements.txt
50
+ ```
51
+
52
+ ### 運行應用程序
53
+ 運行以下命令以啟動應用程序:
54
+
55
+ ```sh
56
+ python app.py
57
+ ```
58
+
59
+ ### 部署到 Hugging Face Spaces
60
+ 您可以將這個應用程序部署到 Hugging Face Spaces,只需將項目文件推送到您的 Hugging Face 倉庫即可。
61
+
62
+ ### 授權協議
63
+ 此項目以 MIT 許可證開源。
64
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=3.0.18
2
+ aiohttp>=3.8.1
3
+ requests>=2.31.0
4
+ beautifulsoup4>=4.12.2
5
+ pandas>=1.5.3
6
+ numpy>=1.24.0
7
+ scikit-learn>=1.2.2
8
+ tensorflow>=2.13.0
9
+ plotly>=5.15.0
10
+ yfinance>=0.2.26
11
+ matplotlib>=3.7.2