File size: 13,836 Bytes
c4c2f6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import os
import gradio as gr
import requests
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
import numpy as np
import faiss
from collections import deque
from langchain_core.embeddings import Embeddings
import threading
import queue
from langchain_core.messages import HumanMessage, AIMessage
from sentence_transformers import SentenceTransformer
import pickle
import torch
import time
from tqdm import tqdm
import logging

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 获取环境变量
os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "")
if not os.environ["OPENROUTER_API_KEY"]:
    raise ValueError("OPENROUTER_API_KEY 未设置")
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY")
if not SILICONFLOW_API_KEY:
    raise ValueError("SILICONFLOW_API_KEY 未设置")

# SiliconFlow API 配置
SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/rerank"

# 自定义嵌入类,优化查询缓存
class SentenceTransformerEmbeddings(Embeddings):
    def __init__(self, model_name="BAAI/bge-m3"):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = SentenceTransformer(model_name, device=device)
        self.batch_size = 32  # 减小批次大小以适应低内存
        self.query_cache = {}
        self.cache_lock = threading.Lock()

    def embed_documents(self, texts):
        embeddings_list = []
        batch_size = 1000  # 减小批次以降低内存压力
        total_chunks = len(texts)
        logger.info(f"生成嵌入,文档数: {total_chunks}")
        with torch.no_grad():
            for i in tqdm(range(0, total_chunks, batch_size), desc="生成嵌入"):
                batch_texts = [text.page_content for text in texts[i:i + batch_size]]
                batch_emb = self.model.encode(
                    batch_texts,
                    normalize_embeddings=True,
                    batch_size=self.batch_size
                )
                embeddings_list.append(batch_emb)
        embeddings_array = np.vstack(embeddings_list)
        np.save("embeddings.npy", embeddings_array)
        return embeddings_array

    def embed_query(self, text):
        with self.cache_lock:
            if text in self.query_cache:
                return self.query_cache[text]
        with torch.no_grad():
            emb = self.model.encode([text], normalize_embeddings=True, batch_size=1)[0]
            with self.cache_lock:
                self.query_cache[text] = emb
                if len(self.query_cache) > 1000:  # 限制缓存大小
                    self.query_cache.pop(next(iter(self.query_cache)))
        return emb

# 重排序函数
def rerank_documents(query, documents, top_n=15):
    try:
        doc_texts = [(doc.page_content[:2048], doc.metadata.get("book", "未知来源")) for doc in documents[:50]]
        headers = {"Authorization": f"Bearer {SILICONFLOW_API_KEY}", "Content-Type": "application/json"}
        payload = {"model": "BAAI/bge-reranker-v2-m3", "query": query, "documents": [text for text, _ in doc_texts], "top_n": top_n}
        response = requests.post(SILICONFLOW_API_URL, headers=headers, json=payload)
        response.raise_for_status()
        result = response.json()
        reranked_docs = []
        for res in result["results"]:
            index = res["index"]
            score = res["relevance_score"]
            if index < len(documents):
                text, book = doc_texts[index]
                reranked_docs.append((documents[index], score))
        return sorted(reranked_docs, key=lambda x: x[1], reverse=True)[:top_n]
    except Exception as e:
        logger.error(f"重排序失败: {str(e)}")
        raise

# 构建 HNSW 索引
def build_hnsw_index(knowledge_base_path, index_path):
    loader = DirectoryLoader(knowledge_base_path, glob="*.txt", loader_cls=lambda path: TextLoader(path, encoding="utf-8"))
    documents = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    texts = text_splitter.split_documents(documents)
    for i, doc in enumerate(texts):
        doc.metadata["book"] = os.path.basename(doc.metadata.get("source", "未知来源")).replace(".txt", "")
    embeddings_array = embeddings.embed_documents(texts)
    dimension = embeddings_array.shape[1]
    index = faiss.IndexHNSWFlat(dimension, 16)
    index.hnsw.efConstruction = 100
    index.add(embeddings_array)
    vector_store = FAISS.from_embeddings([(doc.page_content, embeddings_array[i]) for i, doc in enumerate(texts)], embeddings)
    vector_store.index = index
    vector_store.save_local(index_path)
    with open("chunks.pkl", "wb") as f:
        pickle.dump(texts, f)
    return vector_store, texts

# 初始化嵌入模型和索引
embeddings = SentenceTransformerEmbeddings()
index_path = "faiss_index_hnsw_new"
knowledge_base_path = "knowledge_base"

if not os.path.exists(index_path):
    vector_store, all_documents = build_hnsw_index(knowledge_base_path, index_path)
else:
    vector_store = FAISS.load_local(index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
    vector_store.index.hnsw.efSearch = 200  # 降低 efSearch 以提升速度
    with open("chunks.pkl", "rb") as f:
        all_documents = pickle.load(f)

# 初始化 LLM
llm = ChatOpenAI(
    model="deepseek/deepseek-r1:free",
    api_key=os.environ["OPENROUTER_API_KEY"],
    base_url="https://openrouter.ai/api/v1",
    timeout=100,
    temperature=0.3,
    max_tokens=130000,
    streaming=True
)

# 提示词模板
prompt_template = PromptTemplate(
    input_variables=["context", "question", "chat_history"],
    template="""  

    你是一个研究李敖的专家,根据用户提出的问题{question}、最近7轮对话历史{chat_history}以及从李敖相关书籍和评论中检索的至少10篇文本内容{context}回答问题。  

    在回答时,请注意以下几点:  

    - 结合李敖的写作风格和思想,筛选出与问题和对话历史最相关的检索内容,避免无关信息。  

    - 必须在回答中引用至少10篇不同的文本内容,引用格式为[引用: 文本序号],例如[引用: 1][引用: 2],并确保每篇文本在回答中都有明确使用。  

    - 在回答的末尾,必须以“引用文献”标题列出所有引用的文本序号及其内容摘要(每篇不超过50字)以及具体的书目信息(例如书名和章节),格式为:  

      - 引用文献:  

        1. [文本 1] 摘要... 出自:书名,第X页/章节。  

        2. [文本 2] 摘要... 出自:书名,第X页/章节。  

        (依此类推,至少10篇)  

    - 如果问题涉及李敖对某人或某事的评价,优先引用李敖的直接言论或文字,并说明出处。  

    - 回答应结构化、分段落,确保逻辑清晰,语言生动,类似李敖的犀利风格。  

    - 如果检索内容和历史不足以直接回答问题,可根据李敖的性格和观点推测其可能的看法,但需说明这是推测。  

    - 只能基于提供的知识库内容{context}和对话历史{chat_history}回答,不得引入外部信息。  

    - 对于列举类问题,控制在10个要点以内,并优先提供最相关项。  

    - 如果回答较长,结构化分段总结,分点作答控制在8个点以内。  

    - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。  

    - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。  

    - 你的回答应该综合多个相关知识库内容来回答,不能重复引用一个知识库内容。  

    - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。

    """
)

# 对话历史管理
class ConversationHistory:
    def __init__(self, max_length=7):  # 减少历史轮数
        self.history = deque(maxlen=max_length)

    def add_turn(self, question, answer):
        self.history.append((question, answer))

    def get_history(self):
        return [(q, a) for q, a in self.history]

# 用户会话状态
class UserSession:
    def __init__(self):
        self.conversation = ConversationHistory()
        self.output_queue = queue.Queue()
        self.stop_flag = threading.Event()

# 生成回答
def generate_answer_thread(question, session):
    stop_flag = session.stop_flag
    output_queue = session.output_queue
    conversation = session.conversation

    stop_flag.clear()
    try:
        # 打印用户问题到控制台
        logger.info(f"用户问题: {question}")

        history_list = conversation.get_history()
        history_text = "\n".join([f"问: {q}\n答: {a}" for q, a in history_list[-4:]])  # 只用最后5轮
        query_with_context = f"{history_text}\n问题: {question}" if history_text else question

        # 异步生成查询嵌入
        embed_queue = queue.Queue()
        def embed_task():
            start = time.time()
            emb = embeddings.embed_query(query_with_context)
            embed_queue.put((emb, time.time() - start))
        embed_thread = threading.Thread(target=embed_task)
        embed_thread.start()
        embed_thread.join()
        query_embedding, embed_time = embed_queue.get()

        if stop_flag.is_set():
            output_queue.put("生成已停止")
            return

        # 初始检索
        start = time.time()
        docs_with_scores = vector_store.similarity_search_with_score_by_vector(query_embedding, k=50)
        search_time = time.time() - start

        if stop_flag.is_set():
            output_queue.put("生成已停止")
            return

        # 重排序
        initial_docs = [doc for doc, _ in docs_with_scores]
        start = time.time()
        reranked_docs_with_scores = rerank_documents(query_with_context, initial_docs)
        rerank_time = time.time() - start
        final_docs = [doc for doc, _ in reranked_docs_with_scores][:10]

        # 打印重排序结果到控制台
        logger.info("重排序结果(最终保留的片段及其得分):")
        for i, (doc, score) in enumerate(reranked_docs_with_scores[:10], 1):
            logger.info(f"片段 {i}:")
            logger.info(f"  内容: {doc.page_content[:100]}...")
            logger.info(f"  来源: {doc.metadata.get('book', '未知来源')}")
            logger.info(f"  得分: {score:.4f}")

        context = "\n".join([f"[文本 {i+1}] {doc.page_content} (出处: {doc.metadata.get('book')})" for i, doc in enumerate(final_docs)])
        prompt = prompt_template.format(context=context, question=question, chat_history=history_text)

        # 将时间信息加入回答开头
        timing_info = (
            f"处理时间统计:\n"
            f"- 嵌入时间: {embed_time:.2f} 秒\n"
            f"- 检索时间: {search_time:.2f} 秒\n"
            f"- 重排序时间: {rerank_time:.2f} 秒\n\n"
        )

        answer = timing_info
        output_queue.put(answer)  # 先显示时间信息

        # LLM 生成回答
        start = time.time()
        for chunk in llm.stream([HumanMessage(content=prompt)]):
            if stop_flag.is_set():
                output_queue.put(answer + "\n(生成已停止)")
                return
            answer += chunk.content
            output_queue.put(answer)
        llm_time = time.time() - start
        answer += f"\n\n生成耗时: {llm_time:.2f} 秒"
        output_queue.put(answer)

        conversation.add_turn(question, answer)
        output_queue.put(answer)

    except Exception as e:
        output_queue.put(f"Error: {str(e)}")

# Gradio 接口
def answer_question(question, session_state):
    if session_state is None:
        session_state = UserSession()
    
    thread = threading.Thread(target=generate_answer_thread, args=(question, session_state))
    thread.start()
    
    while thread.is_alive() or not session_state.output_queue.empty():
        try:
            output = session_state.output_queue.get(timeout=0.1)
            yield output, session_state
        except queue.Empty:
            continue

def stop_generation(session_state):
    if session_state:
        session_state.stop_flag.set()
    return "生成已停止"

def clear_conversation():
    return "对话已清空", UserSession()

# Gradio 界面
with gr.Blocks(title="AI李敖助手") as interface:
    gr.Markdown("### AI李敖助手")
    gr.Markdown("基于李敖163本相关书籍构建的知识库,支持上下文关联,记住最近7轮对话,输入问题以获取李敖风格的回答。")
    session_state = gr.State(value=None)
    question_input = gr.Textbox(label="问题")
    submit_button = gr.Button("提交")
    clear_button = gr.Button("新建对话")
    stop_button = gr.Button("停止")
    output_text = gr.Textbox(label="回答", interactive=False)

    submit_button.click(fn=answer_question, inputs=[question_input, session_state], outputs=[output_text, session_state])
    clear_button.click(fn=clear_conversation, inputs=None, outputs=[output_text, session_state])
    stop_button.click(fn=stop_generation, inputs=[session_state], outputs=output_text)

if __name__ == "__main__":
    interface.launch(share=True)