File size: 6,100 Bytes
1e2adc2
 
a45959e
1e2adc2
 
 
 
 
 
 
 
 
0091ee4
1e2adc2
ead29e0
fc08f05
ff03388
ead29e0
 
 
 
9c0b92a
ead29e0
ff03388
ead29e0
 
ff03388
ead29e0
9c0b92a
843b7e5
6d62996
843b7e5
1d22072
6d62996
ead29e0
 
 
 
ff03388
 
426f91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e2adc2
 
 
 
a53f74c
1e2adc2
 
 
 
a53f74c
1e2adc2
 
 
 
 
 
 
 
 
a53f74c
1e2adc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06892b5
fc08f05
1e2adc2
 
 
 
 
fc08f05
1e2adc2
ead29e0
 
3c0fe99
 
ead29e0
 
fc08f05
 
 
 
 
 
 
1e2adc2
3c0fe99
53beba1
3c0fe99
53beba1
 
3c0fe99
53beba1
 
 
 
 
 
fc08f05
1e2adc2
 
 
3c0fe99
 
 
 
 
 
 
 
1e2adc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import pandas as pd
import requests
import yfinance as yf
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame
import gradio as gr

# Function to fetch stock data
def get_stock_data(ticker, period):
    data = yf.download(ticker, period=period)
    return data

# Function to prepare the data for Chronos-Bolt

def prepare_data_chronos(data):
    # 重設索引並準備數據
    df = data.reset_index()
    
    # 創建符合官方格式的數據框
    formatted_df = pd.DataFrame({
        'item_id': ['stock'] * len(df),
        'timestamp': pd.to_datetime(df['Date']),
        'target': df['Close'].astype('float32').values.ravel()  # 改回使用 'target' 而不是 'value'
    })
    
    # 按照時間戳排序
    formatted_df = formatted_df.sort_values('timestamp')
    
    try:
        # 創建 TimeSeriesDataFrame
        ts_df = TimeSeriesDataFrame.from_data_frame(
            formatted_df,
            id_column='item_id',
            timestamp_column='timestamp'
        )
        return ts_df
    except Exception as e:
        print(f"Error creating TimeSeriesDataFrame: {str(e)}")
        raise


# def prepare_data_chronos(data):
#     # 直接使用收盤價序列
#     series = pd.Series(
#         data['Close'].values,
#         index=data.index,
#         name='value'
#     )
    
#     # 創建基本的時間序列數據框
#     df = pd.DataFrame({
#         'timestamp': series.index,
#         'value': series.values,
#         'item_id': ['stock'] * len(series)
#     })
    
#     return TimeSeriesDataFrame(df)


# Function to fetch stock indices (you already defined these)
def get_tw0050_stocks():
    response = requests.get('https://answerbook.david888.com/TW0050')
    data = response.json()
    return [f"{code}.TW" for code in data['TW0050'].keys()]

def get_sp500_stocks(limit=50):
    response = requests.get('https://answerbook.david888.com/SP500')
    data = response.json()
    return list(data['SP500'].keys())[:limit]

def get_nasdaq_stocks(limit=50):
    response = requests.get('http://13.125.121.198:8090/stocks/NASDAQ100')
    data = response.json()
    return list(data['stocks'].keys())[:limit]

def get_tw0051_stocks():
    response = requests.get('https://answerbook.david888.com/TW0051')
    data = response.json()
    return [f"{code}.TW" for code in data['TW0051'].keys()]

def get_sox_stocks():
    return [
        "NVDA", "AVGO", "GFS", "CRUS", "ON", "ASML", "QCOM", "SWKS", "MPWR", "ADI",
        "TSM", "AMD", "TXN", "QRVO", "AMKR", "MU", "ARM", "NXPI", "TER", "ENTG",
        "LSCC", "COHR", "ONTO", "MTSI", "KLAC", "LRCX", "MRVL", "AMAT", "INTC", "MCHP"
    ]

def get_dji_stocks():
    response = requests.get('http://13.125.121.198:8090/stocks/DOWJONES')
    data = response.json()
    return list(data['stocks'].keys())

# Function to get top 10 potential stocks
def get_top_10_potential_stocks(period, selected_indices):
    stock_list = []
    if "\u53f0\u706350" in selected_indices:
        stock_list += get_tw0050_stocks()
    if "\u53f0\u7063\u4e2d\u578b100" in selected_indices:
        stock_list += get_tw0051_stocks()
    if "S&P\u7cbe\u7c21\u724850" in selected_indices:
        stock_list += get_sp500_stocks()
    if "NASDAQ\u7cbe\u7c21\u724850" in selected_indices:
        stock_list += get_nasdaq_stocks()
    if "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX" in selected_indices:
        stock_list += get_sox_stocks()
    if "\u9053\u74b0DJI" in selected_indices:
        stock_list += get_dji_stocks()

    stock_predictions = []
    prediction_length = 2
    
    for ticker in stock_list:
        try:
            data = get_stock_data(ticker, period)
            if data.empty:
                continue
                
            ts_data = prepare_data_chronos(data)
            predictor = TimeSeriesPredictor(
                prediction_length=prediction_length,
                freq="D",
                target="target"
            )
            
            predictor.fit(
                ts_data,
                hyperparameters={
                    "Chronos": {"model_path": "autogluon/chronos-bolt-base"}
                }
            )
            
            predictions = predictor.predict(ts_data)
            
            # 修改這部分以使用最高預測值
            last_actual = float(data['Close'].iloc[-1])
            highest_pred = float(predictions.values.max())  # 找出預測序列中的最高值
            potential = (highest_pred - last_actual) / last_actual
            
            stock_predictions.append((
                ticker, 
                potential, 
                last_actual, 
                highest_pred  # 這裡也改為顯示最高預測值
            ))
            
        except Exception as e:
            print(f"Stock {ticker} error: {str(e)}")
            continue
    
    # 確保所有值都是基本數據類型
    top_10_stocks = sorted(
        [(str(t), float(p), float(c), float(pred)) for t, p, c, pred in stock_predictions],
        key=lambda x: x[1],
        reverse=True
    )[:10]
    
    return top_10_stocks

# Gradio interface function
def stock_prediction_app(period, selected_indices):
    top_10_stocks = get_top_10_potential_stocks(period, selected_indices)
    df = pd.DataFrame(top_10_stocks, columns=["\u80a1\u7968\u4ee3\u865f", "\u6f5b\u529b (\u767e\u5206\u6bd4)", "\u73fe\u50f9", "\u9810\u6e2c\u50f9\u683c"])
    return df

# Define Gradio interface
inputs = [
    gr.Dropdown(choices=["3mo", "6mo", "9mo", "1yr"], label="\u6642\u9593\u7bc4\u570d"),
    gr.CheckboxGroup(choices=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100", "S&P\u7cbe\u7c21\u724850", "NASDAQ\u7cbe\u7c21\u724850", "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX", "\u9053\u74b0DJI"], label="\u6307\u6578\u9078\u64c7", value=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100"])
]
outputs = gr.Dataframe(label="\u6f5b\u529b\u80a1\u63a8\u85a6\u7d50\u679c")

gr.Interface(fn=stock_prediction_app, inputs=inputs, outputs=outputs, title="\u53f0\u80a1\u7f8e\u80a1\u6f5b\u529b\u80a1\u63a8\u85a6\u7cfb\u7d71 - Chronos-Bolt\u6a21\u578b").launch()