File size: 12,411 Bytes
04e426f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
<<<<<<< HEAD
import torch
print(torch.__version__)  # 如 2.4.0+cu118
print(torch.cuda.is_available())  # 应返回 True
print(torch.cuda.get_device_name(0))  # 应返回 GPU 型号
=======
import os
import gradio as gr
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain_core.embeddings import Embeddings
from langchain.prompts import PromptTemplate
import requests
import numpy as np
import json
import faiss
from langchain_community.embeddings import OllamaEmbeddings

# 自定义 SiliconFlow 嵌入类
class SiliconFlowEmbeddings(Embeddings):
    def __init__(self, model="BAAI/bge-m3", api_key=None):
        self.model = model
        self.api_key = api_key

    def embed_documents(self, texts):
        return self._get_embeddings(texts)

    def embed_query(self, text):
        return self._get_embeddings([text])[0]

    def _get_embeddings(self, texts):
        url = "https://api.siliconflow.cn/v1/embeddings"
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        payload = {
            "model": self.model,
            "input": texts
        }
        response = requests.post(url, json=payload, headers=headers, timeout=30)
        if response.status_code == 200:
            data = response.json()
            return np.array([item["embedding"] for item in data["data"]])
        else:
            raise Exception(f"API 调用失败: {response.status_code}, {response.text}")

# SiliconFlow 重排序函数
def rerank_documents(query, documents, api_key, top_n=10):
    url = "https://api.siliconflow.cn/v1/rerank"
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    doc_texts = [doc.page_content for doc in documents]
    payload = {
        "model": "BAAI/bge-reranker-v2-m3",
        "query": query,
        "documents": doc_texts,
        "top_n": top_n
    }
    response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=30)
    if response.status_code == 200:
        result = response.json()
        reranked_results = result.get("results", [])
        if not reranked_results:
            raise Exception("重排序结果为空")
        reranked_docs_with_scores = [
            (documents[res["index"]], res["relevance_score"])
            for res in reranked_results
        ]
        return reranked_docs_with_scores
    else:
        raise Exception(f"重排序失败: {response.status_code}, {response.text}")

# 设置 API Keys
os.environ["SILICONFLOW_API_KEY"] = os.getenv("SILICONFLOW_API_KEY", "sk-cigytzyzghoziznvniugfihuicjcgmborusgodktydremtvd")
os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "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")

# 初始化嵌入模型
embeddings = OllamaEmbeddings(model="bge-m3", base_url="http://localhost:11434")

# 从 knowledge_base 生成 HNSW 索引
def build_hnsw_index(knowledge_base_path, index_path):
    loader = DirectoryLoader(
        knowledge_base_path,
        glob="*.txt",
        loader_cls=lambda path: TextLoader(path, encoding="utf-8")
    )
    documents = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    texts = text_splitter.split_documents(documents)
    
    # 使用 FAISS.from_documents 创建向量存储
    vector_store = FAISS.from_documents(texts, embeddings)
    
    # 获取嵌入并转换为 HNSW
    embeddings_array = np.array(embeddings.embed_documents([doc.page_content for doc in texts]))
    dimension = embeddings_array.shape[1]
    index = faiss.IndexHNSWFlat(dimension, 16)  # M=16
    index.hnsw.efConstruction = 100
    index.hnsw.efSearch = 50
    index.add(embeddings_array)
    
    # 更新 FAISS 的索引
    vector_store.index = index
    vector_store.save_local(index_path)
    print(f"HNSW 索引已生成并保存到 '{index_path}'")
    return vector_store

# 将已有 faiss_index 转为 HNSW
def convert_to_hnsw(existing_index_path, new_index_path):
    # 加载现有索引
    old_vector_store = FAISS.load_local(existing_index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
    
    # 获取文档内容
    if hasattr(old_vector_store, 'docstore') and hasattr(old_vector_store.docstore, '_dict'):
        docs = list(old_vector_store.docstore._dict.values())
        doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in docs]
    else:
        doc_ids = list(old_vector_store.index_to_docstore_id.keys())
        doc_texts = [old_vector_store.docstore._dict[old_vector_store.index_to_docstore_id[i]].page_content 
                     if hasattr(old_vector_store.docstore._dict[old_vector_store.index_to_docstore_id[i]], 'page_content') 
                     else str(old_vector_store.docstore._dict[old_vector_store.index_to_docstore_id[i]]) 
                     for i in doc_ids]
    
    # 使用全局 embeddings 对象生成嵌入
    embeddings_array = np.array(embeddings.embed_documents(doc_texts))
    
    # 创建 HNSW 索引
    dimension = embeddings_array.shape[1]
    index = faiss.IndexHNSWFlat(dimension, 16)  # M=16
    index.hnsw.efConstruction = 100
    index.hnsw.efSearch = 50
    index.add(embeddings_array)
    
    # 创建新的 FAISS 向量存储,注意不直接传递 index,而是稍后赋值
    new_vector_store = FAISS.from_texts(doc_texts, embeddings)
    new_vector_store.index = index  # 直接替换索引
    new_vector_store.save_local(new_index_path)
    print(f"已将 '{existing_index_path}' 转换为 HNSW 并保存到 '{new_index_path}'")
    return new_vector_store

# 加载或生成索引
index_path = "faiss_index_hnsw"
knowledge_base_path = "knowledge_base"

if not os.path.exists(index_path):
    if os.path.exists("faiss_index"):
        print("检测到已有 faiss_index,正在转换为 HNSW...")
        vector_store = convert_to_hnsw("faiss_index", index_path)
    elif os.path.exists(knowledge_base_path):
        print("检测到 knowledge_base,正在生成 HNSW 索引...")
        vector_store = build_hnsw_index(knowledge_base_path, index_path)
    else:
        raise FileNotFoundError("未找到 'faiss_index' 或 'knowledge_base',请提供知识库数据")
else:
    vector_store = FAISS.load_local(index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
    print("已加载 HNSW 索引 'faiss_index_hnsw'")

# 初始化 ChatOpenAI 使用 OpenRouter
llm = ChatOpenAI(
    model="deepseek/deepseek-r1:free",
    api_key=os.environ["OPENROUTER_API_KEY"],
    base_url="https://openrouter.ai/api/v1",
    timeout=60,
    temperature=0.3,
    max_tokens=88888,
)

# 定义提示词模板
prompt_template = PromptTemplate(
    input_variables=["context", "question"],
    template="""  

    你是一个研究李敖的专家,根据用户提出的问题{question}以及从李敖相关书籍和评论中检索的内容{context}回答问题。  



    在回答时,请注意以下几点:  

    - 结合李敖的写作风格和思想,筛选出与问题最相关的检索内容,避免无关信息。  

    - 如果问题涉及李敖对某人或某事的评价,优先引用李敖的直接言论或文字,并说明出处。  

    - 回答应结构化、分段落,确保逻辑清晰,语言生动,类似李敖的犀利风格。  

    - 如果检索内容不足以直接回答问题,可根据李敖的性格和观点推测其可能的看法,但需说明这是推测。  

    - 列出引用的书籍或文章名称及章节(如有),如《李敖大全集》第X卷或具体书名。  

    - 只能基于提供的知识库内容{context}回答,不得引入外部信息。  

    - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。

    - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内,并告诉用户可以查看搜索来源、获得完整信息。优先提供信息完整、最相关的列举项;如非必要,不要主动告诉用户搜索结果未提供的内容。

    - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。

    - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。

    - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。

    - 你的回答应该综合多个相关知识库内容来回答,不能重复引用一个知识库内容。

    - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。

    """
)

# 创建检索问答链
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=vector_store.as_retriever(search_kwargs={"k": 30}),
    return_source_documents=True,
    chain_type_kwargs={"prompt": prompt_template}
)

# 定义 Gradio 接口函数
def answer_question(question):
    try:
        # Step 1: FAISS 初始检索
        initial_docs_with_scores = vector_store.similarity_search_with_score(question, k=30)
        print(f"初始检索数量: {len(initial_docs_with_scores)}")
        
        # FAISS 返回的是距离,转换为相似度
        similarities = [1 - score for _, score in initial_docs_with_scores]
        print(f"相似度范围: {min(similarities):.4f} - {max(similarities):.4f}")
        
        # 打印前 5 个文档内容和相似度
        for i, (doc, score) in enumerate(initial_docs_with_scores[:5]):
            print(f"Top {i+1} - 相似度: {1 - score:.4f}, 内容: {doc.page_content[:100]}")

        # Step 2: 动态阈值过滤
        similarity_threshold = max(similarities) * 0.8
        filtered_docs_with_scores = [
            (doc, 1 - score)
            for doc, score in initial_docs_with_scores
            if (1 - score) >= similarity_threshold
        ]
        if len(filtered_docs_with_scores) < 5:
            filtered_docs_with_scores = initial_docs_with_scores[:10]
            print(f"过滤后数量不足,保留前 10 个文档")
        else:
            print(f"过滤后数量: {len(filtered_docs_with_scores)}")

        initial_docs = [doc for doc, _ in filtered_docs_with_scores]
        vector_similarities = [sim for _, sim in filtered_docs_with_scores]

        # Step 3: 重排序
        reranked_docs_with_scores = rerank_documents(question, initial_docs, os.environ["SILICONFLOW_API_KEY"], top_n=10)
        reranked_docs = [doc for doc, score in reranked_docs_with_scores]
        rerank_scores = [score for _, score in reranked_docs_with_scores]

        # Step 4: 融合得分并排序
        combined_scores = [
            0.2 * vector_similarities[i] + 0.8 * rerank_scores[i]
            for i in range(len(reranked_docs))
        ]
        sorted_docs_with_scores = sorted(
            zip(reranked_docs, combined_scores),
            key=lambda x: x[1],
            reverse=True
        )
        final_docs = [doc for doc, _ in sorted_docs_with_scores][:5]

        # Step 5: 生成回答
        context = "\n\n".join([doc.page_content for doc in final_docs])
        response = qa_chain.invoke({"query": question, "context": context})
        
        return response["result"]
    except Exception as e:
        return f"Error: {str(e)}"

# 创建 Gradio 界面
interface = gr.Interface(
    fn=answer_question,
    inputs=gr.Textbox(label="请输入您的问题"),
    outputs=gr.Textbox(label="回答"),
    title="AI李敖助手",
    description="基于李敖163本相关书籍构建的知识库,输入问题以获取李敖风格的回答。"
)

# 启动应用
if __name__ == "__main__":
    interface.launch(share=True)
>>>>>>> 921dc7e73a28368974490d7eba946303cf2129ba