Spaces:
Sleeping
Sleeping
File size: 13,836 Bytes
c4c2f6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 |
import os
import gradio as gr
import requests
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
import numpy as np
import faiss
from collections import deque
from langchain_core.embeddings import Embeddings
import threading
import queue
from langchain_core.messages import HumanMessage, AIMessage
from sentence_transformers import SentenceTransformer
import pickle
import torch
import time
from tqdm import tqdm
import logging
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 获取环境变量
os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "")
if not os.environ["OPENROUTER_API_KEY"]:
raise ValueError("OPENROUTER_API_KEY 未设置")
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY")
if not SILICONFLOW_API_KEY:
raise ValueError("SILICONFLOW_API_KEY 未设置")
# SiliconFlow API 配置
SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/rerank"
# 自定义嵌入类,优化查询缓存
class SentenceTransformerEmbeddings(Embeddings):
def __init__(self, model_name="BAAI/bge-m3"):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SentenceTransformer(model_name, device=device)
self.batch_size = 32 # 减小批次大小以适应低内存
self.query_cache = {}
self.cache_lock = threading.Lock()
def embed_documents(self, texts):
embeddings_list = []
batch_size = 1000 # 减小批次以降低内存压力
total_chunks = len(texts)
logger.info(f"生成嵌入,文档数: {total_chunks}")
with torch.no_grad():
for i in tqdm(range(0, total_chunks, batch_size), desc="生成嵌入"):
batch_texts = [text.page_content for text in texts[i:i + batch_size]]
batch_emb = self.model.encode(
batch_texts,
normalize_embeddings=True,
batch_size=self.batch_size
)
embeddings_list.append(batch_emb)
embeddings_array = np.vstack(embeddings_list)
np.save("embeddings.npy", embeddings_array)
return embeddings_array
def embed_query(self, text):
with self.cache_lock:
if text in self.query_cache:
return self.query_cache[text]
with torch.no_grad():
emb = self.model.encode([text], normalize_embeddings=True, batch_size=1)[0]
with self.cache_lock:
self.query_cache[text] = emb
if len(self.query_cache) > 1000: # 限制缓存大小
self.query_cache.pop(next(iter(self.query_cache)))
return emb
# 重排序函数
def rerank_documents(query, documents, top_n=15):
try:
doc_texts = [(doc.page_content[:2048], doc.metadata.get("book", "未知来源")) for doc in documents[:50]]
headers = {"Authorization": f"Bearer {SILICONFLOW_API_KEY}", "Content-Type": "application/json"}
payload = {"model": "BAAI/bge-reranker-v2-m3", "query": query, "documents": [text for text, _ in doc_texts], "top_n": top_n}
response = requests.post(SILICONFLOW_API_URL, headers=headers, json=payload)
response.raise_for_status()
result = response.json()
reranked_docs = []
for res in result["results"]:
index = res["index"]
score = res["relevance_score"]
if index < len(documents):
text, book = doc_texts[index]
reranked_docs.append((documents[index], score))
return sorted(reranked_docs, key=lambda x: x[1], reverse=True)[:top_n]
except Exception as e:
logger.error(f"重排序失败: {str(e)}")
raise
# 构建 HNSW 索引
def build_hnsw_index(knowledge_base_path, index_path):
loader = DirectoryLoader(knowledge_base_path, glob="*.txt", loader_cls=lambda path: TextLoader(path, encoding="utf-8"))
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
for i, doc in enumerate(texts):
doc.metadata["book"] = os.path.basename(doc.metadata.get("source", "未知来源")).replace(".txt", "")
embeddings_array = embeddings.embed_documents(texts)
dimension = embeddings_array.shape[1]
index = faiss.IndexHNSWFlat(dimension, 16)
index.hnsw.efConstruction = 100
index.add(embeddings_array)
vector_store = FAISS.from_embeddings([(doc.page_content, embeddings_array[i]) for i, doc in enumerate(texts)], embeddings)
vector_store.index = index
vector_store.save_local(index_path)
with open("chunks.pkl", "wb") as f:
pickle.dump(texts, f)
return vector_store, texts
# 初始化嵌入模型和索引
embeddings = SentenceTransformerEmbeddings()
index_path = "faiss_index_hnsw_new"
knowledge_base_path = "knowledge_base"
if not os.path.exists(index_path):
vector_store, all_documents = build_hnsw_index(knowledge_base_path, index_path)
else:
vector_store = FAISS.load_local(index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
vector_store.index.hnsw.efSearch = 200 # 降低 efSearch 以提升速度
with open("chunks.pkl", "rb") as f:
all_documents = pickle.load(f)
# 初始化 LLM
llm = ChatOpenAI(
model="deepseek/deepseek-r1:free",
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
timeout=100,
temperature=0.3,
max_tokens=130000,
streaming=True
)
# 提示词模板
prompt_template = PromptTemplate(
input_variables=["context", "question", "chat_history"],
template="""
你是一个研究李敖的专家,根据用户提出的问题{question}、最近7轮对话历史{chat_history}以及从李敖相关书籍和评论中检索的至少10篇文本内容{context}回答问题。
在回答时,请注意以下几点:
- 结合李敖的写作风格和思想,筛选出与问题和对话历史最相关的检索内容,避免无关信息。
- 必须在回答中引用至少10篇不同的文本内容,引用格式为[引用: 文本序号],例如[引用: 1][引用: 2],并确保每篇文本在回答中都有明确使用。
- 在回答的末尾,必须以“引用文献”标题列出所有引用的文本序号及其内容摘要(每篇不超过50字)以及具体的书目信息(例如书名和章节),格式为:
- 引用文献:
1. [文本 1] 摘要... 出自:书名,第X页/章节。
2. [文本 2] 摘要... 出自:书名,第X页/章节。
(依此类推,至少10篇)
- 如果问题涉及李敖对某人或某事的评价,优先引用李敖的直接言论或文字,并说明出处。
- 回答应结构化、分段落,确保逻辑清晰,语言生动,类似李敖的犀利风格。
- 如果检索内容和历史不足以直接回答问题,可根据李敖的性格和观点推测其可能的看法,但需说明这是推测。
- 只能基于提供的知识库内容{context}和对话历史{chat_history}回答,不得引入外部信息。
- 对于列举类问题,控制在10个要点以内,并优先提供最相关项。
- 如果回答较长,结构化分段总结,分点作答控制在8个点以内。
- 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
- 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
- 你的回答应该综合多个相关知识库内容来回答,不能重复引用一个知识库内容。
- 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
"""
)
# 对话历史管理
class ConversationHistory:
def __init__(self, max_length=7): # 减少历史轮数
self.history = deque(maxlen=max_length)
def add_turn(self, question, answer):
self.history.append((question, answer))
def get_history(self):
return [(q, a) for q, a in self.history]
# 用户会话状态
class UserSession:
def __init__(self):
self.conversation = ConversationHistory()
self.output_queue = queue.Queue()
self.stop_flag = threading.Event()
# 生成回答
def generate_answer_thread(question, session):
stop_flag = session.stop_flag
output_queue = session.output_queue
conversation = session.conversation
stop_flag.clear()
try:
# 打印用户问题到控制台
logger.info(f"用户问题: {question}")
history_list = conversation.get_history()
history_text = "\n".join([f"问: {q}\n答: {a}" for q, a in history_list[-4:]]) # 只用最后5轮
query_with_context = f"{history_text}\n问题: {question}" if history_text else question
# 异步生成查询嵌入
embed_queue = queue.Queue()
def embed_task():
start = time.time()
emb = embeddings.embed_query(query_with_context)
embed_queue.put((emb, time.time() - start))
embed_thread = threading.Thread(target=embed_task)
embed_thread.start()
embed_thread.join()
query_embedding, embed_time = embed_queue.get()
if stop_flag.is_set():
output_queue.put("生成已停止")
return
# 初始检索
start = time.time()
docs_with_scores = vector_store.similarity_search_with_score_by_vector(query_embedding, k=50)
search_time = time.time() - start
if stop_flag.is_set():
output_queue.put("生成已停止")
return
# 重排序
initial_docs = [doc for doc, _ in docs_with_scores]
start = time.time()
reranked_docs_with_scores = rerank_documents(query_with_context, initial_docs)
rerank_time = time.time() - start
final_docs = [doc for doc, _ in reranked_docs_with_scores][:10]
# 打印重排序结果到控制台
logger.info("重排序结果(最终保留的片段及其得分):")
for i, (doc, score) in enumerate(reranked_docs_with_scores[:10], 1):
logger.info(f"片段 {i}:")
logger.info(f" 内容: {doc.page_content[:100]}...")
logger.info(f" 来源: {doc.metadata.get('book', '未知来源')}")
logger.info(f" 得分: {score:.4f}")
context = "\n".join([f"[文本 {i+1}] {doc.page_content} (出处: {doc.metadata.get('book')})" for i, doc in enumerate(final_docs)])
prompt = prompt_template.format(context=context, question=question, chat_history=history_text)
# 将时间信息加入回答开头
timing_info = (
f"处理时间统计:\n"
f"- 嵌入时间: {embed_time:.2f} 秒\n"
f"- 检索时间: {search_time:.2f} 秒\n"
f"- 重排序时间: {rerank_time:.2f} 秒\n\n"
)
answer = timing_info
output_queue.put(answer) # 先显示时间信息
# LLM 生成回答
start = time.time()
for chunk in llm.stream([HumanMessage(content=prompt)]):
if stop_flag.is_set():
output_queue.put(answer + "\n(生成已停止)")
return
answer += chunk.content
output_queue.put(answer)
llm_time = time.time() - start
answer += f"\n\n生成耗时: {llm_time:.2f} 秒"
output_queue.put(answer)
conversation.add_turn(question, answer)
output_queue.put(answer)
except Exception as e:
output_queue.put(f"Error: {str(e)}")
# Gradio 接口
def answer_question(question, session_state):
if session_state is None:
session_state = UserSession()
thread = threading.Thread(target=generate_answer_thread, args=(question, session_state))
thread.start()
while thread.is_alive() or not session_state.output_queue.empty():
try:
output = session_state.output_queue.get(timeout=0.1)
yield output, session_state
except queue.Empty:
continue
def stop_generation(session_state):
if session_state:
session_state.stop_flag.set()
return "生成已停止"
def clear_conversation():
return "对话已清空", UserSession()
# Gradio 界面
with gr.Blocks(title="AI李敖助手") as interface:
gr.Markdown("### AI李敖助手")
gr.Markdown("基于李敖163本相关书籍构建的知识库,支持上下文关联,记住最近7轮对话,输入问题以获取李敖风格的回答。")
session_state = gr.State(value=None)
question_input = gr.Textbox(label="问题")
submit_button = gr.Button("提交")
clear_button = gr.Button("新建对话")
stop_button = gr.Button("停止")
output_text = gr.Textbox(label="回答", interactive=False)
submit_button.click(fn=answer_question, inputs=[question_input, session_state], outputs=[output_text, session_state])
clear_button.click(fn=clear_conversation, inputs=None, outputs=[output_text, session_state])
stop_button.click(fn=stop_generation, inputs=[session_state], outputs=output_text)
if __name__ == "__main__":
interface.launch(share=True) |