File size: 6,100 Bytes
1e2adc2 a45959e 1e2adc2 0091ee4 1e2adc2 ead29e0 fc08f05 ff03388 ead29e0 9c0b92a ead29e0 ff03388 ead29e0 ff03388 ead29e0 9c0b92a 843b7e5 6d62996 843b7e5 1d22072 6d62996 ead29e0 ff03388 426f91f 1e2adc2 a53f74c 1e2adc2 a53f74c 1e2adc2 a53f74c 1e2adc2 06892b5 fc08f05 1e2adc2 fc08f05 1e2adc2 ead29e0 3c0fe99 ead29e0 fc08f05 1e2adc2 3c0fe99 53beba1 3c0fe99 53beba1 3c0fe99 53beba1 fc08f05 1e2adc2 3c0fe99 1e2adc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import pandas as pd
import requests
import yfinance as yf
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame
import gradio as gr
# Function to fetch stock data
def get_stock_data(ticker, period):
data = yf.download(ticker, period=period)
return data
# Function to prepare the data for Chronos-Bolt
def prepare_data_chronos(data):
# 重設索引並準備數據
df = data.reset_index()
# 創建符合官方格式的數據框
formatted_df = pd.DataFrame({
'item_id': ['stock'] * len(df),
'timestamp': pd.to_datetime(df['Date']),
'target': df['Close'].astype('float32').values.ravel() # 改回使用 'target' 而不是 'value'
})
# 按照時間戳排序
formatted_df = formatted_df.sort_values('timestamp')
try:
# 創建 TimeSeriesDataFrame
ts_df = TimeSeriesDataFrame.from_data_frame(
formatted_df,
id_column='item_id',
timestamp_column='timestamp'
)
return ts_df
except Exception as e:
print(f"Error creating TimeSeriesDataFrame: {str(e)}")
raise
# def prepare_data_chronos(data):
# # 直接使用收盤價序列
# series = pd.Series(
# data['Close'].values,
# index=data.index,
# name='value'
# )
# # 創建基本的時間序列數據框
# df = pd.DataFrame({
# 'timestamp': series.index,
# 'value': series.values,
# 'item_id': ['stock'] * len(series)
# })
# return TimeSeriesDataFrame(df)
# Function to fetch stock indices (you already defined these)
def get_tw0050_stocks():
response = requests.get('https://answerbook.david888.com/TW0050')
data = response.json()
return [f"{code}.TW" for code in data['TW0050'].keys()]
def get_sp500_stocks(limit=50):
response = requests.get('https://answerbook.david888.com/SP500')
data = response.json()
return list(data['SP500'].keys())[:limit]
def get_nasdaq_stocks(limit=50):
response = requests.get('http://13.125.121.198:8090/stocks/NASDAQ100')
data = response.json()
return list(data['stocks'].keys())[:limit]
def get_tw0051_stocks():
response = requests.get('https://answerbook.david888.com/TW0051')
data = response.json()
return [f"{code}.TW" for code in data['TW0051'].keys()]
def get_sox_stocks():
return [
"NVDA", "AVGO", "GFS", "CRUS", "ON", "ASML", "QCOM", "SWKS", "MPWR", "ADI",
"TSM", "AMD", "TXN", "QRVO", "AMKR", "MU", "ARM", "NXPI", "TER", "ENTG",
"LSCC", "COHR", "ONTO", "MTSI", "KLAC", "LRCX", "MRVL", "AMAT", "INTC", "MCHP"
]
def get_dji_stocks():
response = requests.get('http://13.125.121.198:8090/stocks/DOWJONES')
data = response.json()
return list(data['stocks'].keys())
# Function to get top 10 potential stocks
def get_top_10_potential_stocks(period, selected_indices):
stock_list = []
if "\u53f0\u706350" in selected_indices:
stock_list += get_tw0050_stocks()
if "\u53f0\u7063\u4e2d\u578b100" in selected_indices:
stock_list += get_tw0051_stocks()
if "S&P\u7cbe\u7c21\u724850" in selected_indices:
stock_list += get_sp500_stocks()
if "NASDAQ\u7cbe\u7c21\u724850" in selected_indices:
stock_list += get_nasdaq_stocks()
if "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX" in selected_indices:
stock_list += get_sox_stocks()
if "\u9053\u74b0DJI" in selected_indices:
stock_list += get_dji_stocks()
stock_predictions = []
prediction_length = 2
for ticker in stock_list:
try:
data = get_stock_data(ticker, period)
if data.empty:
continue
ts_data = prepare_data_chronos(data)
predictor = TimeSeriesPredictor(
prediction_length=prediction_length,
freq="D",
target="target"
)
predictor.fit(
ts_data,
hyperparameters={
"Chronos": {"model_path": "autogluon/chronos-bolt-base"}
}
)
predictions = predictor.predict(ts_data)
# 修改這部分以使用最高預測值
last_actual = float(data['Close'].iloc[-1])
highest_pred = float(predictions.values.max()) # 找出預測序列中的最高值
potential = (highest_pred - last_actual) / last_actual
stock_predictions.append((
ticker,
potential,
last_actual,
highest_pred # 這裡也改為顯示最高預測值
))
except Exception as e:
print(f"Stock {ticker} error: {str(e)}")
continue
# 確保所有值都是基本數據類型
top_10_stocks = sorted(
[(str(t), float(p), float(c), float(pred)) for t, p, c, pred in stock_predictions],
key=lambda x: x[1],
reverse=True
)[:10]
return top_10_stocks
# Gradio interface function
def stock_prediction_app(period, selected_indices):
top_10_stocks = get_top_10_potential_stocks(period, selected_indices)
df = pd.DataFrame(top_10_stocks, columns=["\u80a1\u7968\u4ee3\u865f", "\u6f5b\u529b (\u767e\u5206\u6bd4)", "\u73fe\u50f9", "\u9810\u6e2c\u50f9\u683c"])
return df
# Define Gradio interface
inputs = [
gr.Dropdown(choices=["3mo", "6mo", "9mo", "1yr"], label="\u6642\u9593\u7bc4\u570d"),
gr.CheckboxGroup(choices=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100", "S&P\u7cbe\u7c21\u724850", "NASDAQ\u7cbe\u7c21\u724850", "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX", "\u9053\u74b0DJI"], label="\u6307\u6578\u9078\u64c7", value=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100"])
]
outputs = gr.Dataframe(label="\u6f5b\u529b\u80a1\u63a8\u85a6\u7d50\u679c")
gr.Interface(fn=stock_prediction_app, inputs=inputs, outputs=outputs, title="\u53f0\u80a1\u7f8e\u80a1\u6f5b\u529b\u80a1\u63a8\u85a6\u7cfb\u7d71 - Chronos-Bolt\u6a21\u578b").launch()
|