tbdavid2019 commited on
Commit
40f675a
·
1 Parent(s): d83f194
Files changed (1) hide show
  1. app.py +141 -110
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import gradio as gr
2
- import aiohttp
3
- import asyncio
4
  import requests
5
  from bs4 import BeautifulSoup
6
  import pandas as pd
@@ -9,15 +7,15 @@ from sklearn.preprocessing import MinMaxScaler
9
  from tensorflow.keras.models import Sequential
10
  from tensorflow.keras.layers import LSTM, Dense, Dropout
11
  from tensorflow.keras.optimizers import Adam
12
- from datetime import datetime, timedelta
13
- import plotly.graph_objs as go
14
- import plotly.io as pio
15
- import yfinance as yf
16
- import logging
17
- import tempfile
18
- import os
19
  import matplotlib as mpl
20
  import matplotlib.font_manager as fm
 
 
 
 
 
21
 
22
  # 設置日志
23
  logging.basicConfig(level=logging.INFO,
@@ -49,55 +47,56 @@ headers = {
49
  'Upgrade-Insecure-Requests': '1'
50
  }
51
 
52
- async def fetch_stock_categories():
53
  try:
54
  url = "https://tw.stock.yahoo.com/class/"
55
- async with aiohttp.ClientSession() as session:
56
- async with session.get(url, headers=headers) as response:
57
- response_text = await response.text()
58
- soup = BeautifulSoup(response_text, 'html.parser')
59
- main_categories = soup.find_all('div', class_='C($c-link-text)')
60
-
61
- data = []
62
- for category in main_categories:
63
- main_category_name = category.find('h2', class_="Fw(b) Fz(24px) Lh(32px)")
64
- if main_category_name:
65
- main_category_name = main_category_name.text.strip()
66
- sub_categories = category.find_all('a', class_='Fz(16px) Lh(1.5) C($c-link-text) C($c-active-text):h Fw(b):h Td(n)')
67
-
68
- for sub_category in sub_categories:
69
- data.append({
70
- '台股': main_category_name,
71
- '類股': sub_category.text.strip(),
72
- '網址': "https://tw.stock.yahoo.com" + sub_category['href']
73
- })
74
-
75
- category_dict = {}
76
- for item in data:
77
- if item['台股'] not in category_dict:
78
- category_dict[item['台股']] = []
79
- category_dict[item['台股']].append({'類股': item['類股'], '網址': item['網址']})
80
 
81
- return category_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  except Exception as e:
83
  logging.error(f"獲取股票類別失敗: {str(e)}")
84
  return {}
85
 
86
- # 股票預測模型類別
87
  class StockPredictor:
88
  def __init__(self):
89
  self.model = None
90
  self.scaler = MinMaxScaler()
91
 
92
- def prepare_data(self, df, selected_features):
93
- scaled_data = self.scaler.fit_transform(df[selected_features])
 
94
 
95
  X, y = [], []
96
  for i in range(len(scaled_data) - 1):
97
  X.append(scaled_data[i])
98
- y.append(scaled_data[i+1])
99
 
100
- return np.array(X).reshape(-1, 1, len(selected_features)), np.array(y)
101
 
102
  def build_model(self, input_shape):
103
  model = Sequential([
@@ -105,13 +104,13 @@ class StockPredictor:
105
  Dropout(0.2),
106
  LSTM(50, activation='relu'),
107
  Dropout(0.2),
108
- Dense(input_shape[1])
109
  ])
110
  model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
111
  return model
112
 
113
- def train(self, df, selected_features):
114
- X, y = self.prepare_data(df, selected_features)
115
  self.model = self.build_model((1, X.shape[2]))
116
  history = self.model.fit(
117
  X, y,
@@ -130,97 +129,139 @@ class StockPredictor:
130
  next_day = self.model.predict(current_data.reshape(1, 1, -1), verbose=0)
131
  predictions.append(next_day[0])
132
 
133
- current_data = next_day
 
 
 
134
 
135
  return np.array(predictions)
136
 
137
  # Gradio界面函數
138
- async def update_stocks(category):
139
  if not category or category not in category_dict:
140
  return []
141
  return [item['類股'] for item in category_dict[category]]
142
 
143
- async def get_stock_items(url):
144
  try:
145
- async with aiohttp.ClientSession() as session:
146
- async with session.get(url, headers=headers) as response:
147
- response_text = await response.text()
148
- soup = BeautifulSoup(response_text, 'html.parser')
149
- stock_items = soup.find_all('li', class_='List(n)')
150
-
151
- stocks_dict = {}
152
- for item in stock_items:
153
- stock_name = item.find('div', class_='Lh(20px) Fw(600) Fz(16px) Ell')
154
- stock_code = item.find('span', class_='Fz(14px) C(#979ba7) Ell')
155
- if stock_name and stock_code:
156
- full_code = stock_code.text.strip()
157
- display_code = full_code.split('.')[0]
158
- display_name = f"{stock_name.text.strip()}{display_code}"
159
- stocks_dict[display_name] = full_code
160
-
161
- return stocks_dict
162
  except Exception as e:
163
  logging.error(f"獲取股票項目失敗: {str(e)}")
164
  return {}
165
 
166
- async def predict_stock(category, stock, stock_item, period, selected_features):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  if not all([category, stock, stock_item]):
168
- return gr.update(value=None), "請選擇產業類別、類股和股票"
169
 
170
  try:
171
  url = next((item['網址'] for item in category_dict.get(category, [])
172
  if item['類股'] == stock), None)
173
  if not url:
174
- return gr.update(value=None), "無法獲取類股網址"
175
 
176
- stock_items = await get_stock_items(url)
177
  stock_code = stock_items.get(stock_item, "")
178
 
179
  if not stock_code:
180
- return gr.update(value=None), "無法獲取股票代碼"
181
 
182
- # 下載股票數據,根據用戶選擇的時間範圍
183
- df = yf.download(stock_code, period=period)
184
  if df.empty:
185
  raise ValueError("無法獲取股票數據")
186
 
187
  # 預測
188
  predictor = StockPredictor()
189
- predictor.train(df, selected_features)
190
 
191
- last_data = predictor.scaler.transform(df.iloc[-1:][selected_features])
192
  predictions = predictor.predict(last_data[0], 5)
193
 
 
 
 
 
 
 
 
194
  # 創建日期指標
195
  dates = [datetime.now() + timedelta(days=i) for i in range(6)]
196
  date_labels = [d.strftime('%m/%d') for d in dates]
197
 
198
- # 用 Plotly 繪圖
199
- fig = go.Figure()
200
- for i, feature in enumerate(selected_features):
201
- fig.add_trace(go.Scatter(
202
- x=date_labels,
203
- y=np.hstack([df[feature].iloc[-1], predictions[:, i]]),
204
- mode='lines+markers',
205
- name=f'預測{feature}'
206
- ))
207
 
208
- fig.update_layout(
209
- title=f'{stock_item} 股價預測 (未來5天)',
210
- xaxis_title='日期',
211
- yaxis_title='股價',
212
- template='plotly_dark'
213
- )
 
 
 
 
 
 
 
214
 
215
- return gr.update(value=pio.to_html(fig, full_html=False)), "預測成功"
 
216
 
217
  except Exception as e:
218
  logging.error(f"預測過程發生錯誤: {str(e)}")
219
- return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
220
 
221
  # 初始化
222
  setup_font()
223
- category_dict = asyncio.run(fetch_stock_categories())
224
  categories = list(category_dict.keys())
225
 
226
  # Gradio界面
@@ -243,41 +284,31 @@ with gr.Blocks() as demo:
243
  label="股票",
244
  value=None
245
  )
246
- period_dropdown = gr.Dropdown(
247
- choices=["1y", "6mo", "3mo", "1mo"],
248
- label="抓取時間範圍",
249
- value="1y"
250
- )
251
- features_checkbox = gr.CheckboxGroup(
252
- choices=['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume'],
253
- label="選擇要用於預測的特徵",
254
- value=['Open', 'Close']
255
- )
256
  predict_button = gr.Button("開始預測", variant="primary")
257
- status_output = gr.Textbox(label="狀態", interactive=False)
258
 
259
  with gr.Row():
260
- stock_plot = gr.HTML(label="股價預測圖")
261
 
262
  # 事件綁定
263
  category_dropdown.change(
264
- update_stocks,
265
  inputs=[category_dropdown],
266
- outputs=[stock_dropdown]
267
  )
268
 
269
  stock_dropdown.change(
270
- update_stocks,
271
- inputs=[category_dropdown],
272
- outputs=[stock_item_dropdown]
273
  )
274
 
275
  predict_button.click(
276
  predict_stock,
277
- inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkbox],
278
- outputs=[stock_plot, status_output]
279
  )
280
 
281
  # 啟動應用
282
  if __name__ == "__main__":
283
  demo.launch(share=False)
 
 
1
  import gradio as gr
 
 
2
  import requests
3
  from bs4 import BeautifulSoup
4
  import pandas as pd
 
7
  from tensorflow.keras.models import Sequential
8
  from tensorflow.keras.layers import LSTM, Dense, Dropout
9
  from tensorflow.keras.optimizers import Adam
10
+ import matplotlib.pyplot as plt
11
+ import io
 
 
 
 
 
12
  import matplotlib as mpl
13
  import matplotlib.font_manager as fm
14
+ import tempfile
15
+ import os
16
+ import yfinance as yf
17
+ import logging
18
+ from datetime import datetime, timedelta
19
 
20
  # 設置日志
21
  logging.basicConfig(level=logging.INFO,
 
47
  'Upgrade-Insecure-Requests': '1'
48
  }
49
 
50
+ def fetch_stock_categories():
51
  try:
52
  url = "https://tw.stock.yahoo.com/class/"
53
+ response = requests.get(url, headers=headers, timeout=10)
54
+ response.raise_for_status()
55
+
56
+ soup = BeautifulSoup(response.text, 'html.parser')
57
+ main_categories = soup.find_all('div', class_='C($c-link-text)')
58
+
59
+ data = []
60
+ for category in main_categories:
61
+ main_category_name = category.find('h2', class_="Fw(b) Fz(24px) Lh(32px)")
62
+ if main_category_name:
63
+ main_category_name = main_category_name.text.strip()
64
+ sub_categories = category.find_all('a', class_='Fz(16px) Lh(1.5) C($c-link-text) C($c-active-text):h Fw(b):h Td(n)')
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ for sub_category in sub_categories:
67
+ data.append({
68
+ '台股': main_category_name,
69
+ '類股': sub_category.text.strip(),
70
+ '網址': "https://tw.stock.yahoo.com" + sub_category['href']
71
+ })
72
+
73
+ category_dict = {}
74
+ for item in data:
75
+ if item['台股'] not in category_dict:
76
+ category_dict[item['台股']] = []
77
+ category_dict[item['台股']].append({'類股': item['類股'], '網址': item['網址']})
78
+
79
+ return category_dict
80
  except Exception as e:
81
  logging.error(f"獲取股票類別失敗: {str(e)}")
82
  return {}
83
 
84
+ # 股票預測模型類別保持不變...
85
  class StockPredictor:
86
  def __init__(self):
87
  self.model = None
88
  self.scaler = MinMaxScaler()
89
 
90
+ def prepare_data(self, df):
91
+ features = ['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume']
92
+ scaled_data = self.scaler.fit_transform(df[features])
93
 
94
  X, y = [], []
95
  for i in range(len(scaled_data) - 1):
96
  X.append(scaled_data[i])
97
+ y.append(scaled_data[i+1, [0, 3]]) # Open和Close的索引
98
 
99
+ return np.array(X).reshape(-1, 1, len(features)), np.array(y)
100
 
101
  def build_model(self, input_shape):
102
  model = Sequential([
 
104
  Dropout(0.2),
105
  LSTM(50, activation='relu'),
106
  Dropout(0.2),
107
+ Dense(2)
108
  ])
109
  model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
110
  return model
111
 
112
+ def train(self, df):
113
+ X, y = self.prepare_data(df)
114
  self.model = self.build_model((1, X.shape[2]))
115
  history = self.model.fit(
116
  X, y,
 
129
  next_day = self.model.predict(current_data.reshape(1, 1, -1), verbose=0)
130
  predictions.append(next_day[0])
131
 
132
+ current_data = current_data.flatten()
133
+ current_data[0] = next_day[0][0]
134
+ current_data[3] = next_day[0][1]
135
+ current_data = current_data.reshape(1, -1)
136
 
137
  return np.array(predictions)
138
 
139
  # Gradio界面函數
140
+ def update_stocks(category):
141
  if not category or category not in category_dict:
142
  return []
143
  return [item['類股'] for item in category_dict[category]]
144
 
145
+ def get_stock_items(url):
146
  try:
147
+ response = requests.get(url, headers=headers, timeout=10)
148
+ response.raise_for_status()
149
+
150
+ soup = BeautifulSoup(response.text, 'html.parser')
151
+ stock_items = soup.find_all('li', class_='List(n)')
152
+
153
+ stocks_dict = {}
154
+ for item in stock_items:
155
+ stock_name = item.find('div', class_='Lh(20px) Fw(600) Fz(16px) Ell')
156
+ stock_code = item.find('span', class_='Fz(14px) C(#979ba7) Ell')
157
+ if stock_name and stock_code:
158
+ full_code = stock_code.text.strip()
159
+ display_code = full_code.split('.')[0]
160
+ display_name = f"{stock_name.text.strip()}{display_code}"
161
+ stocks_dict[display_name] = full_code
162
+
163
+ return stocks_dict
164
  except Exception as e:
165
  logging.error(f"獲取股票項目失敗: {str(e)}")
166
  return {}
167
 
168
+ def update_category(category):
169
+ stocks = update_stocks(category)
170
+ return {
171
+ stock_dropdown: gr.update(choices=stocks, value=None),
172
+ stock_item_dropdown: gr.update(choices=[], value=None),
173
+ stock_plot: gr.update(value=None)
174
+ }
175
+
176
+ def update_stock(category, stock):
177
+ if not category or not stock:
178
+ return {
179
+ stock_item_dropdown: gr.update(choices=[], value=None),
180
+ stock_plot: gr.update(value=None)
181
+ }
182
+
183
+ url = next((item['網址'] for item in category_dict.get(category, [])
184
+ if item['類股'] == stock), None)
185
+
186
+ if url:
187
+ stock_items = get_stock_items(url)
188
+ return {
189
+ stock_item_dropdown: gr.update(choices=list(stock_items.keys()), value=None),
190
+ stock_plot: gr.update(value=None)
191
+ }
192
+ return {
193
+ stock_item_dropdown: gr.update(choices=[], value=None),
194
+ stock_plot: gr.update(value=None)
195
+ }
196
+
197
+ def predict_stock(category, stock, stock_item):
198
  if not all([category, stock, stock_item]):
199
+ return gr.update(value=None)
200
 
201
  try:
202
  url = next((item['網址'] for item in category_dict.get(category, [])
203
  if item['類股'] == stock), None)
204
  if not url:
205
+ return gr.update(value=None)
206
 
207
+ stock_items = get_stock_items(url)
208
  stock_code = stock_items.get(stock_item, "")
209
 
210
  if not stock_code:
211
+ return gr.update(value=None)
212
 
213
+ # 下載股票數據
214
+ df = yf.download(stock_code, period="1y")
215
  if df.empty:
216
  raise ValueError("無法獲取股票數據")
217
 
218
  # 預測
219
  predictor = StockPredictor()
220
+ predictor.train(df)
221
 
222
+ last_data = predictor.scaler.transform(df.iloc[-1:][['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume']])
223
  predictions = predictor.predict(last_data[0], 5)
224
 
225
+ # 反轉預測結果
226
+ last_original = df[['Open', 'Close']].iloc[-1].values
227
+ predictions_original = predictor.scaler.inverse_transform(
228
+ np.hstack([predictions, np.zeros((predictions.shape[0], 4))])
229
+ )[:, :2]
230
+ all_predictions = np.vstack([last_original, predictions_original])
231
+
232
  # 創建日期指標
233
  dates = [datetime.now() + timedelta(days=i) for i in range(6)]
234
  date_labels = [d.strftime('%m/%d') for d in dates]
235
 
236
+ # 繪圖
237
+ fig, ax = plt.subplots(figsize=(14, 7))
238
+ colors = ['#FF9999', '#66B2FF']
239
+ labels = ['預測開盤價', '預測收盤價']
 
 
 
 
 
240
 
241
+ for i, (col, label, color) in enumerate(zip(['Open', 'Close'], labels, colors)):
242
+ ax.plot(date_labels, all_predictions[:, i], label=label,
243
+ marker='o', color=color, linewidth=2)
244
+ for j, value in enumerate(all_predictions[:, i]):
245
+ ax.annotate(f'{value:.2f}', (date_labels[j], value),
246
+ textcoords="offset points", xytext=(0,10),
247
+ ha='center', va='bottom')
248
+
249
+ ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
250
+ ax.set_xlabel('日期', labelpad=10)
251
+ ax.set_ylabel('股價', labelpad=10)
252
+ ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
253
+ ax.grid(True, linestyle='--', alpha=0.7)
254
 
255
+ plt.tight_layout()
256
+ return gr.update(value=fig)
257
 
258
  except Exception as e:
259
  logging.error(f"預測過程發生錯誤: {str(e)}")
260
+ return gr.update(value=None)
261
 
262
  # 初始化
263
  setup_font()
264
+ category_dict = fetch_stock_categories()
265
  categories = list(category_dict.keys())
266
 
267
  # Gradio界面
 
284
  label="股票",
285
  value=None
286
  )
 
 
 
 
 
 
 
 
 
 
287
  predict_button = gr.Button("開始預測", variant="primary")
 
288
 
289
  with gr.Row():
290
+ stock_plot = gr.Plot(label="股價預測圖")
291
 
292
  # 事件綁定
293
  category_dropdown.change(
294
+ update_category,
295
  inputs=[category_dropdown],
296
+ outputs=[stock_dropdown, stock_item_dropdown, stock_plot]
297
  )
298
 
299
  stock_dropdown.change(
300
+ update_stock,
301
+ inputs=[category_dropdown, stock_dropdown],
302
+ outputs=[stock_item_dropdown, stock_plot]
303
  )
304
 
305
  predict_button.click(
306
  predict_stock,
307
+ inputs=[category_dropdown, stock_dropdown, stock_item_dropdown],
308
+ outputs=[stock_plot]
309
  )
310
 
311
  # 啟動應用
312
  if __name__ == "__main__":
313
  demo.launch(share=False)
314
+