tbdavid2019 commited on
Commit
7b04205
·
1 Parent(s): a33a42d

擴增到美股

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. .gitignore +2 -0
  3. app.py +63 -41
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ myenv/*
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ myenv/*
2
+ myenv/**/*
app.py CHANGED
@@ -55,58 +55,78 @@ def predict_stock(model, data, scaler, time_step=60):
55
 
56
  return predicted_prices
57
 
58
- # Function to fetch all Taiwan listed stocks (Taiwan 50 and Small 100)
59
- def get_all_taiwan_stocks():
60
  return [
61
- "2345.TW", "3017.TW", "2454.TW", "2330.TW", "2395.TW", "2379.TW", "2890.TW", "2891.TW", "2882.TW", "3008.TW",
62
- "5871.TW", "2303.TW", "2357.TW", "3661.TW", "2883.TW", "3711.TW", "2308.TW", "2885.TW", "2603.TW", "2881.TW",
63
- "4904.TW", "2887.TW", "2880.TW", "2301.TW", "1216.TW", "2884.TW", "2327.TW", "6669.TW", "5880.TW", "1303.TW",
64
- "3045.TW", "2382.TW", "2912.TW", "4938.TW", "3231.TW", "2892.TW", "2317.TW", "1590.TW", "3034.TW", "2002.TW",
65
- "2412.TW", "2207.TW", "3037.TW", "1301.TW", "1326.TW", "1101.TW", "6446.TW", "2886.TW", "6505.TW", "5876.TW",
66
- "3533.TW", "9904.TW", "2618.TW", "5522.TW", "2360.TW", "6005.TW", "3653.TW", "2368.TW", "2474.TW", "6285.TW",
67
- "1519.TW", "9958.TW", "3044.TW", "1476.TW", "2324.TW", "2059.TW", "1402.TW", "6176.TW", "1102.TW", "2353.TW",
68
- "3665.TW", "2049.TW", "2609.TW", "2478.TW", "9945.TW", "1216.TW", "8464.TW", "3702.TW", "2801.TW", "6526.TW",
69
- "2845.TW", "2834.TW", "2610.TW", "4763.TW", "4958.TW", "9941.TW", "6239.TW", "2915.TW", "9917.TW", "2376.TW",
70
- "1451.TW", "2313.TW", "3051.TW", "9914.TW", "1477.TW", "2377.TW", "2206.TW", "1504.TW", "2912.TW", "6409.TW",
71
- "1560.TW", "1503.TW", "2615.TW", "3005.TW", "2204.TW", "3532.TW", "2888.TW", "2449.TW", "6789.TW", "3481.TW",
72
- "2409.TW", "2385.TW", "3406.TW", "2352.TW", "2207.TW", "2347.TW", "6531.TW", "9910.TW", "2371.TW", "2356.TW",
73
- "2492.TW", "1718.TW", "6890.TW", "8454.TW", "2633.TW", "1802.TW", "2006.TW", "2542.TW", "1513.TW", "1907.TW",
74
- "1722.TW", "2809.TW", "1319.TW", "4137.TW", "2388.TW", "2812.TW", "2540.TW", "3035.TW", "2354.TW", "2027.TW",
75
- "1229.TW", "2105.TW", "2408.TW", "5269.TW", "2344.TW", "3443.TW", "6415.TW", "9921.TW", "3036.TW", "6592.TW",
76
- "6472.TW", "3023.TW", "6770.TW", "1795.TW", "2201.TW", "1605.TW", "8046.TW", "2312.TW", "2359.TW", "2337.TW"
 
77
  ]
78
 
79
  # Function to get top 10 potential stocks
80
- def get_top_10_potential_stocks(period):
81
- stock_list = get_all_taiwan_stocks()
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  stock_predictions = []
83
-
 
84
  for ticker in stock_list:
85
  data = get_stock_data(ticker, period)
86
- if data.empty:
 
87
  continue
88
-
89
- # Prepare data
90
- X_train, y_train, scaler = prepare_data(data)
91
-
92
- # Train model
93
- model = train_lstm_model(X_train, y_train)
94
-
95
- # Predict future prices
96
- predicted_prices = predict_stock(model, data, scaler)
97
-
98
- # Calculate the potential (e.g., last predicted price vs last actual price)
99
- potential = (predicted_prices[-1] - data['Close'].values[-1]) / data['Close'].values[-1]
100
- stock_predictions.append((ticker, potential, data['Close'].values[-1], predicted_prices[-1][0]))
101
-
 
 
 
 
102
  # Sort by potential and get top 10
103
  top_10_stocks = sorted(stock_predictions, key=lambda x: x[1], reverse=True)[:10]
104
  return top_10_stocks
105
 
106
  # Gradio interface function
107
- def stock_prediction_app(period):
108
  # Get top 10 potential stocks
109
- top_10_stocks = get_top_10_potential_stocks(period)
110
 
111
  # Create a dataframe for display
112
  df = pd.DataFrame(top_10_stocks, columns=["股票代號", "潛力 (百分比)", "現價", "預測價格"])
@@ -114,9 +134,11 @@ def stock_prediction_app(period):
114
  return df
115
 
116
  # Define Gradio interface
117
- inputs = gr.Dropdown(choices=["1mo", "3mo", "6mo", "9mo", "1yr"], label="時間範圍")
 
 
 
118
  outputs = gr.Dataframe(label="潛力股推薦結果")
119
 
120
- gr.Interface(fn=stock_prediction_app, inputs=inputs, outputs=outputs, title="台股潛力股推薦系統 - LSTM模型")\
121
  .launch()
122
-
 
55
 
56
  return predicted_prices
57
 
58
+ # Function to fetch Taiwan 50 and Small 100 stocks
59
+ def get_tw0050_stocks():
60
  return [
61
+ "2330.TW", "2317.TW", "2454.TW", "2308.TW", "2881.TW", "2382.TW", "2303.TW", "2882.TW", "2891.TW", "3711.TW",
62
+ "2412.TW", "2886.TW", "2884.TW", "1216.TW", "2357.TW", "2885.TW", "2892.TW", "3034.TW", "2890.TW", "2327.TW",
63
+ "5880.TW", "2345.TW", "3231.TW", "2002.TW", "2880.TW", "3008.TW", "2883.TW", "1303.TW", "4938.TW", "2207.TW",
64
+ "2887.TW", "2379.TW", "1101.TW", "2603.TW", "2301.TW", "1301.TW", "5871.TW", "3037.TW", "3045.TW", "2912.TW",
65
+ "3017.TW", "6446.TW", "4904.TW", "3661.TW", "6669.TW", "1326.TW", "5876.TW", "2395.TW", "1590.TW", "6505.TW"
66
+ ]
67
+
68
+ def get_tw0051_stocks():
69
+ return [
70
+ "2371.TW", "3533.TW", "2618.TW", "3443.TW", "2347.TW", "3044.TW", "2834.TW", "2385.TW", "1605.TW", "2105.TW",
71
+ "6239.TW", "6176.TW", "9904.TW", "1519.TW", "9910.TW", "1513.TW", "1229.TW", "9945.TW", "2313.TW", "1477.TW",
72
+ "3665.TW", "2354.TW", "4958.TW", "8464.TW", "9921.TW", "2812.TW", "2059.TW", "1504.TW", "2542.TW", "6770.TW",
73
+ "5269.TW", "2344.TW", "3023.TW", "1503.TW", "2049.TW", "2610.TW", "2633.TW", "3036.TW", "2368.TW", "3035.TW",
74
+ "2027.TW", "9914.TW", "2408.TW", "2809.TW", "1319.TW", "2352.TW", "2337.TW", "2006.TW", "2206.TW", "4763.TW",
75
+ "3005.TW", "1907.TW", "2915.TW", "1722.TW", "6285.TW", "6472.TW", "6531.TW", "3406.TW", "9958.TW", "9941.TW",
76
+ "1795.TW", "2201.TW", "9917.TW", "2492.TW", "6890.TW", "2845.TW", "8454.TW", "8046.TW", "6789.TW", "2388.TW",
77
+ "6526.TW", "1802.TW", "5522.TW", "6592.TW", "2204.TW", "2540.TW", "2539.TW", "3532.TW"
78
  ]
79
 
80
  # Function to get top 10 potential stocks
81
+ def get_top_10_potential_stocks(period, selected_indices):
82
+ stock_list = []
83
+ if "tw0050台灣50" in selected_indices:
84
+ stock_list += get_tw0050_stocks()
85
+ if "tw0051中型100" in selected_indices:
86
+ stock_list += get_tw0051_stocks()
87
+ if "S&P" in selected_indices:
88
+ stock_list.append("^GSPC")
89
+ if "NASDAQ" in selected_indices:
90
+ stock_list.append("^IXIC")
91
+ if "費城半導體" in selected_indices:
92
+ stock_list.append("^SOX")
93
+ if "道瓊" in selected_indices:
94
+ stock_list.append("^DJI")
95
+
96
  stock_predictions = []
97
+ time_step = 60
98
+
99
  for ticker in stock_list:
100
  data = get_stock_data(ticker, period)
101
+ if data.empty or len(data) < time_step:
102
+ # 如果數據為空或不足以生成訓練樣本,則跳過該股票
103
  continue
104
+
105
+ try:
106
+ # Prepare data
107
+ X_train, y_train, scaler = prepare_data(data, time_step=time_step)
108
+
109
+ # Train model
110
+ model = train_lstm_model(X_train, y_train)
111
+
112
+ # Predict future prices
113
+ predicted_prices = predict_stock(model, data, scaler, time_step=time_step)
114
+
115
+ # Calculate the potential (e.g., last predicted price vs last actual price)
116
+ potential = (predicted_prices[-1] - data['Close'].values[-1]) / data['Close'].values[-1]
117
+ stock_predictions.append((ticker, potential, data['Close'].values[-1], predicted_prices[-1][0]))
118
+ except Exception as e:
119
+ print(f"股票 {ticker} 發生錯誤: {str(e)}")
120
+ continue
121
+
122
  # Sort by potential and get top 10
123
  top_10_stocks = sorted(stock_predictions, key=lambda x: x[1], reverse=True)[:10]
124
  return top_10_stocks
125
 
126
  # Gradio interface function
127
+ def stock_prediction_app(period, selected_indices):
128
  # Get top 10 potential stocks
129
+ top_10_stocks = get_top_10_potential_stocks(period, selected_indices)
130
 
131
  # Create a dataframe for display
132
  df = pd.DataFrame(top_10_stocks, columns=["股票代號", "潛力 (百分比)", "現價", "預測價格"])
 
134
  return df
135
 
136
  # Define Gradio interface
137
+ inputs = [
138
+ gr.Dropdown(choices=["3mo", "6mo", "9mo", "1yr"], label="時間範圍"),
139
+ gr.CheckboxGroup(choices=["tw0050台灣50", "tw0051中型100", "S&P", "NASDAQ", "費城半導體", "道瓊"], label="指數選擇")
140
+ ]
141
  outputs = gr.Dataframe(label="潛力股推薦結果")
142
 
143
+ gr.Interface(fn=stock_prediction_app, inputs=inputs, outputs=outputs, title="潛力股推薦系統 - LSTM模型")\
144
  .launch()