parkerjj commited on
Commit
ac60dc3
·
1 Parent(s): 298d9ab

feat: Add data source configuration and implement yfinance integration for US stock indices

Browse files

- Created a new configuration file `data_source_config.py` to manage data source versioning between yfinance and akshare.
- Implemented `preprocess_yfinance.py` to initialize stock index data using yfinance, including error handling and data formatting.
- Developed `us_stock_yfinance.py` to fetch stock data, manage stock indices, and handle real-time price retrieval with caching and retry mechanisms.
- Added functions for processing stock history and extracting sentiment scores from news text.
- Introduced asynchronous data fetching for improved performance and responsiveness.

Files changed (6) hide show
  1. data_source_config.py +30 -0
  2. preprocess.py +131 -4
  3. preprocess_yfinance.py +248 -0
  4. requirements.txt +21 -3
  5. us_stock.py +254 -62
  6. us_stock_yfinance.py +577 -0
data_source_config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 配置文件:控制数据源版本切换
2
+ # ================================
3
+ # 数据源版本切换配置
4
+ # ================================
5
+
6
+ # 设置为 True 使用新版 yfinance 实现
7
+ # 设置为 False 使用旧版 akshare 实现
8
+ USE_YFINANCE_VERSION = True
9
+
10
+ # 你可以在这里快速切换版本:
11
+ # USE_YFINANCE_VERSION = False # 切换到 akshare
12
+ # USE_YFINANCE_VERSION = True # 切换到 yfinance
13
+
14
+ # ================================
15
+ # 其他配置选项
16
+ # ================================
17
+
18
+ # 数据缓存时间(分钟)
19
+ PRICE_CACHE_MINUTES = 30
20
+
21
+ # API 超时时间(秒)
22
+ API_TIMEOUT_SECONDS = 30
23
+
24
+ # 最大重试次数
25
+ MAX_RETRY_ATTEMPTS = 3
26
+
27
+ # 调试模式
28
+ DEBUG_MODE = False
29
+
30
+ print(f"📊 数据源配置: {'yfinance (新版)' if USE_YFINANCE_VERSION else 'akshare (旧版)'}")
preprocess.py CHANGED
@@ -1,3 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import sys
3
  import os
@@ -18,6 +52,14 @@ from transformers import pipeline
18
 
19
  # 还需要导入 pickle 模块(如果你在代码的其他部分使用了它来处理序列化/反序列化)
20
  import pickle
 
 
 
 
 
 
 
 
21
  from gensim.models import KeyedVectors
22
  import akshare as ak
23
 
@@ -81,10 +123,95 @@ def get_tokenizer_and_model(model_type="one"):
81
 
82
  return _models[model_type]
83
 
84
- index_us_stock_index_INX = ak.index_us_stock_sina(symbol=".INX")
85
- index_us_stock_index_DJI = ak.index_us_stock_sina(symbol=".DJI")
86
- index_us_stock_index_IXIC = ak.index_us_stock_sina(symbol=".IXIC")
87
- index_us_stock_index_NDX = ak.index_us_stock_sina(symbol=".NDX")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
  class LazyWord2Vec:
 
1
+ # ================================
2
+ # 版本切换开关 - 从配置文件导入
3
+ # ================================
4
+ from data_source_config import USE_YFINANCE_VERSION
5
+
6
+ import re
7
+ import sys
8
+ import os
9
+ import trace
10
+ import traceback
11
+ from typing import final
12
+ import numpy as np
13
+ from collections import defaultdict
14
+ import pandas as pd
15
+ import time
16
+
17
+ # 如果使用 spaCy 进行 NLP 处理
18
+ from regex import R
19
+ import spacy
20
+
21
+ # 如果使用某种情感分析工具,比如 Hugging Face 的模型
22
+ from transformers import pipeline
23
+
24
+ # 还需要导入 pickle 模块(如果你在代码的其他部分使用了它来处理序列化/反序列化)
25
+ import pickle
26
+
27
+ # 根据开关导入不同的模块
28
+ if USE_YFINANCE_VERSION:
29
+ import yfinance as yf
30
+ print("🔄 Using yfinance version in preprocess (new)")
31
+ else:
32
+ import akshare as ak
33
+ print("🔄 Using akshare version in preprocess (old)")
34
+
35
  import re
36
  import sys
37
  import os
 
52
 
53
  # 还需要导入 pickle 模块(如果你在代码的其他部分使用了它来处理序列化/反序列化)
54
  import pickle
55
+
56
+ # 根据开关导入不同的模块
57
+ if USE_YFINANCE_VERSION:
58
+ import yfinance as yf
59
+ print("🔄 Using yfinance version in preprocess (new)")
60
+ else:
61
+ import akshare as ak
62
+ print("🔄 Using akshare version in preprocess (old)")
63
  from gensim.models import KeyedVectors
64
  import akshare as ak
65
 
 
123
 
124
  return _models[model_type]
125
 
126
+ # 初始化股票指数数据,根据开关选择不同的实现
127
+ def init_stock_indices():
128
+ """根据版本开关初始化股票指数数据"""
129
+ global index_us_stock_index_INX, index_us_stock_index_DJI, index_us_stock_index_IXIC, index_us_stock_index_NDX
130
+
131
+ if USE_YFINANCE_VERSION:
132
+ print("Initializing stock indices using yfinance...")
133
+ try:
134
+ from datetime import datetime, timedelta
135
+
136
+ # 计算日期范围
137
+ end_date = datetime.now()
138
+ start_date = end_date - timedelta(weeks=8)
139
+
140
+ # 定义指数映射
141
+ indices = {
142
+ '^GSPC': 'INX', # S&P 500
143
+ '^DJI': 'DJI', # Dow Jones
144
+ '^IXIC': 'IXIC', # NASDAQ Composite
145
+ '^NDX': 'NDX' # NASDAQ 100
146
+ }
147
+
148
+ results = {}
149
+
150
+ for yf_symbol, var_name in indices.items():
151
+ try:
152
+ ticker = yf.Ticker(yf_symbol)
153
+ hist_data = ticker.history(start=start_date, end=end_date)
154
+
155
+ if not hist_data.empty:
156
+ # 转换为与akshare相同的格式
157
+ formatted_data = pd.DataFrame({
158
+ 'date': hist_data.index.strftime('%Y-%m-%d'),
159
+ '开盘': hist_data['Open'].values,
160
+ '收盘': hist_data['Close'].values,
161
+ '最高': hist_data['High'].values,
162
+ '最低': hist_data['Low'].values,
163
+ '成交量': hist_data['Volume'].values,
164
+ '成交额': (hist_data['Close'] * hist_data['Volume']).values
165
+ })
166
+ results[var_name] = formatted_data
167
+ else:
168
+ results[var_name] = pd.DataFrame()
169
+
170
+ except Exception as e:
171
+ print(f"Error fetching {yf_symbol}: {e}")
172
+ results[var_name] = pd.DataFrame()
173
+
174
+ # 设置全局变量
175
+ index_us_stock_index_INX = results.get('INX', pd.DataFrame())
176
+ index_us_stock_index_DJI = results.get('DJI', pd.DataFrame())
177
+ index_us_stock_index_IXIC = results.get('IXIC', pd.DataFrame())
178
+ index_us_stock_index_NDX = results.get('NDX', pd.DataFrame())
179
+
180
+ except Exception as e:
181
+ print(f"Error initializing indices with yfinance: {e}")
182
+ # 设置空DataFrame作为fallback
183
+ index_us_stock_index_INX = pd.DataFrame()
184
+ index_us_stock_index_DJI = pd.DataFrame()
185
+ index_us_stock_index_IXIC = pd.DataFrame()
186
+ index_us_stock_index_NDX = pd.DataFrame()
187
+ else:
188
+ print("Initializing stock indices using akshare...")
189
+ try:
190
+ index_us_stock_index_INX = ak.index_us_stock_sina(symbol=".INX")
191
+ index_us_stock_index_DJI = ak.index_us_stock_sina(symbol=".DJI")
192
+ index_us_stock_index_IXIC = ak.index_us_stock_sina(symbol=".IXIC")
193
+ index_us_stock_index_NDX = ak.index_us_stock_sina(symbol=".NDX")
194
+ except Exception as e:
195
+ print(f"Error initializing indices with akshare: {e}")
196
+ index_us_stock_index_INX = pd.DataFrame()
197
+ index_us_stock_index_DJI = pd.DataFrame()
198
+ index_us_stock_index_IXIC = pd.DataFrame()
199
+ index_us_stock_index_NDX = pd.DataFrame()
200
+
201
+ # 延迟初始化索引数据
202
+ import threading
203
+ def delayed_init():
204
+ time.sleep(5) # 等待5秒
205
+ init_stock_indices()
206
+
207
+ init_thread = threading.Thread(target=delayed_init, daemon=True)
208
+ init_thread.start()
209
+
210
+ # 设置初始值为None,等待延迟初始化
211
+ index_us_stock_index_INX = None
212
+ index_us_stock_index_DJI = None
213
+ index_us_stock_index_IXIC = None
214
+ index_us_stock_index_NDX = None
215
 
216
 
217
  class LazyWord2Vec:
preprocess_yfinance.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from datetime import datetime, timedelta, date
3
+ import numpy as np
4
+ import asyncio
5
+ import threading
6
+ import time
7
+ import yfinance as yf
8
+
9
+ # 索引变量初始化
10
+ # 以下变量在外部模块中定义并在运行时更新
11
+ index_us_stock_index_INX = None
12
+ index_us_stock_index_DJI = None
13
+ index_us_stock_index_IXIC = None
14
+ index_us_stock_index_NDX = None
15
+
16
+ def init_stock_index_data():
17
+ """初始化股票指数数据,使用 yfinance"""
18
+ global index_us_stock_index_INX, index_us_stock_index_DJI, index_us_stock_index_IXIC, index_us_stock_index_NDX
19
+
20
+ try:
21
+ # 计算日期范围
22
+ end_date = datetime.now()
23
+ start_date = end_date - timedelta(weeks=8)
24
+
25
+ # 定义指数映射
26
+ indices = {
27
+ '^GSPC': 'INX', # S&P 500
28
+ '^DJI': 'DJI', # Dow Jones
29
+ '^IXIC': 'IXIC', # NASDAQ Composite
30
+ '^NDX': 'NDX' # NASDAQ 100
31
+ }
32
+
33
+ results = {}
34
+
35
+ for yf_symbol, var_name in indices.items():
36
+ try:
37
+ print(f"Fetching {var_name} data using yfinance...")
38
+ ticker = yf.Ticker(yf_symbol)
39
+ hist_data = ticker.history(start=start_date, end=end_date)
40
+
41
+ if not hist_data.empty:
42
+ # 转换为与原来相同的格式
43
+ formatted_data = pd.DataFrame({
44
+ 'date': hist_data.index.strftime('%Y-%m-%d'),
45
+ '开盘': hist_data['Open'].values,
46
+ '收盘': hist_data['Close'].values,
47
+ '最高': hist_data['High'].values,
48
+ '最低': hist_data['Low'].values,
49
+ '成交量': hist_data['Volume'].values,
50
+ '成交额': (hist_data['Close'] * hist_data['Volume']).values
51
+ })
52
+ results[var_name] = formatted_data
53
+ print(f"Successfully fetched {var_name}: {len(formatted_data)} records")
54
+ else:
55
+ print(f"No data for {yf_symbol}")
56
+ results[var_name] = pd.DataFrame()
57
+
58
+ except Exception as e:
59
+ print(f"Error fetching {yf_symbol}: {e}")
60
+ results[var_name] = pd.DataFrame()
61
+
62
+ # 设置全局变量
63
+ index_us_stock_index_INX = results.get('INX', pd.DataFrame())
64
+ index_us_stock_index_DJI = results.get('DJI', pd.DataFrame())
65
+ index_us_stock_index_IXIC = results.get('IXIC', pd.DataFrame())
66
+ index_us_stock_index_NDX = results.get('NDX', pd.DataFrame())
67
+
68
+ print("Stock indices initialized successfully using yfinance")
69
+
70
+ except Exception as e:
71
+ print(f"Error initializing stock indices: {e}")
72
+ # 设置空的DataFrame作为fallback
73
+ index_us_stock_index_INX = pd.DataFrame()
74
+ index_us_stock_index_DJI = pd.DataFrame()
75
+ index_us_stock_index_IXIC = pd.DataFrame()
76
+ index_us_stock_index_NDX = pd.DataFrame()
77
+
78
+ def delayed_init_indices():
79
+ """延迟初始化指数数据"""
80
+ time.sleep(5) # 等待5秒后开始初始化
81
+ init_stock_index_data()
82
+
83
+ # 启动延迟初始化
84
+ init_thread = threading.Thread(target=delayed_init_indices, daemon=True)
85
+ init_thread.start()
86
+
87
+ # 下面是原有的其他函数,保持不变...
88
+
89
+ # 新的文本时间处理函数
90
+ def parse_time(time_str):
91
+ """解析时间字符串并返回规范化的日期格式"""
92
+ if not time_str:
93
+ return None
94
+
95
+ today = date.today()
96
+
97
+ # 处理相对时间表达
98
+ if '昨天' in time_str or '昨日' in time_str:
99
+ return (today - timedelta(days=1)).strftime('%Y-%m-%d')
100
+ elif '今天' in time_str or '今日' in time_str:
101
+ return today.strftime('%Y-%m-%d')
102
+ elif '前天' in time_str:
103
+ return (today - timedelta(days=2)).strftime('%Y-%m-%d')
104
+ elif '上周' in time_str:
105
+ return (today - timedelta(weeks=1)).strftime('%Y-%m-%d')
106
+ elif '上月' in time_str:
107
+ return (today - timedelta(days=30)).strftime('%Y-%m-%d')
108
+
109
+ # 处理具体日期格式
110
+ try:
111
+ # 尝试多种日期格式
112
+ formats = ['%Y-%m-%d', '%Y/%m/%d', '%m/%d/%Y', '%m-%d-%Y', '%d/%m/%Y', '%d-%m-%Y']
113
+ for fmt in formats:
114
+ try:
115
+ parsed_date = datetime.strptime(time_str, fmt).date()
116
+ return parsed_date.strftime('%Y-%m-%d')
117
+ except ValueError:
118
+ continue
119
+ except:
120
+ pass
121
+
122
+ # 如果无法解析,返回今天的日期
123
+ return today.strftime('%Y-%m-%d')
124
+
125
+ # 原有的其他函数...
126
+ def preprocess_news_text(text):
127
+ """预处理新闻文本"""
128
+ # 移除多余的空白字符
129
+ text = ' '.join(text.split())
130
+ # 转换为小写
131
+ text = text.lower()
132
+ return text
133
+
134
+ def extract_sentiment_score(text):
135
+ """提取情感分数的占位符函数"""
136
+ # 这里可以集成实际的���感分析模型
137
+ # 目前返回一个基于文本长度的简单分数
138
+ if not text:
139
+ return 0.0
140
+
141
+ positive_words = ['good', 'great', 'excellent', 'positive', 'growth', 'profit', 'gain', 'rise', 'up']
142
+ negative_words = ['bad', 'poor', 'negative', 'loss', 'decline', 'fall', 'down', 'crash']
143
+
144
+ text_lower = text.lower()
145
+ positive_count = sum(1 for word in positive_words if word in text_lower)
146
+ negative_count = sum(1 for word in negative_words if word in text_lower)
147
+
148
+ if positive_count > negative_count:
149
+ return min(1.0, positive_count * 0.2)
150
+ elif negative_count > positive_count:
151
+ return max(-1.0, -negative_count * 0.2)
152
+ else:
153
+ return 0.0
154
+
155
+ def calculate_technical_indicators(price_data):
156
+ """计算技术指标"""
157
+ if price_data.empty:
158
+ return {}
159
+
160
+ close_prices = price_data['close']
161
+
162
+ # 简单移动平均线
163
+ sma_5 = close_prices.rolling(window=5).mean().iloc[-1] if len(close_prices) >= 5 else close_prices.iloc[-1]
164
+ sma_10 = close_prices.rolling(window=10).mean().iloc[-1] if len(close_prices) >= 10 else close_prices.iloc[-1]
165
+
166
+ # RSI (相对强弱指数)
167
+ def calculate_rsi(prices, window=14):
168
+ if len(prices) < window:
169
+ return 50.0 # 默认值
170
+
171
+ delta = prices.diff()
172
+ gain = delta.where(delta > 0, 0)
173
+ loss = -delta.where(delta < 0, 0)
174
+
175
+ avg_gain = gain.rolling(window=window).mean()
176
+ avg_loss = loss.rolling(window=window).mean()
177
+
178
+ rs = avg_gain / avg_loss
179
+ rsi = 100 - (100 / (1 + rs))
180
+ return rsi.iloc[-1]
181
+
182
+ rsi = calculate_rsi(close_prices)
183
+
184
+ # 价格变化百分比
185
+ price_change = ((close_prices.iloc[-1] - close_prices.iloc[0]) / close_prices.iloc[0] * 100) if len(close_prices) > 1 else 0
186
+
187
+ return {
188
+ 'sma_5': sma_5,
189
+ 'sma_10': sma_10,
190
+ 'rsi': rsi,
191
+ 'price_change_pct': price_change
192
+ }
193
+
194
+ def normalize_features(features_dict):
195
+ """标准化特征值"""
196
+ normalized = {}
197
+
198
+ for key, value in features_dict.items():
199
+ if isinstance(value, (int, float)) and not pd.isna(value):
200
+ # 简单的min-max标准化到[-1, 1]范围
201
+ if key == 'rsi':
202
+ normalized[key] = (value - 50) / 50 # RSI标准化
203
+ elif key.endswith('_pct'):
204
+ normalized[key] = np.tanh(value / 100) # 百分比变化标准化
205
+ else:
206
+ normalized[key] = np.tanh(value / 1000) # 其他数值标准化
207
+ else:
208
+ normalized[key] = 0.0
209
+
210
+ return normalized
211
+
212
+ # 主要的预处理函数
213
+ def preprocess_for_model(news_text, stock_symbol, news_date):
214
+ """为模型预处理数据"""
215
+ try:
216
+ # 预处理文本
217
+ processed_text = preprocess_news_text(news_text)
218
+
219
+ # 解析日期
220
+ parsed_date = parse_time(news_date)
221
+
222
+ # 提取情感分数
223
+ sentiment_score = extract_sentiment_score(processed_text)
224
+
225
+ # 这里应该调用股票数据获取函数
226
+ # 由于需要避免循环导入,这里只返回基本特征
227
+
228
+ return {
229
+ 'processed_text': processed_text,
230
+ 'sentiment_score': sentiment_score,
231
+ 'news_date': parsed_date,
232
+ 'stock_symbol': stock_symbol
233
+ }
234
+
235
+ except Exception as e:
236
+ print(f"Error in preprocess_for_model: {e}")
237
+ return {
238
+ 'processed_text': news_text,
239
+ 'sentiment_score': 0.0,
240
+ 'news_date': date.today().strftime('%Y-%m-%d'),
241
+ 'stock_symbol': stock_symbol
242
+ }
243
+
244
+ if __name__ == "__main__":
245
+ # 测试函数
246
+ test_text = "Apple Inc. reported strong quarterly earnings, beating expectations."
247
+ result = preprocess_for_model(test_text, "AAPL", "2024-02-14")
248
+ print(f"Preprocessing result: {result}")
requirements.txt CHANGED
@@ -3,12 +3,12 @@ blis==0.7.11
3
  spacy==3.7.5
4
  gensim
5
  numpy
6
- gensim
7
  fastapi
8
  requests
9
  sentencepiece
 
10
  transformers
11
- uvicorn[standard]==0.17.*
12
  keras==3.6.0
13
  yfinance==0.2.65
14
  jsonpath==0.82.2
@@ -17,5 +17,23 @@ pydantic==2.9.2
17
  pydantic_core==2.23.4
18
  nltk
19
  gunicorn
 
 
 
 
 
 
 
 
 
 
 
 
20
  --only-binary torch
21
- torch==2.8.0
 
 
 
 
 
 
 
3
  spacy==3.7.5
4
  gensim
5
  numpy
 
6
  fastapi
7
  requests
8
  sentencepiece
9
+ # 建议锁个合理范围,避免上游突发大版本: transformers>=4.41,<4.45
10
  transformers
11
+ uvicorn[standard]==0.35.0
12
  keras==3.6.0
13
  yfinance==0.2.65
14
  jsonpath==0.82.2
 
17
  pydantic_core==2.23.4
18
  nltk
19
  gunicorn
20
+
21
+ # ---------------- 关键约束:解决你看到的冲突 ----------------
22
+ # 1) uvicorn[standard] 会安装 websockets;为兼容 gradio-client 等生态,限制 <13
23
+ # websockets>=10,<13
24
+
25
+ # 2) TensorFlow 2.16.2 需要 protobuf < 5(建议锁到 4.25.x)
26
+ protobuf>=4.25.0,<5
27
+
28
+ # 3) 有些依赖会把 grpcio-status 拉到 1.63+(它要求 protobuf>=5.26.1)→ 与 TF 冲突
29
+ grpcio-status<1.63
30
+
31
+ # 4) 避免在某些平台触发 PyTorch 源码编译(非常耗内存/时间)
32
  --only-binary torch
33
+
34
+ # ---------------- PyTorch:按平台安装不同版本 ----------------
35
+ # Intel Mac(macOS x86_64)只到 2.2.2
36
+ torch==2.2.2; platform_system == "Darwin" and platform_machine == "x86_64" and python_version < "3.13"
37
+
38
+ # Linux / Windows / macOS arm64 → 2.8.0(注意也限制 Python < 3.13)
39
+ torch==2.8.0; (platform_system == "Linux" or platform_system == "Windows" or (platform_system == "Darwin" and platform_machine == "arm64")) and python_version < "3.13"
us_stock.py CHANGED
@@ -1,18 +1,26 @@
 
 
 
 
 
1
  import logging
2
  import re
3
- import akshare as ak
4
  import pandas as pd
5
  from datetime import datetime, timedelta
6
  import time # 导入标准库的 time 模块
7
 
8
  import os
9
-
10
  import requests
11
  import threading
12
  import asyncio
13
 
14
- import yfinance
15
-
 
 
 
 
 
16
 
17
  logging.basicConfig(level=logging.INFO)
18
 
@@ -33,59 +41,154 @@ nasdaq_composite_stocks = pd.read_csv(nasdaq_composite_path)
33
 
34
 
35
  def fetch_stock_us_spot_data_with_retries():
36
- # 定义重试间隔时间序列(秒)
37
- retry_intervals = [10, 20, 60, 300, 600]
38
- retry_index = 0 # 初始重试序号
39
-
40
- while True:
41
- try:
42
- # 尝试获取API数据
43
- symbols = ak.stock_us_spot_em()
44
- return symbols # 成功获取数据后返回
45
-
46
- except Exception as e:
47
- print(f"Error fetching data: {e}")
48
-
49
- # 获取当前重试等待时间
50
- wait_time = retry_intervals[retry_index]
51
- print(f"Retrying in {wait_time} seconds...")
52
- time.sleep(wait_time) # 等待指定的秒数
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # 更新重试索引,但不要超出重试时间列表的范围
55
- retry_index = min(retry_index + 1, len(retry_intervals) - 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
 
59
  async def fetch_stock_us_spot_data_with_retries_async():
60
- retry_intervals = [10, 20] # 减少重试次数
61
- retry_index = 0
62
- max_retries = 2 # 最多重试2次
63
-
64
- for attempt in range(max_retries + 1):
65
  try:
66
- # 添加30秒超时
67
- symbols = await asyncio.wait_for(
68
- asyncio.to_thread(ak.stock_us_spot_em),
69
- timeout=30.0
70
- )
71
- return symbols
72
- except asyncio.TimeoutError:
73
- print(f"Timeout error fetching data (attempt {attempt + 1}/{max_retries + 1})")
74
  except Exception as e:
75
- print(f"Error fetching data (attempt {attempt + 1}/{max_retries + 1}): {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- if attempt < max_retries:
78
- wait_time = retry_intervals[min(retry_index, len(retry_intervals) - 1)]
79
- print(f"Retrying in {wait_time} seconds...")
80
- await asyncio.sleep(wait_time)
81
- retry_index += 1
82
-
83
- # 如果所有重试都失败,返回空数据
84
- print("All retries failed, returning empty data")
85
- return pd.DataFrame()
86
 
87
  symbols = None
88
 
 
 
 
 
 
 
 
 
 
 
 
89
  async def fetch_symbols():
90
  global symbols
91
  try:
@@ -114,10 +217,65 @@ def update_stock_indices():
114
  global index_us_stock_index_INX, index_us_stock_index_DJI, index_us_stock_index_IXIC, index_us_stock_index_NDX
115
  try:
116
  print("Starting stock indices update...")
117
- index_us_stock_index_INX = ak.index_us_stock_sina(symbol=".INX")
118
- index_us_stock_index_DJI = ak.index_us_stock_sina(symbol=".DJI")
119
- index_us_stock_index_IXIC = ak.index_us_stock_sina(symbol=".IXIC")
120
- index_us_stock_index_NDX = ak.index_us_stock_sina(symbol=".NDX")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  print("Stock indices updated successfully")
122
  except Exception as e:
123
  print(f"Error updating stock indices: {e}")
@@ -128,7 +286,7 @@ def update_stock_indices():
128
  # 程序开始时不立即更新,而是延迟启动
129
  def start_indices_update():
130
  """延迟启动股票指数更新,避免阻塞应用启动"""
131
- threading.Timer(60, update_stock_indices).start() # 60秒后开始第一次更新
132
 
133
  # 延迟启动股票指数更新
134
  start_indices_update()
@@ -206,13 +364,18 @@ def get_last_minute_stock_price(symbol: str, max_retries=3) -> float:
206
  for attempt in range(max_retries):
207
  try:
208
  # 缓存无效或不存在,从yfinance获取新数据
209
- stock_data = yfinance.download(
210
- symbol,
211
- period='1d',
212
- interval='5m',
213
- progress=False, # 禁用进度条
214
- timeout=10 # 设置超时时间
215
- )
 
 
 
 
 
216
 
217
  if stock_data.empty:
218
  print(f"Warning: Empty data received for {symbol}, attempt {attempt + 1}/{max_retries}")
@@ -263,10 +426,39 @@ def get_stock_history(symbol, news_date, retries=10):
263
 
264
  while retry_count <= retries and len(symbol) != 0: # 无限循环重试
265
  try:
266
- # 尝试获取API数据
267
- stock_hist_df = ak.stock_us_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust="")
268
-
269
- if stock_hist_df.empty: # 检查是否为空数据
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  # print(f"No data for {symbol} on {news_date}.")
271
  stock_hist_df = None # 将 DataFrame 设置为 None
272
  break
 
1
+ # ================================
2
+ # 版本切换开关 - 从配置文件导入
3
+ # ================================
4
+ from data_source_config import USE_YFINANCE_VERSION, API_TIMEOUT_SECONDS, MAX_RETRY_ATTEMPTS
5
+
6
  import logging
7
  import re
 
8
  import pandas as pd
9
  from datetime import datetime, timedelta
10
  import time # 导入标准库的 time 模块
11
 
12
  import os
 
13
  import requests
14
  import threading
15
  import asyncio
16
 
17
+ # 根据开关导入不同的模块
18
+ if USE_YFINANCE_VERSION:
19
+ import yfinance as yf
20
+ print("🔄 Using yfinance version (new)")
21
+ else:
22
+ import akshare as ak
23
+ print("🔄 Using akshare version (old)")
24
 
25
  logging.basicConfig(level=logging.INFO)
26
 
 
41
 
42
 
43
  def fetch_stock_us_spot_data_with_retries():
44
+ """根据开关选择不同的数据源获取股票列表"""
45
+ if USE_YFINANCE_VERSION:
46
+ return fetch_stock_us_spot_data_yfinance()
47
+ else:
48
+ return fetch_stock_us_spot_data_akshare()
49
+
50
+ def fetch_stock_us_spot_data_akshare():
51
+ """原始的 akshare 实现"""
52
+ if not USE_YFINANCE_VERSION:
53
+ # 定义重试间隔时间序列(秒)
54
+ retry_intervals = [10, 20, 60, 300, 600]
55
+ retry_index = 0 # 初始重试序号
56
+
57
+ while True:
58
+ try:
59
+ # 尝试获取API数据
60
+ symbols = ak.stock_us_spot_em()
61
+ return symbols # 成功获取数据后返回
62
+
63
+ except Exception as e:
64
+ print(f"Error fetching data: {e}")
65
+
66
+ # 获取当前重试等待时间
67
+ wait_time = retry_intervals[retry_index]
68
+ print(f"Retrying in {wait_time} seconds...")
69
+ time.sleep(wait_time) # 等待指定的秒数
70
+
71
+ # 更新重试索引,但不要超出重试时间列表的范围
72
+ retry_index = min(retry_index + 1, len(retry_intervals) - 1)
73
+ else:
74
+ print("Warning: akshare function called while using yfinance version")
75
+ return pd.DataFrame()
76
 
77
+ def fetch_stock_us_spot_data_yfinance():
78
+ """新的 yfinance 实现"""
79
+ try:
80
+ # 从本地CSV文件收集所有股票代码
81
+ all_symbols = set()
82
+
83
+ # 从各个指数CSV文件中提取股票代码
84
+ for df, name in [
85
+ (nasdaq_100_stocks, "NASDAQ-100"),
86
+ (dow_jones_stocks, "Dow Jones"),
87
+ (sp500_stocks, "S&P 500"),
88
+ (nasdaq_composite_stocks, "NASDAQ Composite")
89
+ ]:
90
+ if 'Symbol' in df.columns:
91
+ symbols_from_csv = df['Symbol'].dropna().astype(str).tolist()
92
+ all_symbols.update(symbols_from_csv)
93
+ elif 'Code' in df.columns:
94
+ symbols_from_csv = df['Code'].dropna().astype(str).tolist()
95
+ all_symbols.update(symbols_from_csv)
96
+
97
+ # 添加一些常见的ETF和热门股票
98
+ additional_symbols = [
99
+ # 主要ETF
100
+ 'SPY', 'QQQ', 'IWM', 'VTI', 'ARKK', 'TQQQ', 'SQQQ', 'SPXL',
101
+ # 热门科技股
102
+ 'AAPL', 'MSFT', 'GOOGL', 'GOOG', 'AMZN', 'TSLA', 'META', 'NVDA', 'NFLX',
103
+ 'AMD', 'INTC', 'ORCL', 'CRM', 'ADBE', 'PYPL', 'UBER', 'LYFT',
104
+ # 中概股
105
+ 'BABA', 'JD', 'PDD', 'NIO', 'XPEV', 'LI', 'DIDI', 'TME',
106
+ # 其他热门股票
107
+ 'COST', 'WMT', 'JPM', 'BAC', 'XOM', 'CVX', 'PFE', 'JNJ', 'KO', 'PEP'
108
+ ]
109
+ all_symbols.update(additional_symbols)
110
+
111
+ # 创建DataFrame
112
+ symbols_list = sorted(list(all_symbols))
113
+ symbols_df = pd.DataFrame({
114
+ '代码': symbols_list,
115
+ '名称': [f'{symbol} Inc.' for symbol in symbols_list] # 简单的名称映射
116
+ })
117
+
118
+ print(f"Created symbols dataframe with {len(symbols_df)} symbols using yfinance version")
119
+ return symbols_df
120
+
121
+ except Exception as e:
122
+ print(f"Error creating symbols dataframe: {e}")
123
+ # 返回基本的fallback数据
124
+ fallback_symbols = [
125
+ 'AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA', 'META', 'NVDA', 'NFLX',
126
+ 'SPY', 'QQQ', 'IWM', 'VTI'
127
+ ]
128
+ return pd.DataFrame({
129
+ '代码': fallback_symbols,
130
+ '名称': [f'{symbol} Inc.' for symbol in fallback_symbols]
131
+ })
132
 
133
 
134
 
135
  async def fetch_stock_us_spot_data_with_retries_async():
136
+ """异步版本的股票数据获取,支持版本切换"""
137
+ if USE_YFINANCE_VERSION:
 
 
 
138
  try:
139
+ return await asyncio.to_thread(fetch_stock_us_spot_data_yfinance)
 
 
 
 
 
 
 
140
  except Exception as e:
141
+ print(f"Error in async yfinance fetch: {e}")
142
+ return pd.DataFrame()
143
+ else:
144
+ return await fetch_stock_us_spot_data_akshare_async()
145
+
146
+ async def fetch_stock_us_spot_data_akshare_async():
147
+ """原始的 akshare 异步实现"""
148
+ if not USE_YFINANCE_VERSION:
149
+ retry_intervals = [10, 20] # 减少重试次数
150
+ retry_index = 0
151
+ max_retries = 2 # 最多重试2次
152
+
153
+ for attempt in range(max_retries + 1):
154
+ try:
155
+ # 添加30秒超时
156
+ symbols = await asyncio.wait_for(
157
+ asyncio.to_thread(ak.stock_us_spot_em),
158
+ timeout=30.0
159
+ )
160
+ return symbols
161
+ except asyncio.TimeoutError:
162
+ print(f"Timeout error fetching data (attempt {attempt + 1}/{max_retries + 1})")
163
+ except Exception as e:
164
+ print(f"Error fetching data (attempt {attempt + 1}/{max_retries + 1}): {e}")
165
+
166
+ if attempt < max_retries:
167
+ wait_time = retry_intervals[min(retry_index, len(retry_intervals) - 1)]
168
+ print(f"Retrying in {wait_time} seconds...")
169
+ await asyncio.sleep(wait_time)
170
+ retry_index += 1
171
 
172
+ # 如果所有重试都失败,返回空数据
173
+ print("All retries failed, returning empty data")
174
+ return pd.DataFrame()
175
+ else:
176
+ print("Warning: akshare async function called while using yfinance version")
177
+ return pd.DataFrame()
 
 
 
178
 
179
  symbols = None
180
 
181
+ def create_fallback_symbols():
182
+ """创建fallback符号数据,用于测试"""
183
+ fallback_symbols = [
184
+ 'AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA', 'META', 'NVDA', 'NFLX',
185
+ 'SPY', 'QQQ', 'IWM', 'VTI'
186
+ ]
187
+ return pd.DataFrame({
188
+ '代码': fallback_symbols,
189
+ '名称': [f'{symbol} Inc.' for symbol in fallback_symbols]
190
+ })
191
+
192
  async def fetch_symbols():
193
  global symbols
194
  try:
 
217
  global index_us_stock_index_INX, index_us_stock_index_DJI, index_us_stock_index_IXIC, index_us_stock_index_NDX
218
  try:
219
  print("Starting stock indices update...")
220
+
221
+ if USE_YFINANCE_VERSION:
222
+ print("Updating indices using yfinance...")
223
+ # 使用 yfinance 更新指数数据
224
+ from datetime import datetime, timedelta
225
+
226
+ # 计算日期范围
227
+ end_date = datetime.now()
228
+ start_date = end_date - timedelta(weeks=8)
229
+
230
+ # 定义指数映射
231
+ indices = {
232
+ '^GSPC': 'INX', # S&P 500
233
+ '^DJI': 'DJI', # Dow Jones
234
+ '^IXIC': 'IXIC', # NASDAQ Composite
235
+ '^NDX': 'NDX' # NASDAQ 100
236
+ }
237
+
238
+ for yf_symbol, var_name in indices.items():
239
+ try:
240
+ ticker = yf.Ticker(yf_symbol)
241
+ hist_data = ticker.history(start=start_date, end=end_date)
242
+
243
+ if not hist_data.empty:
244
+ # 转换为与akshare相同的格式
245
+ formatted_data = pd.DataFrame({
246
+ 'date': hist_data.index.strftime('%Y-%m-%d'),
247
+ '开盘': hist_data['Open'].values,
248
+ '收盘': hist_data['Close'].values,
249
+ '最高': hist_data['High'].values,
250
+ '最低': hist_data['Low'].values,
251
+ '成交量': hist_data['Volume'].values,
252
+ '成交额': (hist_data['Close'] * hist_data['Volume']).values
253
+ })
254
+
255
+ # 设置全局变量
256
+ if var_name == 'INX':
257
+ index_us_stock_index_INX = formatted_data
258
+ elif var_name == 'DJI':
259
+ index_us_stock_index_DJI = formatted_data
260
+ elif var_name == 'IXIC':
261
+ index_us_stock_index_IXIC = formatted_data
262
+ elif var_name == 'NDX':
263
+ index_us_stock_index_NDX = formatted_data
264
+
265
+ print(f"Successfully updated {var_name}: {len(formatted_data)} records")
266
+ else:
267
+ print(f"No data received for {yf_symbol}")
268
+
269
+ except Exception as e:
270
+ print(f"Error fetching {yf_symbol}: {e}")
271
+ else:
272
+ print("Updating indices using akshare...")
273
+ # 使用 akshare 更新指数数据
274
+ index_us_stock_index_INX = ak.index_us_stock_sina(symbol=".INX")
275
+ index_us_stock_index_DJI = ak.index_us_stock_sina(symbol=".DJI")
276
+ index_us_stock_index_IXIC = ak.index_us_stock_sina(symbol=".IXIC")
277
+ index_us_stock_index_NDX = ak.index_us_stock_sina(symbol=".NDX")
278
+
279
  print("Stock indices updated successfully")
280
  except Exception as e:
281
  print(f"Error updating stock indices: {e}")
 
286
  # 程序开始时不立即更新,而是延迟启动
287
  def start_indices_update():
288
  """延迟启动股票指数更新,避免阻塞应用启动"""
289
+ threading.Timer(5, update_stock_indices).start() # 5秒后开始第一次更新
290
 
291
  # 延迟启动股票指数更新
292
  start_indices_update()
 
364
  for attempt in range(max_retries):
365
  try:
366
  # 缓存无效或不存在,从yfinance获取新数据
367
+ if USE_YFINANCE_VERSION:
368
+ stock_data = yf.download(
369
+ symbol,
370
+ period='1d',
371
+ interval='5m',
372
+ progress=False, # 禁用进度条
373
+ timeout=10 # 设置超时时间
374
+ )
375
+ else:
376
+ # 使用akshare获取数据的逻辑
377
+ ticker = ak.stock_us_hist(symbol=symbol, period="daily", start_date="20240101", end_date="20240201")
378
+ stock_data = ticker if not ticker.empty else pd.DataFrame()
379
 
380
  if stock_data.empty:
381
  print(f"Warning: Empty data received for {symbol}, attempt {attempt + 1}/{max_retries}")
 
426
 
427
  while retry_count <= retries and len(symbol) != 0: # 无限循环重试
428
  try:
429
+ # 根据版本开关选择不同的API
430
+ if USE_YFINANCE_VERSION:
431
+ # 使用 yfinance 获取数据
432
+ ticker = yf.Ticker(symbol)
433
+ # 将日期格式转换为 yfinance 期望的格式 (YYYY-MM-DD)
434
+ yf_start_date = datetime.strptime(start_date, "%Y%m%d").strftime("%Y-%m-%d")
435
+ yf_end_date = datetime.strptime(end_date, "%Y%m%d").strftime("%Y-%m-%d")
436
+
437
+ stock_hist_df = ticker.history(start=yf_start_date, end=yf_end_date)
438
+
439
+ if not stock_hist_df.empty:
440
+ # 转换为与akshare相同的格式
441
+ stock_hist_df = stock_hist_df.reset_index()
442
+ stock_hist_df = pd.DataFrame({
443
+ 'date': stock_hist_df['Date'].dt.strftime('%Y-%m-%d'),
444
+ '开盘': stock_hist_df['Open'],
445
+ '收盘': stock_hist_df['Close'],
446
+ '最高': stock_hist_df['High'],
447
+ '最低': stock_hist_df['Low'],
448
+ '成交量': stock_hist_df['Volume'],
449
+ '成交额': stock_hist_df['Close'] * stock_hist_df['Volume'],
450
+ '振幅': 0, # yfinance没有直接提供,设为0
451
+ '涨跌幅': 0, # 可以计算,但这里简化为0
452
+ '涨跌额': 0, # 可以计算,但这里简化为0
453
+ '换手率': 0 # yfinance没有直接提供,设为0
454
+ })
455
+ else:
456
+ stock_hist_df = None
457
+ else:
458
+ # 使用 akshare 获取数据
459
+ stock_hist_df = ak.stock_us_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust="")
460
+
461
+ if stock_hist_df is None or stock_hist_df.empty: # 检查是否为空数据
462
  # print(f"No data for {symbol} on {news_date}.")
463
  stock_hist_df = None # 将 DataFrame 设置为 None
464
  break
us_stock_yfinance.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ import pandas as pd
4
+ from datetime import datetime, timedelta
5
+ import time # 导入标准库的 time 模块
6
+
7
+ import os
8
+
9
+ import requests
10
+ import threading
11
+ import asyncio
12
+
13
+ import yfinance as yf
14
+
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+
18
+
19
+ # 获取当前文件的目录
20
+ base_dir = os.path.dirname(os.path.abspath(__file__))
21
+
22
+ # 构建CSV文件的绝对路径
23
+ nasdaq_100_path = os.path.join(base_dir, './model/nasdaq100.csv')
24
+ dow_jones_path = os.path.join(base_dir, './model/dji.csv')
25
+ sp500_path = os.path.join(base_dir, './model/sp500.csv')
26
+ nasdaq_composite_path = os.path.join(base_dir, './model/nasdaq_all.csv')
27
+ # 从CSV文件加载成分股数据
28
+ nasdaq_100_stocks = pd.read_csv(nasdaq_100_path)
29
+ dow_jones_stocks = pd.read_csv(dow_jones_path)
30
+ sp500_stocks = pd.read_csv(sp500_path)
31
+ nasdaq_composite_stocks = pd.read_csv(nasdaq_composite_path)
32
+
33
+
34
+ def fetch_stock_us_spot_data_with_retries():
35
+ """使用 yfinance 和本地 CSV 数据创建股票代码表"""
36
+ try:
37
+ # 从本地CSV文件收集所有股票代码
38
+ all_symbols = set()
39
+
40
+ # 从各个指数CSV文件中提取股票代码
41
+ for df, name in [
42
+ (nasdaq_100_stocks, "NASDAQ-100"),
43
+ (dow_jones_stocks, "Dow Jones"),
44
+ (sp500_stocks, "S&P 500"),
45
+ (nasdaq_composite_stocks, "NASDAQ Composite")
46
+ ]:
47
+ if 'Symbol' in df.columns:
48
+ symbols_from_csv = df['Symbol'].dropna().astype(str).tolist()
49
+ all_symbols.update(symbols_from_csv)
50
+ elif 'Code' in df.columns:
51
+ symbols_from_csv = df['Code'].dropna().astype(str).tolist()
52
+ all_symbols.update(symbols_from_csv)
53
+
54
+ # 添加一些常见的ETF和热门股票
55
+ additional_symbols = [
56
+ # 主要ETF
57
+ 'SPY', 'QQQ', 'IWM', 'VTI', 'ARKK', 'TQQQ', 'SQQQ', 'SPXL',
58
+ # 热门科技股
59
+ 'AAPL', 'MSFT', 'GOOGL', 'GOOG', 'AMZN', 'TSLA', 'META', 'NVDA', 'NFLX',
60
+ 'AMD', 'INTC', 'ORCL', 'CRM', 'ADBE', 'PYPL', 'UBER', 'LYFT',
61
+ # 中概股
62
+ 'BABA', 'JD', 'PDD', 'NIO', 'XPEV', 'LI', 'DIDI', 'TME',
63
+ # 其他热门股票
64
+ 'COST', 'WMT', 'JPM', 'BAC', 'XOM', 'CVX', 'PFE', 'JNJ', 'KO', 'PEP'
65
+ ]
66
+ all_symbols.update(additional_symbols)
67
+
68
+ # 创建DataFrame
69
+ symbols_list = sorted(list(all_symbols))
70
+ symbols_df = pd.DataFrame({
71
+ '代码': symbols_list,
72
+ '名称': [f'{symbol} Inc.' for symbol in symbols_list] # 简单的名称映射
73
+ })
74
+
75
+ print(f"Created symbols dataframe with {len(symbols_df)} symbols")
76
+ return symbols_df
77
+
78
+ except Exception as e:
79
+ print(f"Error creating symbols dataframe: {e}")
80
+ # 返回基本的fallback数据
81
+ fallback_symbols = [
82
+ 'AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA', 'META', 'NVDA', 'NFLX',
83
+ 'SPY', 'QQQ', 'IWM', 'VTI'
84
+ ]
85
+ return pd.DataFrame({
86
+ '代码': fallback_symbols,
87
+ '名称': [f'{symbol} Inc.' for symbol in fallback_symbols]
88
+ })
89
+
90
+
91
+ async def fetch_stock_us_spot_data_with_retries_async():
92
+ """异步版本的股票代码获取"""
93
+ try:
94
+ return await asyncio.to_thread(fetch_stock_us_spot_data_with_retries)
95
+ except Exception as e:
96
+ print(f"Error in async fetch: {e}")
97
+ return pd.DataFrame()
98
+
99
+
100
+ symbols = None
101
+
102
+ async def fetch_symbols():
103
+ global symbols
104
+ try:
105
+ print("Starting symbols initialization...")
106
+ # 异步获取数据
107
+ symbols = await fetch_stock_us_spot_data_with_retries_async()
108
+ if symbols is not None and not symbols.empty:
109
+ print(f"Symbols initialized successfully: {len(symbols)} symbols loaded")
110
+ else:
111
+ print("Symbols initialization failed, using empty dataset")
112
+ symbols = pd.DataFrame()
113
+ except Exception as e:
114
+ print(f"Error in fetch_symbols: {e}")
115
+ symbols = pd.DataFrame()
116
+ finally:
117
+ print("Symbols initialization completed")
118
+
119
+
120
+ # 全局变量
121
+ index_us_stock_index_INX = None
122
+ index_us_stock_index_DJI = None
123
+ index_us_stock_index_IXIC = None
124
+ index_us_stock_index_NDX = None
125
+
126
+ def update_stock_indices():
127
+ """使用 yfinance 获取美股指数数据"""
128
+ global index_us_stock_index_INX, index_us_stock_index_DJI, index_us_stock_index_IXIC, index_us_stock_index_NDX
129
+ try:
130
+ print("Starting stock indices update using yfinance...")
131
+
132
+ # 获取过去8周的数据
133
+ end_date = datetime.now()
134
+ start_date = end_date - timedelta(weeks=8)
135
+
136
+ # 指数映射
137
+ indices = {
138
+ '^GSPC': 'INX', # S&P 500
139
+ '^DJI': 'DJI', # Dow Jones
140
+ '^IXIC': 'IXIC', # NASDAQ Composite
141
+ '^NDX': 'NDX' # NASDAQ 100
142
+ }
143
+
144
+ results = {}
145
+
146
+ for yf_symbol, var_name in indices.items():
147
+ try:
148
+ ticker = yf.Ticker(yf_symbol)
149
+ hist_data = ticker.history(start=start_date, end=end_date)
150
+
151
+ if not hist_data.empty:
152
+ # 转换为与akshare相同的格式
153
+ formatted_data = pd.DataFrame({
154
+ 'date': hist_data.index.strftime('%Y-%m-%d'),
155
+ '开盘': hist_data['Open'].values,
156
+ '收盘': hist_data['Close'].values,
157
+ '最高': hist_data['High'].values,
158
+ '最低': hist_data['Low'].values,
159
+ '成交量': hist_data['Volume'].values,
160
+ '成交额': (hist_data['Close'] * hist_data['Volume']).values
161
+ })
162
+ results[var_name] = formatted_data
163
+ print(f"Successfully fetched {var_name} data: {len(formatted_data)} records")
164
+ else:
165
+ print(f"No data received for {yf_symbol}")
166
+ results[var_name] = pd.DataFrame()
167
+
168
+ except Exception as e:
169
+ print(f"Error fetching {yf_symbol}: {e}")
170
+ results[var_name] = pd.DataFrame()
171
+
172
+ # 设置全局变量
173
+ index_us_stock_index_INX = results.get('INX', pd.DataFrame())
174
+ index_us_stock_index_DJI = results.get('DJI', pd.DataFrame())
175
+ index_us_stock_index_IXIC = results.get('IXIC', pd.DataFrame())
176
+ index_us_stock_index_NDX = results.get('NDX', pd.DataFrame())
177
+
178
+ print("Stock indices updated successfully using yfinance")
179
+
180
+ except Exception as e:
181
+ print(f"Error updating stock indices: {e}")
182
+
183
+ # 设置定时器,每隔12小时更新一次
184
+ threading.Timer(12 * 60 * 60, update_stock_indices).start()
185
+
186
+ # 程序开始时不立即更新,而是延迟启动
187
+ def start_indices_update():
188
+ """延迟启动股票指数更新,避免阻塞应用启动"""
189
+ threading.Timer(5, update_stock_indices).start() # 5秒后开始第一次更新
190
+
191
+ # 延迟启动股票指数更新
192
+ start_indices_update()
193
+
194
+
195
+ # 创建列名转换的字典
196
+ column_mapping = {
197
+ '日期': 'date',
198
+ '开盘': 'open',
199
+ '收盘': 'close',
200
+ '最高': 'high',
201
+ '最低': 'low',
202
+ '成交量': 'volume',
203
+ '成交额': 'amount',
204
+ '振幅': 'amplitude',
205
+ '涨跌幅': 'price_change_percentage',
206
+ '涨跌额': 'price_change_amount',
207
+ '换手率': 'turnover_rate'
208
+ }
209
+
210
+ # 定义一个标准的列顺序
211
+ standard_columns = ['date', 'open', 'close', 'high', 'low', 'volume', 'amount']
212
+
213
+
214
+ # 定义查找函数
215
+ def find_stock_entry(stock_code):
216
+ # 使用 str.endswith 来匹配股票代码
217
+ if symbols is None or symbols.empty:
218
+ print("Warning: symbols data is empty")
219
+ return ""
220
+
221
+ try:
222
+ matching_row = symbols[symbols['代码'].str.endswith(stock_code, na=False)]
223
+ if not matching_row.empty:
224
+ return matching_row['代码'].values[0]
225
+ else:
226
+ # 如果没有找到,直接返回输入的代码(假设它是有效的)
227
+ return stock_code.upper()
228
+ except Exception as e:
229
+ print(f"Error in find_stock_entry: {e}")
230
+ return stock_code.upper()
231
+
232
+
233
+ def reduce_columns(df, columns_to_keep):
234
+ return df[columns_to_keep]
235
+
236
+
237
+ # 创建缓存字典
238
+ _price_cache = {}
239
+
240
+ def get_last_minute_stock_price(symbol: str, max_retries=3) -> float:
241
+ """获取股票最新价格,使用30分钟缓存,并包含重试机制"""
242
+
243
+ if not symbol:
244
+ return -1.0
245
+ if symbol == "NONE_SYMBOL_FOUND":
246
+ return -1.0
247
+
248
+ current_time = datetime.now()
249
+
250
+ # 检查缓存
251
+ if symbol in _price_cache:
252
+ cached_price, cached_time = _price_cache[symbol]
253
+ # 如果缓存时间在30分钟内,直接返回缓存的价格
254
+ if current_time - cached_time < timedelta(minutes=30):
255
+ return cached_price
256
+
257
+ # 重试机制
258
+ for attempt in range(max_retries):
259
+ try:
260
+ # 使用yfinance获取实时数据
261
+ ticker = yf.Ticker(symbol)
262
+ info = ticker.info
263
+
264
+ current_price = info.get('regularMarketPrice') or info.get('currentPrice')
265
+
266
+ if current_price is None:
267
+ # 尝试获取历史数据的最新价格
268
+ hist = ticker.history(period='1d', interval='1m')
269
+ if not hist.empty:
270
+ current_price = float(hist['Close'].iloc[-1])
271
+
272
+ if current_price is not None:
273
+ current_price = float(current_price)
274
+ # 更新缓存
275
+ _price_cache[symbol] = (current_price, current_time)
276
+ return current_price
277
+ else:
278
+ print(f"Warning: No price data for {symbol}, attempt {attempt + 1}/{max_retries}")
279
+ if attempt == max_retries - 1:
280
+ return -1.0
281
+ time.sleep(1)
282
+
283
+ except Exception as e:
284
+ print(f"Error fetching price for {symbol}, attempt {attempt + 1}/{max_retries}: {str(e)}")
285
+ if attempt == max_retries - 1:
286
+ return -1.0
287
+ time.sleep(1)
288
+
289
+ return -1.0
290
+
291
+
292
+ # 返回个股历史数据
293
+ def get_stock_history(symbol, news_date, retries=10):
294
+ """使用 yfinance 获取股票历史数据"""
295
+
296
+ # 如果传入的symbol不包含数字前缀,则通过 find_stock_entry 获取完整的symbol
297
+ if not any(char.isdigit() for char in symbol):
298
+ full_symbol = find_stock_entry(symbol)
299
+ if len(symbol) != 0 and full_symbol:
300
+ symbol = full_symbol
301
+ else:
302
+ symbol = ""
303
+
304
+ # 将news_date转换为datetime对象
305
+ current_date = datetime.now()
306
+
307
+ # 计算start_date和end_date
308
+ start_date = current_date - timedelta(days=60)
309
+ end_date = current_date
310
+
311
+ stock_hist_df = None
312
+ retry_count = 0
313
+
314
+ while retry_count <= retries and len(symbol) != 0:
315
+ try:
316
+ # 使用yfinance获取数据
317
+ ticker = yf.Ticker(symbol)
318
+ stock_hist_df = ticker.history(start=start_date, end=end_date)
319
+
320
+ if stock_hist_df.empty:
321
+ print(f"No data for {symbol} on {news_date}.")
322
+ stock_hist_df = None
323
+ else:
324
+ # 转换为与akshare相同的格式
325
+ stock_hist_df = stock_hist_df.reset_index()
326
+ stock_hist_df = pd.DataFrame({
327
+ 'date': stock_hist_df['Date'].dt.strftime('%Y-%m-%d'),
328
+ '开盘': stock_hist_df['Open'],
329
+ '收盘': stock_hist_df['Close'],
330
+ '最高': stock_hist_df['High'],
331
+ '最低': stock_hist_df['Low'],
332
+ '成交量': stock_hist_df['Volume'],
333
+ '成交额': stock_hist_df['Close'] * stock_hist_df['Volume'],
334
+ '振幅': 0, # yfinance没有直接提供,设为0
335
+ '涨跌幅': 0, # 可以计算,但这里简化为0
336
+ '涨跌额': 0, # 可以计算,但这里简化为0
337
+ '换手率': 0 # yfinance没有直接提供,设为0
338
+ })
339
+ break
340
+
341
+ except Exception as e:
342
+ print(f"Error {e} scraping data for {symbol} on {news_date}. Retrying...")
343
+ retry_count += 1
344
+ if retry_count <= retries:
345
+ time.sleep(2) # 等待2秒后重试
346
+ continue
347
+
348
+ # 如果获取失败或数据为空,返回填充为0的 DataFrame
349
+ if stock_hist_df is None or stock_hist_df.empty:
350
+ # 构建一个空的 DataFrame,包含指定日期范围的空数据
351
+ date_range = pd.date_range(start=start_date, end=end_date)
352
+ stock_hist_df = pd.DataFrame({
353
+ 'date': date_range.strftime('%Y-%m-%d'),
354
+ '开盘': 0,
355
+ '收盘': 0,
356
+ '最高': 0,
357
+ '最低': 0,
358
+ '成交量': 0,
359
+ '成交额': 0,
360
+ '振幅': 0,
361
+ '涨跌幅': 0,
362
+ '涨跌额': 0,
363
+ '换手率': 0
364
+ })
365
+
366
+ # 使用rename方法转换列名
367
+ stock_hist_df = stock_hist_df.rename(columns=column_mapping)
368
+ stock_hist_df = stock_hist_df.reindex(columns=standard_columns)
369
+ # 处理个股数据,保留所需列
370
+ stock_hist_df = reduce_columns(stock_hist_df, standard_columns)
371
+ return stock_hist_df
372
+
373
+
374
+ # 返回个股所属指数历史数据
375
+ def get_stock_index_history(symbol, news_date, force_index=0):
376
+ # 检查股票所属的指数
377
+ if symbol in nasdaq_100_stocks['Symbol'].values or force_index == 1:
378
+ index_code = ".NDX"
379
+ index_data = index_us_stock_index_NDX
380
+ elif symbol in dow_jones_stocks['Symbol'].values or force_index == 2:
381
+ index_code = ".DJI"
382
+ index_data = index_us_stock_index_DJI
383
+ elif symbol in sp500_stocks['Symbol'].values or force_index == 3:
384
+ index_code = ".INX"
385
+ index_data = index_us_stock_index_INX
386
+ elif symbol in nasdaq_composite_stocks["Symbol"].values or symbol is None or symbol == "" or force_index == 4:
387
+ index_code = ".IXIC"
388
+ index_data = index_us_stock_index_IXIC
389
+ else:
390
+ index_code = ".IXIC"
391
+ index_data = index_us_stock_index_IXIC
392
+
393
+ # 获取当前日期
394
+ current_date = datetime.now()
395
+
396
+ # 计算 start_date 和 end_date
397
+ start_date = (current_date - timedelta(weeks=8)).strftime("%Y-%m-%d")
398
+ end_date = current_date.strftime("%Y-%m-%d")
399
+
400
+ if index_data is None or index_data.empty:
401
+ # 如果全局数据为空,尝试实时获取
402
+ print(f"Index data for {index_code} is empty, fetching real-time data...")
403
+ try:
404
+ # 映射到yfinance符号
405
+ yf_symbol_map = {
406
+ '.INX': '^GSPC',
407
+ '.DJI': '^DJI',
408
+ '.IXIC': '^IXIC',
409
+ '.NDX': '^NDX'
410
+ }
411
+ yf_symbol = yf_symbol_map.get(index_code, '^IXIC')
412
+
413
+ ticker = yf.Ticker(yf_symbol)
414
+ hist_data = ticker.history(start=start_date, end=end_date)
415
+
416
+ if not hist_data.empty:
417
+ index_data = pd.DataFrame({
418
+ 'date': hist_data.index.strftime('%Y-%m-%d'),
419
+ '开盘': hist_data['Open'].values,
420
+ '收盘': hist_data['Close'].values,
421
+ '最高': hist_data['High'].values,
422
+ '最低': hist_data['Low'].values,
423
+ '成交量': hist_data['Volume'].values,
424
+ '成交额': (hist_data['Close'] * hist_data['Volume']).values
425
+ })
426
+ else:
427
+ # 返回空数据
428
+ date_range = pd.date_range(start=start_date, end=end_date)
429
+ index_data = pd.DataFrame({
430
+ 'date': date_range.strftime('%Y-%m-%d'),
431
+ '开盘': 0, '收盘': 0, '最高': 0, '最低': 0, '成交量': 0, '成交额': 0
432
+ })
433
+ except Exception as e:
434
+ print(f"Error fetching real-time index data: {e}")
435
+ # 返回空数据
436
+ date_range = pd.date_range(start=start_date, end=end_date)
437
+ index_data = pd.DataFrame({
438
+ 'date': date_range.strftime('%Y-%m-%d'),
439
+ '开盘': 0, '收盘': 0, '最高': 0, '最低': 0, '成交量': 0, '成交额': 0
440
+ })
441
+
442
+ # 确保 index_data['date'] 是 datetime 类型
443
+ index_data['date'] = pd.to_datetime(index_data['date'])
444
+
445
+ # 从指数历史数据中提取指定日期范围的数据
446
+ index_hist_df = index_data[(index_data['date'] >= start_date) & (index_data['date'] <= end_date)]
447
+
448
+ # 统一列名
449
+ index_hist_df = index_hist_df.rename(columns=column_mapping)
450
+ index_hist_df = index_hist_df.reindex(columns=standard_columns)
451
+ # 处理个股数据,保留所需列
452
+ index_hist_df = reduce_columns(index_hist_df, standard_columns)
453
+ return index_hist_df
454
+
455
+
456
+ def find_stock_codes_or_names(entities):
457
+ """
458
+ 从给定的实体列表中检索股票代码或公司名称。
459
+ """
460
+ stock_codes = set()
461
+
462
+ # 合并所有股票字典并清理数据,确保都是字符串
463
+ all_symbols = pd.concat([nasdaq_100_stocks['Symbol'],
464
+ dow_jones_stocks['Symbol'],
465
+ sp500_stocks['Symbol'],
466
+ nasdaq_composite_stocks['Symbol']]).dropna().astype(str).unique().tolist()
467
+
468
+ all_names = pd.concat([nasdaq_100_stocks['Name'],
469
+ nasdaq_composite_stocks['Name'],
470
+ sp500_stocks['Security'],
471
+ dow_jones_stocks['Company']]).dropna().astype(str).unique().tolist()
472
+
473
+ # 创建一个 Name 到 Symbol 的映射
474
+ name_to_symbol = {}
475
+ for idx, name in enumerate(all_names):
476
+ if idx < len(all_symbols):
477
+ symbol = all_symbols[idx]
478
+ name_to_symbol[name.lower()] = symbol
479
+
480
+ # 查找实体映射到的股票代码
481
+ for entity, entity_type in entities:
482
+ entity_lower = entity.lower()
483
+ entity_upper = entity.upper()
484
+
485
+ # 检查 Symbol 列
486
+ if entity_upper in all_symbols:
487
+ stock_codes.add(entity_upper)
488
+
489
+ # 检查 Name 列,确保完整匹配而不是部分匹配
490
+ for name, symbol in name_to_symbol.items():
491
+ # 使用正则表达式进行严格匹配
492
+ pattern = rf'\b{re.escape(entity_lower)}\b'
493
+ if re.search(pattern, name):
494
+ stock_codes.add(symbol.upper())
495
+
496
+ if not stock_codes:
497
+ return ['NONE_SYMBOL_FOUND']
498
+ return list(stock_codes)
499
+
500
+
501
+ def process_history(stock_history, target_date, history_days=30, following_days=3):
502
+ # 检查数据是否为空
503
+ if stock_history.empty:
504
+ return create_empty_data(history_days), create_empty_data(following_days)
505
+
506
+ # 确保日期列存在并转换为datetime格式
507
+ if 'date' not in stock_history.columns:
508
+ return create_empty_data(history_days), create_empty_data(following_days)
509
+
510
+ stock_history['date'] = pd.to_datetime(stock_history['date'])
511
+ target_date = pd.to_datetime(target_date)
512
+
513
+ # 按日期升序排序
514
+ stock_history = stock_history.sort_values('date')
515
+
516
+ # 找到目标日期对应的索引
517
+ target_row = stock_history[stock_history['date'] <= target_date]
518
+ if target_row.empty:
519
+ return create_empty_data(history_days), create_empty_data(following_days)
520
+
521
+ # 获取目标日期最近的行
522
+ target_index = target_row.index[-1]
523
+ target_pos = stock_history.index.get_loc(target_index)
524
+
525
+ # 获取历史数据(包括目标日期)
526
+ start_pos = max(0, target_pos - history_days + 1)
527
+ previous_rows = stock_history.iloc[start_pos:target_pos + 1]
528
+
529
+ # 获取后续数据
530
+ following_rows = stock_history.iloc[target_pos + 1:target_pos + following_days + 1]
531
+
532
+ # 删除日期列并确保数据完整性
533
+ previous_rows = previous_rows.drop(columns=['date'])
534
+ following_rows = following_rows.drop(columns=['date'])
535
+
536
+ # 处理数据不足的情况
537
+ previous_rows = handle_insufficient_data(previous_rows, history_days)
538
+ following_rows = handle_insufficient_data(following_rows, following_days)
539
+
540
+ return previous_rows.iloc[:, :6], following_rows.iloc[:, :6]
541
+
542
+
543
+ def create_empty_data(days):
544
+ return pd.DataFrame({
545
+ '开盘': [-1] * days,
546
+ '收盘': [-1] * days,
547
+ '最高': [-1] * days,
548
+ '最低': [-1] * days,
549
+ '成交量': [-1] * days,
550
+ '成交额': [-1] * days
551
+ })
552
+
553
+
554
+ def handle_insufficient_data(data, required_days):
555
+ current_rows = len(data)
556
+ if current_rows < required_days:
557
+ missing_rows = required_days - current_rows
558
+ empty_data = create_empty_data(missing_rows)
559
+ return pd.concat([empty_data, data]).reset_index(drop=True)
560
+ return data
561
+
562
+
563
+ if __name__ == "__main__":
564
+ # 测试函数
565
+ result = find_stock_entry('AAPL')
566
+ print(f"find_stock_entry: {result}")
567
+ result = get_stock_history('AAPL', '20240214')
568
+ print(f"get_stock_history: {result}")
569
+ result = get_stock_index_history('AAPL', '20240214')
570
+ print(f"get_stock_index_history: {result}")
571
+ result = find_stock_codes_or_names([('苹果', 'ORG'), ('苹果公司', 'ORG')])
572
+ print(f"find_stock_codes_or_names: {result}")
573
+ result = process_history(get_stock_history('AAPL', '20240214'), '20240214')
574
+ print(f"process_history: {result}")
575
+ result = process_history(get_stock_index_history('AAPL', '20240214'), '20240214')
576
+ print(f"process_history: {result}")
577
+ pass