tbdavid2019 commited on
Commit
851795f
·
1 Parent(s): 40f675a
Files changed (1) hide show
  1. app.py +64 -63
app.py CHANGED
@@ -7,15 +7,15 @@ from sklearn.preprocessing import MinMaxScaler
7
  from tensorflow.keras.models import Sequential
8
  from tensorflow.keras.layers import LSTM, Dense, Dropout
9
  from tensorflow.keras.optimizers import Adam
10
- import matplotlib.pyplot as plt
11
- import io
12
- import matplotlib as mpl
13
- import matplotlib.font_manager as fm
14
- import tempfile
15
- import os
16
  import yfinance as yf
17
  import logging
18
- from datetime import datetime, timedelta
 
 
 
19
 
20
  # 設置日志
21
  logging.basicConfig(level=logging.INFO,
@@ -81,22 +81,21 @@ def fetch_stock_categories():
81
  logging.error(f"獲取股票類別失敗: {str(e)}")
82
  return {}
83
 
84
- # 股票預測模型類別保持不變...
85
  class StockPredictor:
86
  def __init__(self):
87
  self.model = None
88
  self.scaler = MinMaxScaler()
89
 
90
- def prepare_data(self, df):
91
- features = ['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume']
92
- scaled_data = self.scaler.fit_transform(df[features])
93
 
94
  X, y = [], []
95
  for i in range(len(scaled_data) - 1):
96
  X.append(scaled_data[i])
97
- y.append(scaled_data[i+1, [0, 3]]) # Open和Close的索引
98
 
99
- return np.array(X).reshape(-1, 1, len(features)), np.array(y)
100
 
101
  def build_model(self, input_shape):
102
  model = Sequential([
@@ -104,13 +103,13 @@ class StockPredictor:
104
  Dropout(0.2),
105
  LSTM(50, activation='relu'),
106
  Dropout(0.2),
107
- Dense(2)
108
  ])
109
  model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
110
  return model
111
 
112
- def train(self, df):
113
- X, y = self.prepare_data(df)
114
  self.model = self.build_model((1, X.shape[2]))
115
  history = self.model.fit(
116
  X, y,
@@ -129,10 +128,7 @@ class StockPredictor:
129
  next_day = self.model.predict(current_data.reshape(1, 1, -1), verbose=0)
130
  predictions.append(next_day[0])
131
 
132
- current_data = current_data.flatten()
133
- current_data[0] = next_day[0][0]
134
- current_data[3] = next_day[0][1]
135
- current_data = current_data.reshape(1, -1)
136
 
137
  return np.array(predictions)
138
 
@@ -170,14 +166,16 @@ def update_category(category):
170
  return {
171
  stock_dropdown: gr.update(choices=stocks, value=None),
172
  stock_item_dropdown: gr.update(choices=[], value=None),
173
- stock_plot: gr.update(value=None)
 
174
  }
175
 
176
  def update_stock(category, stock):
177
  if not category or not stock:
178
  return {
179
  stock_item_dropdown: gr.update(choices=[], value=None),
180
- stock_plot: gr.update(value=None)
 
181
  }
182
 
183
  url = next((item['網址'] for item in category_dict.get(category, [])
@@ -187,77 +185,69 @@ def update_stock(category, stock):
187
  stock_items = get_stock_items(url)
188
  return {
189
  stock_item_dropdown: gr.update(choices=list(stock_items.keys()), value=None),
190
- stock_plot: gr.update(value=None)
 
191
  }
192
  return {
193
  stock_item_dropdown: gr.update(choices=[], value=None),
194
- stock_plot: gr.update(value=None)
 
195
  }
196
 
197
- def predict_stock(category, stock, stock_item):
198
  if not all([category, stock, stock_item]):
199
- return gr.update(value=None)
200
 
201
  try:
202
  url = next((item['網址'] for item in category_dict.get(category, [])
203
  if item['類股'] == stock), None)
204
  if not url:
205
- return gr.update(value=None)
206
 
207
  stock_items = get_stock_items(url)
208
  stock_code = stock_items.get(stock_item, "")
209
 
210
  if not stock_code:
211
- return gr.update(value=None)
212
 
213
- # 下載股票數據
214
  df = yf.download(stock_code, period="1y")
215
  if df.empty:
216
  raise ValueError("無法獲取股票數據")
217
 
218
  # 預測
219
  predictor = StockPredictor()
220
- predictor.train(df)
221
 
222
- last_data = predictor.scaler.transform(df.iloc[-1:][['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume']])
223
  predictions = predictor.predict(last_data[0], 5)
224
 
225
- # 反轉預測結果
226
- last_original = df[['Open', 'Close']].iloc[-1].values
227
- predictions_original = predictor.scaler.inverse_transform(
228
- np.hstack([predictions, np.zeros((predictions.shape[0], 4))])
229
- )[:, :2]
230
- all_predictions = np.vstack([last_original, predictions_original])
231
-
232
  # 創建日期指標
233
  dates = [datetime.now() + timedelta(days=i) for i in range(6)]
234
  date_labels = [d.strftime('%m/%d') for d in dates]
235
 
236
- # 繪圖
237
- fig, ax = plt.subplots(figsize=(14, 7))
238
- colors = ['#FF9999', '#66B2FF']
239
- labels = ['預測開盤價', '預測收盤價']
240
-
241
- for i, (col, label, color) in enumerate(zip(['Open', 'Close'], labels, colors)):
242
- ax.plot(date_labels, all_predictions[:, i], label=label,
243
- marker='o', color=color, linewidth=2)
244
- for j, value in enumerate(all_predictions[:, i]):
245
- ax.annotate(f'{value:.2f}', (date_labels[j], value),
246
- textcoords="offset points", xytext=(0,10),
247
- ha='center', va='bottom')
248
 
249
- ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
250
- ax.set_xlabel('日期', labelpad=10)
251
- ax.set_ylabel('股價', labelpad=10)
252
- ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
253
- ax.grid(True, linestyle='--', alpha=0.7)
 
254
 
255
- plt.tight_layout()
256
- return gr.update(value=fig)
257
 
258
  except Exception as e:
259
  logging.error(f"預測過程發生錯誤: {str(e)}")
260
- return gr.update(value=None)
261
 
262
  # 初始化
263
  setup_font()
@@ -284,28 +274,39 @@ with gr.Blocks() as demo:
284
  label="股票",
285
  value=None
286
  )
 
 
 
 
 
 
 
 
 
 
287
  predict_button = gr.Button("開始預測", variant="primary")
 
288
 
289
  with gr.Row():
290
- stock_plot = gr.Plot(label="股價預測圖")
291
 
292
  # 事件綁定
293
  category_dropdown.change(
294
  update_category,
295
  inputs=[category_dropdown],
296
- outputs=[stock_dropdown, stock_item_dropdown, stock_plot]
297
  )
298
 
299
  stock_dropdown.change(
300
  update_stock,
301
  inputs=[category_dropdown, stock_dropdown],
302
- outputs=[stock_item_dropdown, stock_plot]
303
  )
304
 
305
  predict_button.click(
306
  predict_stock,
307
- inputs=[category_dropdown, stock_dropdown, stock_item_dropdown],
308
- outputs=[stock_plot]
309
  )
310
 
311
  # 啟動應用
 
7
  from tensorflow.keras.models import Sequential
8
  from tensorflow.keras.layers import LSTM, Dense, Dropout
9
  from tensorflow.keras.optimizers import Adam
10
+ from datetime import datetime, timedelta
11
+ import plotly.graph_objs as go
12
+ import plotly.io as pio
 
 
 
13
  import yfinance as yf
14
  import logging
15
+ import tempfile
16
+ import os
17
+ import matplotlib as mpl
18
+ import matplotlib.font_manager as fm
19
 
20
  # 設置日志
21
  logging.basicConfig(level=logging.INFO,
 
81
  logging.error(f"獲取股票類別失敗: {str(e)}")
82
  return {}
83
 
84
+ # 股票預測模型類別
85
  class StockPredictor:
86
  def __init__(self):
87
  self.model = None
88
  self.scaler = MinMaxScaler()
89
 
90
+ def prepare_data(self, df, selected_features):
91
+ scaled_data = self.scaler.fit_transform(df[selected_features])
 
92
 
93
  X, y = [], []
94
  for i in range(len(scaled_data) - 1):
95
  X.append(scaled_data[i])
96
+ y.append(scaled_data[i+1])
97
 
98
+ return np.array(X).reshape(-1, 1, len(selected_features)), np.array(y)
99
 
100
  def build_model(self, input_shape):
101
  model = Sequential([
 
103
  Dropout(0.2),
104
  LSTM(50, activation='relu'),
105
  Dropout(0.2),
106
+ Dense(input_shape[1])
107
  ])
108
  model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
109
  return model
110
 
111
+ def train(self, df, selected_features):
112
+ X, y = self.prepare_data(df, selected_features)
113
  self.model = self.build_model((1, X.shape[2]))
114
  history = self.model.fit(
115
  X, y,
 
128
  next_day = self.model.predict(current_data.reshape(1, 1, -1), verbose=0)
129
  predictions.append(next_day[0])
130
 
131
+ current_data = next_day
 
 
 
132
 
133
  return np.array(predictions)
134
 
 
166
  return {
167
  stock_dropdown: gr.update(choices=stocks, value=None),
168
  stock_item_dropdown: gr.update(choices=[], value=None),
169
+ stock_plot: gr.update(value=None),
170
+ status_output: gr.update(value="")
171
  }
172
 
173
  def update_stock(category, stock):
174
  if not category or not stock:
175
  return {
176
  stock_item_dropdown: gr.update(choices=[], value=None),
177
+ stock_plot: gr.update(value=None),
178
+ status_output: gr.update(value="")
179
  }
180
 
181
  url = next((item['網址'] for item in category_dict.get(category, [])
 
185
  stock_items = get_stock_items(url)
186
  return {
187
  stock_item_dropdown: gr.update(choices=list(stock_items.keys()), value=None),
188
+ stock_plot: gr.update(value=None),
189
+ status_output: gr.update(value="")
190
  }
191
  return {
192
  stock_item_dropdown: gr.update(choices=[], value=None),
193
+ stock_plot: gr.update(value=None),
194
+ status_output: gr.update(value="")
195
  }
196
 
197
+ def predict_stock(category, stock, stock_item, selected_features):
198
  if not all([category, stock, stock_item]):
199
+ return gr.update(value=None), "請選擇產業類別、類股和股票"
200
 
201
  try:
202
  url = next((item['網址'] for item in category_dict.get(category, [])
203
  if item['類股'] == stock), None)
204
  if not url:
205
+ return gr.update(value=None), "無法獲取類股網址"
206
 
207
  stock_items = get_stock_items(url)
208
  stock_code = stock_items.get(stock_item, "")
209
 
210
  if not stock_code:
211
+ return gr.update(value=None), "無法獲取股票代碼"
212
 
213
+ # 下載股票數據,根據用戶選擇的時間範圍
214
  df = yf.download(stock_code, period="1y")
215
  if df.empty:
216
  raise ValueError("無法獲取股票數據")
217
 
218
  # 預測
219
  predictor = StockPredictor()
220
+ predictor.train(df, selected_features)
221
 
222
+ last_data = predictor.scaler.transform(df.iloc[-1:][selected_features])
223
  predictions = predictor.predict(last_data[0], 5)
224
 
 
 
 
 
 
 
 
225
  # 創建日期指標
226
  dates = [datetime.now() + timedelta(days=i) for i in range(6)]
227
  date_labels = [d.strftime('%m/%d') for d in dates]
228
 
229
+ # 用 Plotly 繪圖
230
+ fig = go.Figure()
231
+ for i, feature in enumerate(selected_features):
232
+ fig.add_trace(go.Scatter(
233
+ x=date_labels,
234
+ y=np.hstack([df[feature].iloc[-1], predictions[:, i]]),
235
+ mode='lines+markers',
236
+ name=f'預測{feature}'
237
+ ))
 
 
 
238
 
239
+ fig.update_layout(
240
+ title=f'{stock_item} 股價預測 (未來5天)',
241
+ xaxis_title='日期',
242
+ yaxis_title='股價',
243
+ template='plotly_dark'
244
+ )
245
 
246
+ return gr.update(value=pio.to_html(fig, full_html=False)), "預測成功"
 
247
 
248
  except Exception as e:
249
  logging.error(f"預測過程發生錯誤: {str(e)}")
250
+ return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
251
 
252
  # 初始化
253
  setup_font()
 
274
  label="股票",
275
  value=None
276
  )
277
+ period_dropdown = gr.Dropdown(
278
+ choices=["1y", "6mo", "3mo", "1mo"],
279
+ label="抓取時間範圍",
280
+ value="1y"
281
+ )
282
+ features_checkbox = gr.CheckboxGroup(
283
+ choices=['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume'],
284
+ label="選擇要用於預測的特徵",
285
+ value=['Open', 'Close']
286
+ )
287
  predict_button = gr.Button("開始預測", variant="primary")
288
+ status_output = gr.Textbox(label="狀態", interactive=False)
289
 
290
  with gr.Row():
291
+ stock_plot = gr.HTML(label="股價預測圖")
292
 
293
  # 事件綁定
294
  category_dropdown.change(
295
  update_category,
296
  inputs=[category_dropdown],
297
+ outputs=[stock_dropdown, stock_item_dropdown, stock_plot, status_output]
298
  )
299
 
300
  stock_dropdown.change(
301
  update_stock,
302
  inputs=[category_dropdown, stock_dropdown],
303
+ outputs=[stock_item_dropdown, stock_plot, status_output]
304
  )
305
 
306
  predict_button.click(
307
  predict_stock,
308
+ inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, features_checkbox],
309
+ outputs=[stock_plot, status_output]
310
  )
311
 
312
  # 啟動應用