tbdavid2019 commited on
Commit
a2bf6d0
·
1 Parent(s): 6171f23
Files changed (1) hide show
  1. app.py +9 -43
app.py CHANGED
@@ -9,7 +9,6 @@ from tensorflow.keras.layers import LSTM, Dense, Dropout
9
  from tensorflow.keras.optimizers import Adam
10
  from datetime import datetime, timedelta
11
  import plotly.graph_objs as go
12
- import plotly.io as pio
13
  import yfinance as yf
14
  import logging
15
  import tempfile
@@ -17,7 +16,7 @@ import os
17
  import matplotlib as mpl
18
  import matplotlib.font_manager as fm
19
 
20
- # 設置日志
21
  logging.basicConfig(level=logging.INFO,
22
  format='%(asctime)s - %(levelname)s - %(message)s')
23
 
@@ -161,54 +160,21 @@ def get_stock_items(url):
161
  logging.error(f"獲取股票項目失敗: {str(e)}")
162
  return {}
163
 
164
- def update_category(category):
165
- stocks = update_stocks(category)
166
- return {
167
- stock_dropdown: gr.update(choices=stocks, value=None),
168
- stock_item_dropdown: gr.update(choices=[], value=None),
169
- stock_plot: gr.update(value=None),
170
- status_output: gr.update(value="")
171
- }
172
-
173
- def update_stock(category, stock):
174
- if not category or not stock:
175
- return {
176
- stock_item_dropdown: gr.update(choices=[], value=None),
177
- stock_plot: gr.update(value=None),
178
- status_output: gr.update(value="")
179
- }
180
-
181
- url = next((item['網址'] for item in category_dict.get(category, [])
182
- if item['類股'] == stock), None)
183
-
184
- if url:
185
- stock_items = get_stock_items(url)
186
- return {
187
- stock_item_dropdown: gr.update(choices=list(stock_items.keys()), value=None),
188
- stock_plot: gr.update(value=None),
189
- status_output: gr.update(value="")
190
- }
191
- return {
192
- stock_item_dropdown: gr.update(choices=[], value=None),
193
- stock_plot: gr.update(value=None),
194
- status_output: gr.update(value="")
195
- }
196
-
197
  def predict_stock(category, stock, stock_item, period, selected_features):
198
  if not all([category, stock, stock_item]):
199
- return gr.update(value=None), "請選擇產業類別、類股和股票"
200
 
201
  try:
202
  url = next((item['網址'] for item in category_dict.get(category, [])
203
  if item['類股'] == stock), None)
204
  if not url:
205
- return gr.update(value=None), "無法獲取類股網址"
206
 
207
  stock_items = get_stock_items(url)
208
  stock_code = stock_items.get(stock_item, "")
209
 
210
  if not stock_code:
211
- return gr.update(value=None), "無法獲取股票代碼"
212
 
213
  # 下載股票數據,根據用戶選擇的時間範圍
214
  df = yf.download(stock_code, period=period)
@@ -235,7 +201,7 @@ def predict_stock(category, stock, stock_item, period, selected_features):
235
  mode='lines+markers',
236
  name=f'預測{feature}'
237
  ))
238
-
239
  fig.update_layout(
240
  title=f'{stock_item} 股價預測 (未來5天)',
241
  xaxis_title='日期',
@@ -243,11 +209,11 @@ def predict_stock(category, stock, stock_item, period, selected_features):
243
  template='plotly_dark'
244
  )
245
 
246
- return gr.update(value=pio.to_html(fig, full_html=False)), "預測成功"
247
 
248
  except Exception as e:
249
  logging.error(f"預測過程發生錯誤: {str(e)}")
250
- return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
251
 
252
  # 初始化
253
  setup_font()
@@ -288,7 +254,7 @@ with gr.Blocks() as demo:
288
  status_output = gr.Textbox(label="狀態", interactive=False)
289
 
290
  with gr.Row():
291
- stock_plot = gr.HTML(label="股價預測圖")
292
 
293
  # 事件綁定
294
  category_dropdown.change(
@@ -311,4 +277,4 @@ with gr.Blocks() as demo:
311
 
312
  # 啟動應用
313
  if __name__ == "__main__":
314
- demo.launch(share=False)
 
9
  from tensorflow.keras.optimizers import Adam
10
  from datetime import datetime, timedelta
11
  import plotly.graph_objs as go
 
12
  import yfinance as yf
13
  import logging
14
  import tempfile
 
16
  import matplotlib as mpl
17
  import matplotlib.font_manager as fm
18
 
19
+ # 設置日誌
20
  logging.basicConfig(level=logging.INFO,
21
  format='%(asctime)s - %(levelname)s - %(message)s')
22
 
 
160
  logging.error(f"獲取股票項目失敗: {str(e)}")
161
  return {}
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def predict_stock(category, stock, stock_item, period, selected_features):
164
  if not all([category, stock, stock_item]):
165
+ return None, "請選擇產業類別、類股和股票"
166
 
167
  try:
168
  url = next((item['網址'] for item in category_dict.get(category, [])
169
  if item['類股'] == stock), None)
170
  if not url:
171
+ return None, "無法獲取類股網址"
172
 
173
  stock_items = get_stock_items(url)
174
  stock_code = stock_items.get(stock_item, "")
175
 
176
  if not stock_code:
177
+ return None, "無法獲取股票代碼"
178
 
179
  # 下載股票數據,根據用戶選擇的時間範圍
180
  df = yf.download(stock_code, period=period)
 
201
  mode='lines+markers',
202
  name=f'預測{feature}'
203
  ))
204
+
205
  fig.update_layout(
206
  title=f'{stock_item} 股價預測 (未來5天)',
207
  xaxis_title='日期',
 
209
  template='plotly_dark'
210
  )
211
 
212
+ return fig, "預測成功"
213
 
214
  except Exception as e:
215
  logging.error(f"預測過程發生錯誤: {str(e)}")
216
+ return None, f"預測過程發生錯誤: {str(e)}"
217
 
218
  # 初始化
219
  setup_font()
 
254
  status_output = gr.Textbox(label="狀態", interactive=False)
255
 
256
  with gr.Row():
257
+ stock_plot = gr.Plot(label="股價預測圖")
258
 
259
  # 事件綁定
260
  category_dropdown.change(
 
277
 
278
  # 啟動應用
279
  if __name__ == "__main__":
280
+ demo.launch(share=False)