guoerjun commited on
Commit
cc74372
·
1 Parent(s): 1315943
Files changed (9) hide show
  1. .gitignore +6 -0
  2. Dockerfile +25 -0
  3. app.py +199 -0
  4. config.py +8 -0
  5. embedding.py +60 -0
  6. llm.py +73 -0
  7. makefile +54 -0
  8. requirements.txt +6 -0
  9. retriever.py +170 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ db.json
2
+ __pycache__/embedding.cpython-312.pyc
3
+ __pycache__/retriever.cpython-312.pyc
4
+ files/input/SenseNebula AIS产品培训_20220727.pdf
5
+ knowledge_bases/ceshi .faiss
6
+ knowledge_bases/ceshi .pkl
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 使用官方的 Python 基础镜像
2
+ FROM python:3.12-slim
3
+
4
+ # 设置工作目录
5
+ WORKDIR /app
6
+
7
+ # 复制依赖文件
8
+ COPY requirements.txt .
9
+
10
+ ENV PIP_NO_CACHE_DIR=off
11
+ # 设置环境变量
12
+ ENV PYTHONUNBUFFERED=1
13
+
14
+ # 安装 Python 依赖包
15
+ RUN pip install --upgrade pip
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # 复制项目文件
19
+ COPY . .
20
+
21
+ # 暴露端口(如果需要)
22
+ EXPOSE 7860
23
+
24
+ # 运行项目
25
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import numpy as np
4
+ from gradio_image_prompter import ImagePrompter
5
+ import time
6
+ from pathlib import Path
7
+ from retriever import knowledgeBase
8
+ import llm
9
+
10
+ current_file_path = Path(__file__).resolve()
11
+ absolute_path = (current_file_path.parent / "files" / "input").resolve()
12
+
13
+ components = {}
14
+
15
+ params = {
16
+ "algo_type": None,
17
+ "input_image":None
18
+ }
19
+
20
+
21
+ def gradio(*keys):
22
+ if len(keys) == 1 and type(keys[0]) in [list, tuple]:
23
+ keys = keys[0]
24
+
25
+ return [components[k] for k in keys]
26
+
27
+ def create_ui():
28
+ with gr.Blocks() as demo:
29
+ with gr.Tab("知识库"):
30
+ with gr.Row():
31
+ with gr.Column(scale=1):
32
+ with gr.Group():
33
+ components["db_view"] = gr.Dataframe(
34
+ headers=["列表"],
35
+ datatype=["str"],
36
+ row_count=2,
37
+ col_count=(1, "fixed"),
38
+ interactive=False
39
+ )
40
+ components["file_expr"] = gr.FileExplorer(
41
+ scale=1,
42
+ value=[],
43
+ file_count="single",
44
+ root=absolute_path,
45
+ # ignore_glob="**/__init__.py",
46
+ elem_id="file_expr",
47
+ )
48
+ with gr.Column(scale=2):
49
+ with gr.Row():
50
+ with gr.Column(scale=2):
51
+ components["db_name"] = gr.Textbox(label="名称", info="请输入库名称", lines=1, value="")
52
+ with gr.Column(scale=2):
53
+ components["db_submit_btn"] = gr.Button(value="提交")
54
+ components["file_upload"] = gr.File(elem_id='file_upload',file_count='multiple',label='文档上传', file_types=[".pdf", ".doc", '.docx', '.json', '.csv'])
55
+ with gr.Row():
56
+ with gr.Column(scale=2):
57
+ components["db_input"] = gr.Textbox(label="关键词", lines=1, value="")
58
+ with gr.Column(scale=1):
59
+ components["db_test_select"] = gr.Dropdown(knowledgeBase.get_bases(),multiselect=True, label="知识库选择")
60
+ with gr.Column(scale=1):
61
+ components["dbtest_submit_btn"] = gr.Button(value="检索")
62
+ with gr.Row():
63
+ with gr.Group():
64
+ components["db_search_result"] = gr.JSON(label="检索结果")
65
+
66
+ with gr.Tab("问答"):
67
+ with gr.Row():
68
+ with gr.Column(scale=2):
69
+ with gr.Group():
70
+ components["chatbot"] = gr.Chatbot(
71
+ [(None,"你好,有什么需要帮助的?")],
72
+ elem_id="chatbot",
73
+ bubble_full_width=False,
74
+ height=600
75
+ )
76
+ components["chat_input"] = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
77
+ components["db_select"] = gr.CheckboxGroup(knowledgeBase.get_bases(),label="知识库", info="可选择1个或多个知识库")
78
+ create_event_handlers()
79
+ demo.load(init,None,gradio("db_view","db_select","db_test_select"))
80
+ return demo
81
+
82
+ def init():
83
+ db_list = knowledgeBase.get_bases()
84
+ db_df_list = knowledgeBase.get_df_bases()
85
+ return db_df_list,gr.CheckboxGroup(db_list,label="知识库", info="可选择1个或多个知识库"),gr.Dropdown(db_list,multiselect=True, label="知识库选择")
86
+
87
+ def create_event_handlers():
88
+
89
+ components["db_submit_btn"].click(
90
+ file_handler,gradio('file_upload','db_name'),gradio("db_view",'db_select',"db_test_select")
91
+ )
92
+
93
+ components["chat_input"].submit(
94
+ do_llm_request, gradio("chatbot", "chat_input"), gradio("chatbot", "chat_input")
95
+ ).then(
96
+ do_llm_response, gradio("chatbot","db_select"), gradio("chatbot"), api_name="bot_response"
97
+ ).then(
98
+ lambda: gr.MultimodalTextbox(interactive=True), None, gradio('chat_input')
99
+ )
100
+
101
+ # components["chatbot"].like(print_like_dislike, None, None)
102
+
103
+ components['dbtest_submit_btn'].click(
104
+ do_search, gradio('db_test_select','db_input'), gradio('db_search_result')
105
+ )
106
+
107
+ components['db_view'].select(
108
+ db_expr, gradio('db_view'), gradio('file_expr')
109
+ )
110
+
111
+ def print_like_dislike(x: gr.LikeData):
112
+ print(x.index, x.value, x.liked)
113
+
114
+ def do_llm_request(history, message):
115
+ for x in message["files"]:
116
+ history.append(((x,), None))
117
+ if message["text"] is not None:
118
+ history.append((message["text"], None))
119
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
120
+
121
+ def do_llm_response(history,selected_dbs):
122
+ print("do_llm_response:",history,selected_dbs)
123
+ user_input = history[-1][0]
124
+ prompt = ""
125
+ quote = ""
126
+ if len(selected_dbs) > 0:
127
+ knowledge = knowledgeBase.retrieve_documents(selected_dbs,user_input)
128
+ print("do_llm_response context:",knowledge)
129
+ prompt = f'''
130
+ 背景1:{knowledge[0]["content"]}
131
+ 背景2:{knowledge[1]["content"]}
132
+ 背景3:{knowledge[2]["content"]}
133
+ 基于以上事实回答问题:{user_input}
134
+ '''
135
+
136
+ quote = f'''
137
+ > 文档:{knowledge[0]["meta"]["source"]},页码:{knowledge[0]["meta"]["page"]}
138
+ > 文档:{knowledge[1]["meta"]["source"]},页码:{knowledge[1]["meta"]["page"]}
139
+ > 文档:{knowledge[2]["meta"]["source"]},页码:{knowledge[2]["meta"]["page"]}
140
+ '''
141
+ else:
142
+ prompt = user_input
143
+
144
+ history[-1][1] = ""
145
+ if llm_client is None:
146
+ gr.Warning("请先设置大模型")
147
+ response = "模型参数未设置"
148
+ else:
149
+ print("do_llm_response prompt:",prompt)
150
+ response = llm_client(prompt)
151
+ response = response.removeprefix(prompt)
152
+ response += quote
153
+
154
+ for character in response:
155
+ history[-1][1] += character
156
+ time.sleep(0.01)
157
+ yield history
158
+
159
+
160
+ llm_client = llm.baidu_client
161
+
162
+
163
+ def file_handler(file_objs,name):
164
+ import shutil
165
+ import os
166
+
167
+ print("file_obj:",file_objs)
168
+
169
+ os.makedirs(os.path.dirname("./files/input/"), exist_ok=True)
170
+
171
+ for idx, file in enumerate(file_objs):
172
+ print(file)
173
+ file_path = "./files/input/" + os.path.basename(file.name)
174
+ if not os.path.exists(file_path):
175
+ shutil.move(file.name,"./files/input/")
176
+
177
+ knowledgeBase.add_documents_to_kb(name,[file_path])
178
+
179
+ dbs = knowledgeBase.get_bases()
180
+ dfs = knowledgeBase.get_df_bases()
181
+ return dfs,gr.CheckboxGroup(dbs,label="知识库", info="可选择1个或多个知识库"),gr.Dropdown(dbs,multiselect=True, label="知识库选择")
182
+
183
+ def db_expr(selected_index: gr.SelectData, dataframe_origin):
184
+ print("db_expr",selected_index.index)
185
+
186
+ dbname = dataframe_origin.iloc[selected_index.index[0],selected_index.index[1]]
187
+ print("db_expr",dbname)
188
+
189
+ return knowledgeBase.get_db_files(dbname)
190
+
191
+ def do_search(selected_dbs,user_input):
192
+ print("do_search:",selected_dbs,user_input)
193
+ context = knowledgeBase.retrieve_documents(selected_dbs,user_input)
194
+ return context
195
+
196
+ if __name__ == "__main__":
197
+ demo = create_ui()
198
+ # demo.launch(server_name="10.151.124.137")
199
+ demo.launch()
config.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ wenxin_ak = ""
3
+ wenxin_sk = ""
4
+
5
+ tongyi_ak = ""
6
+ tongyi_sk = ""
7
+
8
+ hg_token = ""
embedding.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ from typing import Any, List, Mapping, Optional,Union
3
+ from langchain.callbacks.manager import (
4
+ CallbackManagerForLLMRun
5
+ )
6
+ from langchain_core.embeddings import Embeddings
7
+ import torch
8
+
9
+ class Embedding(Embeddings):
10
+
11
+ def __init__(self,**kwargs):
12
+ self.model=AutoModel.from_pretrained('BAAI/bge-small-zh-v1.5')
13
+ self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-zh-v1.5')
14
+ self.model.eval()
15
+
16
+ @property
17
+ def _llm_type(self) -> str:
18
+ return "BAAI/bge-small-zh-v1.5"
19
+
20
+ @property
21
+ def model_name(self) -> str:
22
+ return "embedding"
23
+
24
+ def _call(
25
+ self,
26
+ prompt: List[str],
27
+ stop: Optional[List[str]] = None,
28
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
29
+ **kwargs: Any,
30
+ ) -> str:
31
+ encoded_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors='pt')
32
+
33
+ with torch.no_grad():
34
+ model_output = self.model(**encoded_input)
35
+ # Perform pooling. In this case, cls pooling.
36
+ sentence_embeddings = model_output[0][:, 0]
37
+ print(sentence_embeddings.shape)
38
+ # normalize embeddings
39
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
40
+ return sentence_embeddings.numpy()
41
+
42
+ @property
43
+ def _identifying_params(self) -> Mapping[str, Any]:
44
+ """Get the identifying parameters."""
45
+ return {"model_path": self.model_path}
46
+
47
+ def embed_documents(self, texts) -> List[List[float]]:
48
+ # Embed a list of documents
49
+ embeddings = []
50
+ print("embed_documents:",len(texts),type(texts))
51
+ embedding = self._call(texts)
52
+ for row in embedding:
53
+ embeddings.append(row)
54
+ # print("embed_documents: shape",embeddings.shape)
55
+ return embeddings
56
+
57
+ def embed_query(self, text) -> List[float]:
58
+ # Embed a single query
59
+ embedding = self._call([text])
60
+ return embedding[0]
llm.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ from http import HTTPStatus
4
+ from dashscope import Application
5
+ import config
6
+
7
+ def baidu_client(input):
8
+
9
+ url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k?access_token=" + get_access_token()
10
+
11
+ payload = json.dumps({
12
+ "temperature": 0.95,
13
+ "top_p": 0.7,
14
+ "penalty_score": 1,
15
+ "messages": [
16
+ {
17
+ "role": "user",
18
+ "content": input
19
+ }
20
+ ],
21
+ "system": ""
22
+ })
23
+ headers = {
24
+ 'Content-Type': 'application/json'
25
+ }
26
+
27
+ response = requests.request("POST", url, headers=headers, data=payload)
28
+
29
+ print("baidu_client",response.text)
30
+ return response.json()["result"]
31
+
32
+
33
+ def get_access_token():
34
+ """
35
+ 使用 AK,SK 生成鉴权签名(Access Token)
36
+ :return: access_token,或是None(如果错误)
37
+ """
38
+ url = "https://aip.baidubce.com/oauth/2.0/token"
39
+ params = {"grant_type": "client_credentials", "client_id": config.wenxin_ak, "client_secret": config.wenxin_sk}
40
+ return str(requests.post(url, params=params).json().get("access_token"))
41
+
42
+
43
+ def qwen_agent_app(input):
44
+ response = Application.call(app_id=config.tongyi_ak,
45
+ prompt=input,
46
+ api_key=config.tongyi_sk,
47
+ )
48
+
49
+ if response.status_code != HTTPStatus.OK:
50
+ print('request_id=%s, code=%s, message=%s\n' % (response.request_id, response.status_code, response.message))
51
+ return ""
52
+ else:
53
+ print('request_id=%s\n output=%s\n usage=%s\n' % (response.request_id, response.output, response.usage))
54
+ return response.output["text"]
55
+
56
+
57
+ def hg_client(input):
58
+
59
+ import requests
60
+ API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
61
+ headers = {"Authorization": f"Bearer {config.hg_token}"}
62
+
63
+ def query(payload):
64
+ response = requests.post(API_URL, headers=headers, json=payload)
65
+ return response.json()
66
+
67
+ output = query({
68
+ "inputs": input,
69
+ })
70
+ print(output)
71
+ if len(output) >0:
72
+ return output[0]['generated_text']
73
+ return ""
makefile ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 定义变量
2
+ IMAGE_NAME=guojingneo/rag-app
3
+ DOCKERFILE_PATH=Dockerfile
4
+ CONTAINER_NAME=rag-app-container
5
+ PORT=7860
6
+
7
+ # 获取 Git 提交 ID
8
+ COMMIT_ID := $(shell git rev-parse --short HEAD)
9
+
10
+ # 默认目标
11
+ .PHONY: all
12
+ all: build
13
+
14
+ # 构建 Docker 镜像
15
+ .PHONY: build
16
+ build:
17
+ docker build -t $(IMAGE_NAME):$(COMMIT_ID) -f $(DOCKERFILE_PATH) .
18
+
19
+ # 运行 Docker 容器
20
+ .PHONY: run
21
+ run:
22
+ docker run -d --name $(CONTAINER_NAME) -p $(PORT):$(PORT) $(IMAGE_NAME):$(COMMIT_ID)
23
+
24
+ # 停止并删除容器
25
+ .PHONY: stop
26
+ stop:
27
+ docker stop $(CONTAINER_NAME) || true
28
+ docker rm $(CONTAINER_NAME) || true
29
+
30
+ # 推送 Docker 镜像到注册表
31
+ .PHONY: push
32
+ push:
33
+ docker push $(IMAGE_NAME):$(COMMIT_ID)
34
+
35
+ # 清理未使用的 Docker 镜像和容器
36
+ .PHONY: clean
37
+ clean:
38
+ docker system prune -f
39
+
40
+ # 打包镜像并推送
41
+ .PHONY: package
42
+ package: build push
43
+
44
+ # 显示帮助信息
45
+ .PHONY: help
46
+ help:
47
+ @echo "使用方法:"
48
+ @echo " make build 构建 Docker 镜像"
49
+ @echo " make run 运行 Docker 容器"
50
+ @echo " make stop 停止并删除容器"
51
+ @echo " make push 推送 Docker 镜像到注册表"
52
+ @echo " make clean 清理未使用的 Docker 镜像和容器"
53
+ @echo " make package 构建并推送 Docker 镜像"
54
+ @echo " make help 显示帮助信息"
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ faiss-cpu==1.8.0
2
+ pypdf==4.2.0
3
+ langchain==0.2.5
4
+ langchain-community==0.2.5
5
+ transformers==4.32.1
6
+ dashscope==1.20.0
retriever.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from langchain_community.document_loaders import TextLoader, JSONLoader, PyPDFLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.docstore.in_memory import InMemoryDocstore
5
+ import faiss
6
+ import os
7
+ import glob
8
+ import json
9
+ from typing import Any,List,Dict
10
+ from embedding import Embedding
11
+
12
+
13
+ class KnowledgeBaseManager:
14
+ def __init__(self, base_path="./knowledge_bases", embedding_dim=512, batch_size=16):
15
+ self.base_path = base_path
16
+ self.embedding_dim = embedding_dim
17
+ self.batch_size = batch_size
18
+ self.embeddings = Embedding()
19
+ self.knowledge_bases: Dict[str, FAISS] = {}
20
+ self.db_files_map: Dict[str, list] = {}
21
+ os.makedirs(self.base_path, exist_ok=True)
22
+
23
+ faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
24
+ # 获取不带后缀的名称
25
+ file_names_without_extension = [os.path.splitext(os.path.basename(file))[0] for file in faiss_files]
26
+ for name in file_names_without_extension:
27
+ self.load_knowledge_base(name)
28
+
29
+
30
+ def create_knowledge_base(self, name: str):
31
+ index = faiss.IndexFlatL2(self.embedding_dim)
32
+ kb = FAISS(self.embeddings, index, InMemoryDocstore(), {})
33
+ if name in self.knowledge_bases:
34
+ print(f"Knowledge base '{name}' already exists.")
35
+ return
36
+
37
+ self.knowledge_bases[name] = kb
38
+ self.db_files_map[name] = []
39
+ self.save_knowledge_base(name)
40
+ print(f"Knowledge base '{name}' created.")
41
+
42
+ def delete_knowledge_base(self, name: str):
43
+ if name in self.knowledge_bases:
44
+ del self.knowledge_bases[name]
45
+ del self.db_files_map[name]
46
+ os.remove(os.path.join(self.base_path, f"{name}.faiss"))
47
+ print(f"Knowledge base '{name}' deleted.")
48
+ else:
49
+ print(f"Knowledge base '{name}' does not exist.")
50
+
51
+ def load_knowledge_base(self, name: str):
52
+ kb_path = os.path.join(self.base_path, f"{name}.faiss")
53
+ if os.path.exists(kb_path):
54
+ self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
55
+ # 加载文件中的数据
56
+ try:
57
+ with open('db.json', 'r+') as f:
58
+ self.db_files_map = json.load(f)
59
+ except FileNotFoundError:
60
+ # 如果文件不存在,则创建一个空的文件并初始化 self.db_files_map
61
+ with open('db.json', 'w+') as f:
62
+ self.db_files_map = {}
63
+ json.dump(self.db_files_map, f)
64
+ print(f"Knowledge base '{name}' loaded.")
65
+ else:
66
+ print(f"Knowledge base '{name}' does not exist.")
67
+
68
+ def save_knowledge_base(self, name: str):
69
+ if name in self.knowledge_bases:
70
+ self.knowledge_bases[name].save_local(self.base_path, name)
71
+ with open('db.json', 'w') as f:
72
+ json.dump(self.db_files_map, f)
73
+ print(f"Knowledge base '{name}' saved.")
74
+ else:
75
+ print(f"Knowledge base '{name}' does not exist.")
76
+
77
+ # Document(page_content = '渠道版', metadata = {
78
+ # 'source': './files/input/PS004.pdf',
79
+ # 'page': 0
80
+ # }), Document(page_content = '2/20.', metadata = {
81
+ # 'source': './files/input/PS004.pdf',
82
+ # 'page': 1
83
+ # })
84
+ def add_documents_to_kb(self, name: str, file_paths: List[str]):
85
+ if name not in self.knowledge_bases:
86
+ print(f"Knowledge base '{name}' does not exist.")
87
+ self.create_knowledge_base(name)
88
+
89
+ kb = self.knowledge_bases[name]
90
+ self.db_files_map[name].extend([os.path.basename(file_path) for file_path in file_paths])
91
+ documents = self.load_documents(file_paths)
92
+ print(f"Loaded {len(documents)} documents.")
93
+ print(documents)
94
+ pages = self.split_documents(documents)
95
+ print(f"Split documents into {len(pages)} pages.")
96
+ # print(pages)
97
+
98
+ doc_ids = []
99
+ for i in range(0, len(pages), self.batch_size):
100
+ batch = pages[i:i+self.batch_size]
101
+ doc_ids.extend(kb.add_documents(batch))
102
+
103
+ self.save_knowledge_base(name)
104
+ return doc_ids
105
+
106
+ def load_documents(self, file_paths: List[str]):
107
+ documents = []
108
+ for file_path in file_paths:
109
+ loader = self.get_loader(file_path)
110
+ documents.extend(loader.load())
111
+ return documents
112
+
113
+ def get_loader(self, file_path: str):
114
+ if file_path.endswith('.txt'):
115
+ return TextLoader(file_path)
116
+ elif file_path.endswith('.json'):
117
+ return JSONLoader(file_path)
118
+ elif file_path.endswith('.pdf'):
119
+ return PyPDFLoader(file_path)
120
+ else:
121
+ raise ValueError("Unsupported file format")
122
+
123
+ def split_documents(self, documents):
124
+ text_splitter = RecursiveCharacterTextSplitter(separators=[
125
+ "\n\n",
126
+ "\n",
127
+ " ",
128
+ ".",
129
+ ",",
130
+ "\u200b", # Zero-width space
131
+ "\uff0c", # Fullwidth comma
132
+ "\u3001", # Ideographic comma
133
+ "\uff0e", # Fullwidth full stop
134
+ "\u3002", # Ideographic full stop
135
+ "",
136
+ ],
137
+ chunk_size=512, chunk_overlap=0)
138
+ return text_splitter.split_documents(documents)
139
+
140
+ def retrieve_documents(self, names: List[str], query: str):
141
+ results = []
142
+ for name in names:
143
+ if name not in self.knowledge_bases:
144
+ print(f"Knowledge base '{name}' does not exist.")
145
+ continue
146
+
147
+ retriever = self.knowledge_bases[name].as_retriever(
148
+ search_type="mmr",
149
+ search_kwargs={"score_threshold": 0.5, "k": 3}
150
+ )
151
+ docs = retriever.get_relevant_documents(query)
152
+ results.extend([{"name": name, "content": doc.page_content,"meta": doc.metadata} for doc in docs])
153
+
154
+
155
+ return results
156
+
157
+ def get_db_files(self,name):
158
+ data = self.db_files_map[name]
159
+ return data
160
+
161
+ def get_bases(self):
162
+ data = self.knowledge_bases.keys()
163
+ return list(data)
164
+
165
+ def get_df_bases(self):
166
+ import pandas as pd
167
+ data = self.knowledge_bases.keys()
168
+ return pd.DataFrame(list(data), columns=['列表'])
169
+
170
+ knowledgeBase = KnowledgeBaseManager()