tbdavid2019 commited on
Commit
359bd0c
·
1 Parent(s): f7c1877
Files changed (1) hide show
  1. app.py +86 -100
app.py CHANGED
@@ -7,7 +7,6 @@ from sklearn.preprocessing import MinMaxScaler
7
  from tensorflow.keras.models import Sequential
8
  from tensorflow.keras.layers import LSTM, Dense, Dropout
9
  from tensorflow.keras.optimizers import Adam
10
- from prophet import Prophet
11
  import matplotlib.pyplot as plt
12
  import io
13
  import matplotlib as mpl
@@ -17,21 +16,20 @@ import os
17
  import yfinance as yf
18
  import logging
19
  from datetime import datetime, timedelta
 
20
 
21
- # 設置日志
22
  logging.basicConfig(level=logging.INFO,
23
- format='%(asctime)s - %(levelname)s - %(message)s')
24
 
25
  # 字體設置
26
  def setup_font():
27
  try:
28
  url_font = "https://drive.google.com/uc?id=1eGAsTN1HBpJAkeVM57_C7ccp7hbgSz3_"
29
  response_font = requests.get(url_font)
30
-
31
  with tempfile.NamedTemporaryFile(delete=False, suffix='.ttf') as tmp_file:
32
  tmp_file.write(response_font.content)
33
  tmp_file_path = tmp_file.name
34
-
35
  fm.fontManager.addfont(tmp_file_path)
36
  mpl.rc('font', family='Taipei Sans TC Beta')
37
  except Exception as e:
@@ -53,50 +51,44 @@ def fetch_stock_categories():
53
  url = "https://tw.stock.yahoo.com/class/"
54
  response = requests.get(url, headers=headers, timeout=10)
55
  response.raise_for_status()
56
-
57
  soup = BeautifulSoup(response.text, 'html.parser')
58
  main_categories = soup.find_all('div', class_='C($c-link-text)')
59
-
60
  data = []
61
  for category in main_categories:
62
  main_category_name = category.find('h2', class_="Fw(b) Fz(24px) Lh(32px)")
63
  if main_category_name:
64
  main_category_name = main_category_name.text.strip()
65
  sub_categories = category.find_all('a', class_='Fz(16px) Lh(1.5) C($c-link-text) C($c-active-text):h Fw(b):h Td(n)')
66
-
67
  for sub_category in sub_categories:
68
  data.append({
69
  '台股': main_category_name,
70
  '類股': sub_category.text.strip(),
71
  '網址': "https://tw.stock.yahoo.com" + sub_category['href']
72
  })
73
-
74
  category_dict = {}
75
  for item in data:
76
  if item['台股'] not in category_dict:
77
  category_dict[item['台股']] = []
78
  category_dict[item['台股']].append({'類股': item['類股'], '網址': item['網址']})
79
-
80
  return category_dict
81
  except Exception as e:
82
  logging.error(f"獲取股票類別失敗: {str(e)}")
83
  return {}
84
 
85
- # 股票預測模型類別保持不變...
86
  class StockPredictor:
87
  def __init__(self):
88
- self.model = None
 
89
  self.scaler = MinMaxScaler()
90
-
91
- def prepare_data(self, df, features):
92
- scaled_data = self.scaler.fit_transform(df[features])
93
-
94
  X, y = [], []
95
  for i in range(len(scaled_data) - 1):
96
  X.append(scaled_data[i])
97
- y.append(scaled_data[i+1, [0, 3]]) # Open和Close的索引
98
-
99
- return np.array(X).reshape(-1, 1, len(features)), np.array(y)
100
 
101
  def build_model(self, input_shape):
102
  model = Sequential([
@@ -104,15 +96,15 @@ class StockPredictor:
104
  Dropout(0.2),
105
  LSTM(50, activation='relu'),
106
  Dropout(0.2),
107
- Dense(2)
108
  ])
109
  model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
110
  return model
111
 
112
- def train(self, df, features):
113
- X, y = self.prepare_data(df, features)
114
- self.model = self.build_model((1, X.shape[2]))
115
- history = self.model.fit(
116
  X, y,
117
  epochs=50,
118
  batch_size=32,
@@ -124,17 +116,26 @@ class StockPredictor:
124
  def predict(self, last_data, n_days):
125
  predictions = []
126
  current_data = last_data.copy()
127
-
128
  for _ in range(n_days):
129
- next_day = self.model.predict(current_data.reshape(1, 1, -1), verbose=0)
130
  predictions.append(next_day[0])
131
-
132
  current_data = current_data.flatten()
133
- current_data[0] = next_day[0][0]
134
- current_data[3] = next_day[0][1]
135
  current_data = current_data.reshape(1, -1)
136
-
137
  return np.array(predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  # Gradio界面函數
140
  def update_stocks(category):
@@ -146,10 +147,8 @@ def get_stock_items(url):
146
  try:
147
  response = requests.get(url, headers=headers, timeout=10)
148
  response.raise_for_status()
149
-
150
  soup = BeautifulSoup(response.text, 'html.parser')
151
  stock_items = soup.find_all('li', class_='List(n)')
152
-
153
  stocks_dict = {}
154
  for item in stock_items:
155
  stock_name = item.find('div', class_='Lh(20px) Fw(600) Fz(16px) Ell')
@@ -159,7 +158,6 @@ def get_stock_items(url):
159
  display_code = full_code.split('.')[0]
160
  display_name = f"{stock_name.text.strip()}{display_code}"
161
  stocks_dict[display_name] = full_code
162
-
163
  return stocks_dict
164
  except Exception as e:
165
  logging.error(f"獲取股票項目失敗: {str(e)}")
@@ -170,108 +168,99 @@ def update_category(category):
170
  return {
171
  stock_dropdown: gr.update(choices=stocks, value=None),
172
  stock_item_dropdown: gr.update(choices=[], value=None),
173
- stock_plot: gr.update(value=None)
 
174
  }
175
 
176
  def update_stock(category, stock):
177
  if not category or not stock:
178
  return {
179
  stock_item_dropdown: gr.update(choices=[], value=None),
180
- stock_plot: gr.update(value=None)
 
181
  }
182
-
183
  url = next((item['網址'] for item in category_dict.get(category, [])
184
  if item['類股'] == stock), None)
185
-
186
  if url:
187
  stock_items = get_stock_items(url)
188
  return {
189
  stock_item_dropdown: gr.update(choices=list(stock_items.keys()), value=None),
190
- stock_plot: gr.update(value=None)
 
191
  }
192
  return {
193
  stock_item_dropdown: gr.update(choices=[], value=None),
194
- stock_plot: gr.update(value=None)
 
195
  }
196
 
197
- def predict_stock(category, stock, stock_item, period, features, model_type):
198
  if not all([category, stock, stock_item]):
199
- return gr.update(value=None), "請選擇完整的選項"
200
-
201
  try:
202
  url = next((item['網址'] for item in category_dict.get(category, [])
203
- if item['類股'] == stock), None)
204
  if not url:
205
- return gr.update(value=None), "無法找到該類股的網址"
206
-
207
  stock_items = get_stock_items(url)
208
  stock_code = stock_items.get(stock_item, "")
209
-
210
  if not stock_code:
211
- return gr.update(value=None), "無法找到該股票的代碼"
212
 
213
- # 下載股票數據
214
  df = yf.download(stock_code, period=period)
215
  if df.empty:
216
  raise ValueError("無法獲取股票數據")
217
 
218
  # 預測
 
219
  if model_type == "LSTM":
220
- predictor = StockPredictor()
221
- predictor.train(df, features)
222
- last_data = predictor.scaler.transform(df.iloc[-1:][features])
223
  predictions = predictor.predict(last_data[0], 5)
224
 
225
  # 反轉預測結果
226
- last_original = df[features].iloc[-1].values
227
  predictions_original = predictor.scaler.inverse_transform(
228
- np.hstack([predictions, np.zeros((predictions.shape[0], len(features) - 2))])
229
- )[:, :2]
230
- all_predictions = np.vstack([last_original, predictions_original])
231
-
232
  elif model_type == "Prophet":
233
- prophet_df = df.reset_index()[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'})
234
- m = Prophet()
235
- m.fit(prophet_df)
236
- future = m.make_future_dataframe(periods=5)
237
- forecast = m.predict(future)
238
- all_predictions = forecast[['ds', 'yhat']].tail(6)
239
- date_labels = all_predictions['ds'].dt.strftime('%m/%d').tolist()
240
- all_predictions = all_predictions['yhat'].values
241
-
242
  # 創建日期索引
243
  dates = [datetime.now() + timedelta(days=i) for i in range(6)]
244
  date_labels = [d.strftime('%m/%d') for d in dates]
245
-
246
  # 繪圖
247
  fig, ax = plt.subplots(figsize=(14, 7))
248
- colors = ['#FF9999', '#66B2FF']
249
- labels = ['預測開盤價', '預測收盤價']
250
-
251
- for i, (label, color) in enumerate(labels):
252
- if model_type == "Prophet":
253
- ax.plot(date_labels, all_predictions, label='預測收盤價', marker='o', color=colors[1], linewidth=2)
254
- for j, value in enumerate(all_predictions):
255
- ax.annotate(f'{value:.2f}', (date_labels[j], value),
256
- textcoords="offset points", xytext=(0,10),
257
- ha='center', va='bottom')
258
- break
259
- else:
260
- ax.plot(date_labels, all_predictions[:, i], label=label, marker='o', color=color, linewidth=2)
261
  for j, value in enumerate(all_predictions[:, i]):
262
  ax.annotate(f'{value:.2f}', (date_labels[j], value),
263
- textcoords="offset points", xytext=(0,10),
264
- ha='center', va='bottom')
265
-
 
 
 
 
 
 
 
266
  ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
267
  ax.set_xlabel('日期', labelpad=10)
268
  ax.set_ylabel('股價', labelpad=10)
269
  ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
270
  ax.grid(True, linestyle='--', alpha=0.7)
271
-
272
  plt.tight_layout()
273
  return gr.update(value=fig), "預測成功"
274
-
275
  except Exception as e:
276
  logging.error(f"預測過程發生錯誤: {str(e)}")
277
  return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
@@ -302,45 +291,42 @@ with gr.Blocks() as demo:
302
  value=None
303
  )
304
  period_dropdown = gr.Dropdown(
305
- choices=["1mo", "3mo", "6mo", "1y"],
306
  label="抓取時間範圍",
307
  value="1y"
308
  )
309
- features_checkboxes = gr.CheckboxGroup(
310
  choices=['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume'],
311
  label="選擇要用於預測的特徵",
312
  value=['Open', 'Close']
313
  )
314
- model_type_dropdown = gr.Dropdown(
315
  choices=["LSTM", "Prophet"],
316
- label="選擇預測模型",
317
  value="LSTM"
318
  )
319
  predict_button = gr.Button("開始預測", variant="primary")
320
-
321
- with gr.Row():
322
- stock_plot = gr.Plot(label="股價預測圖")
323
- status_textbox = gr.Textbox(label="狀態", value="")
324
-
325
  # 事件綁定
326
  category_dropdown.change(
327
  update_category,
328
  inputs=[category_dropdown],
329
- outputs=[stock_dropdown, stock_item_dropdown, stock_plot]
330
  )
331
-
332
  stock_dropdown.change(
333
  update_stock,
334
  inputs=[category_dropdown, stock_dropdown],
335
- outputs=[stock_item_dropdown, stock_plot]
336
  )
337
-
338
  predict_button.click(
339
  predict_stock,
340
- inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkboxes, model_type_dropdown],
341
- outputs=[stock_plot, status_textbox]
342
  )
343
 
344
  # 啟動應用
345
  if __name__ == "__main__":
346
- demo.launch(share=False)
 
7
  from tensorflow.keras.models import Sequential
8
  from tensorflow.keras.layers import LSTM, Dense, Dropout
9
  from tensorflow.keras.optimizers import Adam
 
10
  import matplotlib.pyplot as plt
11
  import io
12
  import matplotlib as mpl
 
16
  import yfinance as yf
17
  import logging
18
  from datetime import datetime, timedelta
19
+ from fbprophet import Prophet
20
 
21
+ # 設置日誌
22
  logging.basicConfig(level=logging.INFO,
23
+ format='%(asctime)s - %(levelname)s - %(message)s')
24
 
25
  # 字體設置
26
  def setup_font():
27
  try:
28
  url_font = "https://drive.google.com/uc?id=1eGAsTN1HBpJAkeVM57_C7ccp7hbgSz3_"
29
  response_font = requests.get(url_font)
 
30
  with tempfile.NamedTemporaryFile(delete=False, suffix='.ttf') as tmp_file:
31
  tmp_file.write(response_font.content)
32
  tmp_file_path = tmp_file.name
 
33
  fm.fontManager.addfont(tmp_file_path)
34
  mpl.rc('font', family='Taipei Sans TC Beta')
35
  except Exception as e:
 
51
  url = "https://tw.stock.yahoo.com/class/"
52
  response = requests.get(url, headers=headers, timeout=10)
53
  response.raise_for_status()
 
54
  soup = BeautifulSoup(response.text, 'html.parser')
55
  main_categories = soup.find_all('div', class_='C($c-link-text)')
 
56
  data = []
57
  for category in main_categories:
58
  main_category_name = category.find('h2', class_="Fw(b) Fz(24px) Lh(32px)")
59
  if main_category_name:
60
  main_category_name = main_category_name.text.strip()
61
  sub_categories = category.find_all('a', class_='Fz(16px) Lh(1.5) C($c-link-text) C($c-active-text):h Fw(b):h Td(n)')
 
62
  for sub_category in sub_categories:
63
  data.append({
64
  '台股': main_category_name,
65
  '類股': sub_category.text.strip(),
66
  '網址': "https://tw.stock.yahoo.com" + sub_category['href']
67
  })
 
68
  category_dict = {}
69
  for item in data:
70
  if item['台股'] not in category_dict:
71
  category_dict[item['台股']] = []
72
  category_dict[item['台股']].append({'類股': item['類股'], '網址': item['網址']})
 
73
  return category_dict
74
  except Exception as e:
75
  logging.error(f"獲取股票類別失敗: {str(e)}")
76
  return {}
77
 
78
+ # 股票預測模型類別
79
  class StockPredictor:
80
  def __init__(self):
81
+ self.lstm_model = None
82
+ self.prophet_model = None
83
  self.scaler = MinMaxScaler()
84
+
85
+ def prepare_data(self, df, selected_features):
86
+ scaled_data = self.scaler.fit_transform(df[selected_features])
 
87
  X, y = [], []
88
  for i in range(len(scaled_data) - 1):
89
  X.append(scaled_data[i])
90
+ y.append(scaled_data[i+1])
91
+ return np.array(X).reshape(-1, 1, len(selected_features)), np.array(y)
 
92
 
93
  def build_model(self, input_shape):
94
  model = Sequential([
 
96
  Dropout(0.2),
97
  LSTM(50, activation='relu'),
98
  Dropout(0.2),
99
+ Dense(input_shape[1])
100
  ])
101
  model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
102
  return model
103
 
104
+ def train(self, df, selected_features):
105
+ X, y = self.prepare_data(df, selected_features)
106
+ self.lstm_model = self.build_model((1, X.shape[2]))
107
+ history = self.lstm_model.fit(
108
  X, y,
109
  epochs=50,
110
  batch_size=32,
 
116
  def predict(self, last_data, n_days):
117
  predictions = []
118
  current_data = last_data.copy()
 
119
  for _ in range(n_days):
120
+ next_day = self.lstm_model.predict(current_data.reshape(1, 1, -1), verbose=0)
121
  predictions.append(next_day[0])
 
122
  current_data = current_data.flatten()
123
+ current_data[:len(next_day[0])] = next_day[0]
 
124
  current_data = current_data.reshape(1, -1)
 
125
  return np.array(predictions)
126
+
127
+ def train_prophet(self, df, target_column='Close'):
128
+ df_prophet = df.reset_index()[['Date', target_column]].rename(columns={'Date': 'ds', target_column: 'y'})
129
+ self.prophet_model = Prophet()
130
+ self.prophet_model.fit(df_prophet)
131
+
132
+ def predict_prophet(self, df, days=5):
133
+ if self.prophet_model is None:
134
+ raise ValueError("Prophet model has not been trained yet.")
135
+
136
+ future = self.prophet_model.make_future_dataframe(periods=days)
137
+ forecast = self.prophet_model.predict(future)
138
+ return forecast[['ds', 'yhat']].tail(days)
139
 
140
  # Gradio界面函數
141
  def update_stocks(category):
 
147
  try:
148
  response = requests.get(url, headers=headers, timeout=10)
149
  response.raise_for_status()
 
150
  soup = BeautifulSoup(response.text, 'html.parser')
151
  stock_items = soup.find_all('li', class_='List(n)')
 
152
  stocks_dict = {}
153
  for item in stock_items:
154
  stock_name = item.find('div', class_='Lh(20px) Fw(600) Fz(16px) Ell')
 
158
  display_code = full_code.split('.')[0]
159
  display_name = f"{stock_name.text.strip()}{display_code}"
160
  stocks_dict[display_name] = full_code
 
161
  return stocks_dict
162
  except Exception as e:
163
  logging.error(f"獲取股票項目失敗: {str(e)}")
 
168
  return {
169
  stock_dropdown: gr.update(choices=stocks, value=None),
170
  stock_item_dropdown: gr.update(choices=[], value=None),
171
+ stock_plot: gr.update(value=None),
172
+ status_output: gr.update(value="")
173
  }
174
 
175
  def update_stock(category, stock):
176
  if not category or not stock:
177
  return {
178
  stock_item_dropdown: gr.update(choices=[], value=None),
179
+ stock_plot: gr.update(value=None),
180
+ status_output: gr.update(value="")
181
  }
 
182
  url = next((item['網址'] for item in category_dict.get(category, [])
183
  if item['類股'] == stock), None)
 
184
  if url:
185
  stock_items = get_stock_items(url)
186
  return {
187
  stock_item_dropdown: gr.update(choices=list(stock_items.keys()), value=None),
188
+ stock_plot: gr.update(value=None),
189
+ status_output: gr.update(value="")
190
  }
191
  return {
192
  stock_item_dropdown: gr.update(choices=[], value=None),
193
+ stock_plot: gr.update(value=None),
194
+ status_output: gr.update(value="")
195
  }
196
 
197
+ def predict_stock(category, stock, stock_item, period, selected_features, model_type):
198
  if not all([category, stock, stock_item]):
199
+ return gr.update(value=None), "請選擇產業類別、類股和股票"
 
200
  try:
201
  url = next((item['網址'] for item in category_dict.get(category, [])
202
+ if item['類股'] == stock), None)
203
  if not url:
204
+ return gr.update(value=None), "無法獲取類股網址"
 
205
  stock_items = get_stock_items(url)
206
  stock_code = stock_items.get(stock_item, "")
 
207
  if not stock_code:
208
+ return gr.update(value=None), "無法獲取股票代碼"
209
 
210
+ # 下載股票數據,根據用戶選擇的時間範圍
211
  df = yf.download(stock_code, period=period)
212
  if df.empty:
213
  raise ValueError("無法獲取股票數據")
214
 
215
  # 預測
216
+ predictor = StockPredictor()
217
  if model_type == "LSTM":
218
+ predictor.train(df, selected_features)
219
+ last_data = predictor.scaler.transform(df[selected_features].iloc[-1:].values)
 
220
  predictions = predictor.predict(last_data[0], 5)
221
 
222
  # 反轉預測結果
223
+ last_original = df[selected_features].iloc[-1].values
224
  predictions_original = predictor.scaler.inverse_transform(
225
+ np.vstack([last_data, predictions])
226
+ )
227
+ all_predictions = np.vstack([last_original, predictions_original[1:]])
 
228
  elif model_type == "Prophet":
229
+ predictor.train_prophet(df, target_column=selected_features[0]) # 使用第一個特徵作為預測目標
230
+ predictions = predictor.predict_prophet(df, days=5)
231
+ all_predictions = predictions['yhat'].values
232
+
 
 
 
 
 
233
  # 創建日期索引
234
  dates = [datetime.now() + timedelta(days=i) for i in range(6)]
235
  date_labels = [d.strftime('%m/%d') for d in dates]
236
+
237
  # 繪圖
238
  fig, ax = plt.subplots(figsize=(14, 7))
239
+ if model_type == "LSTM":
240
+ colors = ['#FF9999', '#66B2FF']
241
+ labels = [f'預測{feature}' for feature in selected_features]
242
+ for i, (label, color) in enumerate(zip(labels, colors)):
243
+ ax.plot(date_labels, all_predictions[:, i], label=label,
244
+ marker='o', color=color, linewidth=2)
 
 
 
 
 
 
 
245
  for j, value in enumerate(all_predictions[:, i]):
246
  ax.annotate(f'{value:.2f}', (date_labels[j], value),
247
+ textcoords="offset points", xytext=(0,10),
248
+ ha='center', va='bottom')
249
+ elif model_type == "Prophet":
250
+ ax.plot(date_labels, all_predictions, label='預測',
251
+ marker='o', color='#FF9999', linewidth=2)
252
+ for j, value in enumerate(all_predictions):
253
+ ax.annotate(f'{value:.2f}', (date_labels[j], value),
254
+ textcoords="offset points", xytext=(0,10),
255
+ ha='center', va='bottom')
256
+
257
  ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
258
  ax.set_xlabel('日期', labelpad=10)
259
  ax.set_ylabel('股價', labelpad=10)
260
  ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
261
  ax.grid(True, linestyle='--', alpha=0.7)
 
262
  plt.tight_layout()
263
  return gr.update(value=fig), "預測成功"
 
264
  except Exception as e:
265
  logging.error(f"預測過程發生錯誤: {str(e)}")
266
  return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
 
291
  value=None
292
  )
293
  period_dropdown = gr.Dropdown(
294
+ choices=["1y", "6mo", "3mo", "1mo"],
295
  label="抓取時間範圍",
296
  value="1y"
297
  )
298
+ features_checkbox = gr.CheckboxGroup(
299
  choices=['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume'],
300
  label="選擇要用於預測的特徵",
301
  value=['Open', 'Close']
302
  )
303
+ model_type_radio = gr.Radio(
304
  choices=["LSTM", "Prophet"],
305
+ label="選擇模型類型",
306
  value="LSTM"
307
  )
308
  predict_button = gr.Button("開始預測", variant="primary")
309
+ status_output = gr.Textbox(label="狀態", interactive=False)
310
+ with gr.Row():
311
+ stock_plot = gr.Plot(label="股價預測圖")
312
+
 
313
  # 事件綁定
314
  category_dropdown.change(
315
  update_category,
316
  inputs=[category_dropdown],
317
+ outputs=[stock_dropdown, stock_item_dropdown, stock_plot, status_output]
318
  )
 
319
  stock_dropdown.change(
320
  update_stock,
321
  inputs=[category_dropdown, stock_dropdown],
322
+ outputs=[stock_item_dropdown, stock_plot, status_output]
323
  )
 
324
  predict_button.click(
325
  predict_stock,
326
+ inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkbox, model_type_radio],
327
+ outputs=[stock_plot, status_output]
328
  )
329
 
330
  # 啟動應用
331
  if __name__ == "__main__":
332
+ demo.launch(share=False)