tbdavid2019 commited on
Commit
73cc4bb
·
1 Parent(s): 999c140
Files changed (1) hide show
  1. app.py +81 -100
app.py CHANGED
@@ -7,6 +7,7 @@ from sklearn.preprocessing import MinMaxScaler
7
  from tensorflow.keras.models import Sequential
8
  from tensorflow.keras.layers import LSTM, Dense, Dropout
9
  from tensorflow.keras.optimizers import Adam
 
10
  import matplotlib.pyplot as plt
11
  import io
12
  import matplotlib as mpl
@@ -16,9 +17,8 @@ import os
16
  import yfinance as yf
17
  import logging
18
  from datetime import datetime, timedelta
19
- from prophet import Prophet
20
 
21
- # 設置日誌
22
  logging.basicConfig(level=logging.INFO,
23
  format='%(asctime)s - %(levelname)s - %(message)s')
24
 
@@ -82,21 +82,21 @@ def fetch_stock_categories():
82
  logging.error(f"獲取股票類別失敗: {str(e)}")
83
  return {}
84
 
85
- # 股票預測模型類別
86
  class StockPredictor:
87
  def __init__(self):
88
  self.model = None
89
  self.scaler = MinMaxScaler()
90
 
91
- def prepare_data(self, df, selected_features):
92
- scaled_data = self.scaler.fit_transform(df[selected_features])
93
 
94
  X, y = [], []
95
  for i in range(len(scaled_data) - 1):
96
  X.append(scaled_data[i])
97
- y.append(scaled_data[i+1])
98
 
99
- return np.array(X).reshape(-1, 1, len(selected_features)), np.array(y)
100
 
101
  def build_model(self, input_shape):
102
  model = Sequential([
@@ -104,13 +104,13 @@ class StockPredictor:
104
  Dropout(0.2),
105
  LSTM(50, activation='relu'),
106
  Dropout(0.2),
107
- Dense(input_shape[1])
108
  ])
109
  model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
110
  return model
111
 
112
- def train(self, df, selected_features):
113
- X, y = self.prepare_data(df, selected_features)
114
  self.model = self.build_model((1, X.shape[2]))
115
  history = self.model.fit(
116
  X, y,
@@ -130,7 +130,8 @@ class StockPredictor:
130
  predictions.append(next_day[0])
131
 
132
  current_data = current_data.flatten()
133
- current_data[:len(next_day[0])] = next_day[0]
 
134
  current_data = current_data.reshape(1, -1)
135
 
136
  return np.array(predictions)
@@ -169,16 +170,14 @@ def update_category(category):
169
  return {
170
  stock_dropdown: gr.update(choices=stocks, value=None),
171
  stock_item_dropdown: gr.update(choices=[], value=None),
172
- stock_plot: gr.update(value=None),
173
- status_output: gr.update(value="")
174
  }
175
 
176
  def update_stock(category, stock):
177
  if not category or not stock:
178
  return {
179
  stock_item_dropdown: gr.update(choices=[], value=None),
180
- stock_plot: gr.update(value=None),
181
- status_output: gr.update(value="")
182
  }
183
 
184
  url = next((item['網址'] for item in category_dict.get(category, [])
@@ -188,105 +187,87 @@ def update_stock(category, stock):
188
  stock_items = get_stock_items(url)
189
  return {
190
  stock_item_dropdown: gr.update(choices=list(stock_items.keys()), value=None),
191
- stock_plot: gr.update(value=None),
192
- status_output: gr.update(value="")
193
  }
194
  return {
195
  stock_item_dropdown: gr.update(choices=[], value=None),
196
- stock_plot: gr.update(value=None),
197
- status_output: gr.update(value="")
198
  }
199
 
200
- def predict_stock(category, stock, stock_item, period, selected_features, model_choice):
201
  if not all([category, stock, stock_item]):
202
- return gr.update(value=None), "請選擇產業類別、類股和股票"
203
 
204
  try:
205
  url = next((item['網址'] for item in category_dict.get(category, [])
206
  if item['類股'] == stock), None)
207
  if not url:
208
- return gr.update(value=None), "無法獲取類股網址"
209
 
210
  stock_items = get_stock_items(url)
211
  stock_code = stock_items.get(stock_item, "")
212
-
213
  if not stock_code:
214
- return gr.update(value=None), "無法獲取股票代碼"
215
-
216
- # 下載股票數據,根據用戶選擇的時間範圍
217
- df = yf.download(stock_code, period=period)
218
  if df.empty:
219
  raise ValueError("無法獲取股票數據")
220
 
221
- # 根據模型選擇進行預測
222
- if model_choice == "LSTM":
223
  predictor = StockPredictor()
224
- predictor.train(df, selected_features)
225
- last_data = predictor.scaler.transform(df[selected_features].iloc[-1:].values)
226
  predictions = predictor.predict(last_data[0], 5)
227
-
228
  # 反轉預測結果
229
- last_original = df[selected_features].iloc[-1].values
230
  predictions_original = predictor.scaler.inverse_transform(
231
- np.vstack([last_data, predictions])
232
- )
233
- all_predictions = np.vstack([last_original, predictions_original[1:]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- # 創建日期索引
236
- dates = [datetime.now() + timedelta(days=i) for i in range(6)]
237
- date_labels = [d.strftime('%m/%d') for d in dates]
238
-
239
- # 繪圖
240
- fig, ax = plt.subplots(figsize=(14, 7))
241
- for i, feature in enumerate(selected_features):
242
- ax.plot(date_labels, all_predictions[:, i], label=f'預測{feature}', marker='o', linewidth=2)
243
- for j, value in enumerate(all_predictions[:, i]):
244
- ax.annotate(f'{value:.2f}', (date_labels[j], value),
245
- textcoords="offset points", xytext=(0,10),
246
- ha='center', va='bottom')
247
-
248
- ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
249
- ax.set_xlabel('日期', labelpad=10)
250
- ax.set_ylabel('股價', labelpad=10)
251
- ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
252
- ax.grid(True, linestyle='--', alpha=0.7)
253
- plt.tight_layout()
254
- return gr.update(value=fig), "預測成功"
255
-
256
- elif model_choice == "Prophet":
257
- if 'Close' not in selected_features:
258
- return gr.update(value=None), "Prophet 模型僅支持 'Close' 特徵"
259
-
260
- prophet_df = df.reset_index()[['Date', 'Close']]
261
- prophet_df.rename(columns={'Date': 'ds', 'Close': 'y'}, inplace=True)
262
-
263
- model = Prophet()
264
- model.fit(prophet_df)
265
-
266
- future = model.make_future_dataframe(periods=5)
267
- forecast = model.predict(future)
268
-
269
- # 取出日期和預測結果
270
- date_labels = forecast['ds'].tail(6).dt.strftime('%m/%d').tolist()
271
- predictions = forecast['yhat'].tail(6).values
272
-
273
- # 繪圖
274
- fig, ax = plt.subplots(figsize=(14, 7))
275
- ax.plot(date_labels, predictions, label="預測股價", marker='o', color='#FF9999', linewidth=2)
276
- ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
277
- ax.set_xlabel('日期', labelpad=10)
278
- ax.set_ylabel('股價', labelpad=10)
279
- ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
280
- ax.grid(True, linestyle='--', alpha=0.7)
281
- plt.tight_layout()
282
- return gr.update(value=fig), "預測成功"
283
-
284
- else:
285
- return gr.update(value=None), "未知的模型選擇"
286
-
287
  except Exception as e:
288
  logging.error(f"預測過程發生錯誤: {str(e)}")
289
- return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
290
 
291
  # 初始化
292
  setup_font()
@@ -314,43 +295,43 @@ with gr.Blocks() as demo:
314
  value=None
315
  )
316
  period_dropdown = gr.Dropdown(
317
- choices=["1y", "6mo", "3mo", "1mo"],
318
  label="抓取時間範圍",
319
  value="1y"
320
  )
321
- features_checkbox = gr.CheckboxGroup(
322
  choices=['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume'],
323
  label="選擇要用於預測的特徵",
324
  value=['Open', 'Close']
325
  )
326
- model_dropdown = gr.Dropdown(
327
  choices=["LSTM", "Prophet"],
328
  label="選擇預測模型",
329
  value="LSTM"
330
  )
331
  predict_button = gr.Button("開始預測", variant="primary")
332
- status_output = gr.Textbox(label="狀態", interactive=False)
333
 
334
  with gr.Row():
335
  stock_plot = gr.Plot(label="股價預測圖")
336
-
 
337
  # 事件綁定
338
  category_dropdown.change(
339
  update_category,
340
  inputs=[category_dropdown],
341
- outputs=[stock_dropdown, stock_item_dropdown, stock_plot, status_output]
342
  )
343
-
344
  stock_dropdown.change(
345
  update_stock,
346
  inputs=[category_dropdown, stock_dropdown],
347
- outputs=[stock_item_dropdown, stock_plot, status_output]
348
  )
349
-
350
  predict_button.click(
351
  predict_stock,
352
- inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkbox, model_dropdown],
353
- outputs=[stock_plot, status_output]
354
  )
355
 
356
  # 啟動應用
 
7
  from tensorflow.keras.models import Sequential
8
  from tensorflow.keras.layers import LSTM, Dense, Dropout
9
  from tensorflow.keras.optimizers import Adam
10
+ from prophet import Prophet
11
  import matplotlib.pyplot as plt
12
  import io
13
  import matplotlib as mpl
 
17
  import yfinance as yf
18
  import logging
19
  from datetime import datetime, timedelta
 
20
 
21
+ # 設置日志
22
  logging.basicConfig(level=logging.INFO,
23
  format='%(asctime)s - %(levelname)s - %(message)s')
24
 
 
82
  logging.error(f"獲取股票類別失敗: {str(e)}")
83
  return {}
84
 
85
+ # 股票預測模型類別保持不變...
86
  class StockPredictor:
87
  def __init__(self):
88
  self.model = None
89
  self.scaler = MinMaxScaler()
90
 
91
+ def prepare_data(self, df, features):
92
+ scaled_data = self.scaler.fit_transform(df[features])
93
 
94
  X, y = [], []
95
  for i in range(len(scaled_data) - 1):
96
  X.append(scaled_data[i])
97
+ y.append(scaled_data[i+1, [0, 3]]) # Open和Close的索引
98
 
99
+ return np.array(X).reshape(-1, 1, len(features)), np.array(y)
100
 
101
  def build_model(self, input_shape):
102
  model = Sequential([
 
104
  Dropout(0.2),
105
  LSTM(50, activation='relu'),
106
  Dropout(0.2),
107
+ Dense(2)
108
  ])
109
  model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
110
  return model
111
 
112
+ def train(self, df, features):
113
+ X, y = self.prepare_data(df, features)
114
  self.model = self.build_model((1, X.shape[2]))
115
  history = self.model.fit(
116
  X, y,
 
130
  predictions.append(next_day[0])
131
 
132
  current_data = current_data.flatten()
133
+ current_data[0] = next_day[0][0]
134
+ current_data[3] = next_day[0][1]
135
  current_data = current_data.reshape(1, -1)
136
 
137
  return np.array(predictions)
 
170
  return {
171
  stock_dropdown: gr.update(choices=stocks, value=None),
172
  stock_item_dropdown: gr.update(choices=[], value=None),
173
+ stock_plot: gr.update(value=None)
 
174
  }
175
 
176
  def update_stock(category, stock):
177
  if not category or not stock:
178
  return {
179
  stock_item_dropdown: gr.update(choices=[], value=None),
180
+ stock_plot: gr.update(value=None)
 
181
  }
182
 
183
  url = next((item['網址'] for item in category_dict.get(category, [])
 
187
  stock_items = get_stock_items(url)
188
  return {
189
  stock_item_dropdown: gr.update(choices=list(stock_items.keys()), value=None),
190
+ stock_plot: gr.update(value=None)
 
191
  }
192
  return {
193
  stock_item_dropdown: gr.update(choices=[], value=None),
194
+ stock_plot: gr.update(value=None)
 
195
  }
196
 
197
+ def predict_stock(category, stock, stock_item, features, model_type):
198
  if not all([category, stock, stock_item]):
199
+ return gr.update(value=None)
200
 
201
  try:
202
  url = next((item['網址'] for item in category_dict.get(category, [])
203
  if item['類股'] == stock), None)
204
  if not url:
205
+ return gr.update(value=None)
206
 
207
  stock_items = get_stock_items(url)
208
  stock_code = stock_items.get(stock_item, "")
209
+
210
  if not stock_code:
211
+ return gr.update(value=None)
212
+
213
+ # 下載股票數據
214
+ df = yf.download(stock_code, period="1y")
215
  if df.empty:
216
  raise ValueError("無法獲取股票數據")
217
 
218
+ # 預測
219
+ if model_type == "LSTM":
220
  predictor = StockPredictor()
221
+ predictor.train(df, features)
222
+ last_data = predictor.scaler.transform(df.iloc[-1:][features])
223
  predictions = predictor.predict(last_data[0], 5)
224
+
225
  # 反轉預測結果
226
+ last_original = df[features].iloc[-1].values
227
  predictions_original = predictor.scaler.inverse_transform(
228
+ np.hstack([predictions, np.zeros((predictions.shape[0], len(features) - 2))])
229
+ )[:, :2]
230
+ all_predictions = np.vstack([last_original, predictions_original])
231
+
232
+ elif model_type == "Prophet":
233
+ prophet_df = df.reset_index()[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'})
234
+ m = Prophet()
235
+ m.fit(prophet_df)
236
+ future = m.make_future_dataframe(periods=5)
237
+ forecast = m.predict(future)
238
+ all_predictions = forecast[['ds', 'yhat']].tail(6)
239
+ date_labels = all_predictions['ds'].dt.strftime('%m/%d').tolist()
240
+ all_predictions = all_predictions['yhat'].values
241
+
242
+ # 創建日期索引
243
+ dates = [datetime.now() + timedelta(days=i) for i in range(6)]
244
+ date_labels = [d.strftime('%m/%d') for d in dates]
245
+
246
+ # 繪圖
247
+ fig, ax = plt.subplots(figsize=(14, 7))
248
+ colors = ['#FF9999', '#66B2FF']
249
+ labels = ['預測開盤價', '預測收盤價']
250
+
251
+ for i, (label, color) in enumerate(zip(labels, colors)):
252
+ ax.plot(date_labels, all_predictions if model_type == "Prophet" else all_predictions[:, i],
253
+ label=label, marker='o', color=color, linewidth=2)
254
+ for j, value in enumerate(all_predictions if model_type == "Prophet" else all_predictions[:, i]):
255
+ ax.annotate(f'{value:.2f}', (date_labels[j], value),
256
+ textcoords="offset points", xytext=(0,10),
257
+ ha='center', va='bottom')
258
+
259
+ ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
260
+ ax.set_xlabel('日期', labelpad=10)
261
+ ax.set_ylabel('股價', labelpad=10)
262
+ ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
263
+ ax.grid(True, linestyle='--', alpha=0.7)
264
+
265
+ plt.tight_layout()
266
+ return gr.update(value=fig)
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  except Exception as e:
269
  logging.error(f"預測過程發生錯誤: {str(e)}")
270
+ return gr.update(value=None)
271
 
272
  # 初始化
273
  setup_font()
 
295
  value=None
296
  )
297
  period_dropdown = gr.Dropdown(
298
+ choices=["1mo", "3mo", "6mo", "1y"],
299
  label="抓取時間範圍",
300
  value="1y"
301
  )
302
+ features_checkboxes = gr.CheckboxGroup(
303
  choices=['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume'],
304
  label="選擇要用於預測的特徵",
305
  value=['Open', 'Close']
306
  )
307
+ model_type_dropdown = gr.Dropdown(
308
  choices=["LSTM", "Prophet"],
309
  label="選擇預測模型",
310
  value="LSTM"
311
  )
312
  predict_button = gr.Button("開始預測", variant="primary")
 
313
 
314
  with gr.Row():
315
  stock_plot = gr.Plot(label="股價預測圖")
316
+ status_textbox = gr.Textbox(label="狀態", value="")
317
+
318
  # 事件綁定
319
  category_dropdown.change(
320
  update_category,
321
  inputs=[category_dropdown],
322
+ outputs=[stock_dropdown, stock_item_dropdown, stock_plot]
323
  )
324
+
325
  stock_dropdown.change(
326
  update_stock,
327
  inputs=[category_dropdown, stock_dropdown],
328
+ outputs=[stock_item_dropdown, stock_plot]
329
  )
330
+
331
  predict_button.click(
332
  predict_stock,
333
+ inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, features_checkboxes, model_type_dropdown],
334
+ outputs=[stock_plot, status_textbox]
335
  )
336
 
337
  # 啟動應用