tbdavid2019 commited on
Commit
47cec11
·
1 Parent(s): 9c0b92a
Files changed (1) hide show
  1. app2.py +162 -0
app2.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import requests
3
+ import yfinance as yf
4
+ from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame
5
+ import gradio as gr
6
+
7
+ # Function to fetch stock data
8
+ def get_stock_data(ticker, period):
9
+ data = yf.download(ticker, period=period)
10
+ return data
11
+
12
+ # Function to prepare the data for Chronos-Bolt
13
+ def prepare_data_chronos(data):
14
+ # Reset index and prepare data
15
+ df = data.reset_index()
16
+
17
+ # Create a DataFrame in the format expected by AutoGluon TimeSeries
18
+ formatted_df = pd.DataFrame({
19
+ 'item_id': ['stock'] * len(df),
20
+ 'timestamp': pd.to_datetime(df['Date']),
21
+ 'target': df['Close'].astype('float32').values.ravel()
22
+ })
23
+
24
+ # Sort by timestamp
25
+ formatted_df = formatted_df.sort_values('timestamp')
26
+
27
+ try:
28
+ # Create a TimeSeriesDataFrame without specifying target_column
29
+ ts_df = TimeSeriesDataFrame.from_data_frame(
30
+ formatted_df,
31
+ id_column='item_id',
32
+ timestamp_column='timestamp'
33
+ )
34
+ return ts_df
35
+ except Exception as e:
36
+ print(f"Error creating TimeSeriesDataFrame: {str(e)}")
37
+ raise
38
+
39
+ # Functions to fetch stock indices
40
+ def get_tw0050_stocks():
41
+ response = requests.get('https://answerbook.david888.com/TW0050')
42
+ data = response.json()
43
+ return [f"{code}.TW" for code in data['TW0050'].keys()]
44
+
45
+ def get_sp500_stocks(limit=50):
46
+ response = requests.get('https://answerbook.david888.com/SP500')
47
+ data = response.json()
48
+ return list(data['SP500'].keys())[:limit]
49
+
50
+ def get_nasdaq_stocks(limit=50):
51
+ response = requests.get('http://13.125.121.198:8090/stocks/NASDAQ100')
52
+ data = response.json()
53
+ return list(data['stocks'].keys())[:limit]
54
+
55
+ def get_tw0051_stocks():
56
+ response = requests.get('https://answerbook.david888.com/TW0051')
57
+ data = response.json()
58
+ return [f"{code}.TW" for code in data['TW0051'].keys()]
59
+
60
+ def get_sox_stocks():
61
+ return [
62
+ "NVDA", "AVGO", "GFS", "CRUS", "ON", "ASML", "QCOM", "SWKS", "MPWR", "ADI",
63
+ "TSM", "AMD", "TXN", "QRVO", "AMKR", "MU", "ARM", "NXPI", "TER", "ENTG",
64
+ "LSCC", "COHR", "ONTO", "MTSI", "KLAC", "LRCX", "MRVL", "AMAT", "INTC", "MCHP"
65
+ ]
66
+
67
+ def get_dji_stocks():
68
+ response = requests.get('http://13.125.121.198:8090/stocks/DOWJONES')
69
+ data = response.json()
70
+ return list(data['stocks'].keys())
71
+
72
+ # Function to get top 10 potential stocks
73
+ def get_top_10_potential_stocks(period, selected_indices):
74
+ stock_list = []
75
+ if "\u53f0\u706350" in selected_indices:
76
+ stock_list += get_tw0050_stocks()
77
+ if "\u53f0\u7063\u4e2d\u578b100" in selected_indices:
78
+ stock_list += get_tw0051_stocks()
79
+ if "S&P\u7cbe\u7c21\u724850" in selected_indices:
80
+ stock_list += get_sp500_stocks()
81
+ if "NASDAQ\u7cbe\u7c21\u724850" in selected_indices:
82
+ stock_list += get_nasdaq_stocks()
83
+ if "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX" in selected_indices:
84
+ stock_list += get_sox_stocks()
85
+ if "\u9053\u74b0DJI" in selected_indices:
86
+ stock_list += get_dji_stocks()
87
+
88
+ stock_predictions = []
89
+ prediction_length = 10
90
+
91
+ for ticker in stock_list:
92
+ try:
93
+ data = get_stock_data(ticker, period)
94
+ if data.empty:
95
+ continue
96
+
97
+ ts_data = prepare_data_chronos(data)
98
+
99
+ # Create a TimeSeriesPredictor for daily data
100
+ predictor = TimeSeriesPredictor(
101
+ prediction_length=prediction_length,
102
+ freq="1D"
103
+ )
104
+ predictor.fit(
105
+ ts_data,
106
+ hyperparameters={
107
+ "Chronos": {"model_path": "autogluon/chronos-bolt-base"}
108
+ }
109
+ )
110
+
111
+ predictions = predictor.predict(ts_data)
112
+ # Calculate potential as (prediction - last_close) / last_close
113
+ potential = (predictions.iloc[-1] - data['Close'].iloc[-1]) / data['Close'].iloc[-1]
114
+ stock_predictions.append((ticker, potential, data['Close'].iloc[-1], predictions.iloc[-1]))
115
+
116
+ except Exception as e:
117
+ print(f"Stock {ticker} error: {str(e)}")
118
+ continue
119
+
120
+ # Sort stocks by potential in descending order, take top 10
121
+ top_10_stocks = sorted(stock_predictions, key=lambda x: x[1], reverse=True)[:10]
122
+ return top_10_stocks
123
+
124
+ # Gradio interface function
125
+ def stock_prediction_app(period, selected_indices):
126
+ top_10_stocks = get_top_10_potential_stocks(period, selected_indices)
127
+ df = pd.DataFrame(top_10_stocks, columns=[
128
+ "\u80a1\u7968\u4ee3\u865f", # Ticker
129
+ "\u6f5b\u529b (\u767e\u5206\u6bd4)", # Potential
130
+ "\u73fe\u50f9", # Current Price
131
+ "\u9810\u6e2c\u50f9\u683c" # Predicted Price
132
+ ])
133
+ return df
134
+
135
+ # Define Gradio interface
136
+ inputs = [
137
+ gr.Dropdown(choices=["3mo", "6mo", "9mo", "1yr"], label="\u6642\u9593\u7bc4\u570d"),
138
+ gr.CheckboxGroup(
139
+ choices=[
140
+ "\u53f0\u706350", # 台灣50
141
+ "\u53f0\u7063\u4e2d\u578b100", # 台灣中型100
142
+ "S&P\u7cbe\u7c21\u724850", # S&P精簡版50
143
+ "NASDAQ\u7cbe\u7c21\u724850", # NASDAQ精簡版50
144
+ "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX", # 費城半導體
145
+ "\u9053\u74b0DJI" # 道瓊DJI
146
+ ],
147
+ label="\u6307\u6578\u9078\u64c7",
148
+ value=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100"]
149
+ )
150
+ ]
151
+
152
+ outputs = gr.Dataframe(label="\u6f5b\u529b\u80a1\u63a8\u85a6\u7d50\u679c")
153
+
154
+ app = gr.Interface(
155
+ fn=stock_prediction_app,
156
+ inputs=inputs,
157
+ outputs=outputs,
158
+ title="\u53f0\u80a1\u7f8e\u80a1\u6f5b\u529b\u80a1\u63a8\u85a6\u7cfb\u7d71 - Chronos-Bolt\u6a21\u578b"
159
+ )
160
+
161
+ if __name__ == "__main__":
162
+ app.launch()