Spaces:
Sleeping
Sleeping
guoerjun
commited on
Commit
·
cc74372
1
Parent(s):
1315943
fix
Browse files- .gitignore +6 -0
- Dockerfile +25 -0
- app.py +199 -0
- config.py +8 -0
- embedding.py +60 -0
- llm.py +73 -0
- makefile +54 -0
- requirements.txt +6 -0
- retriever.py +170 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
db.json
|
2 |
+
__pycache__/embedding.cpython-312.pyc
|
3 |
+
__pycache__/retriever.cpython-312.pyc
|
4 |
+
files/input/SenseNebula AIS产品培训_20220727.pdf
|
5 |
+
knowledge_bases/ceshi .faiss
|
6 |
+
knowledge_bases/ceshi .pkl
|
Dockerfile
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 使用官方的 Python 基础镜像
|
2 |
+
FROM python:3.12-slim
|
3 |
+
|
4 |
+
# 设置工作目录
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# 复制依赖文件
|
8 |
+
COPY requirements.txt .
|
9 |
+
|
10 |
+
ENV PIP_NO_CACHE_DIR=off
|
11 |
+
# 设置环境变量
|
12 |
+
ENV PYTHONUNBUFFERED=1
|
13 |
+
|
14 |
+
# 安装 Python 依赖包
|
15 |
+
RUN pip install --upgrade pip
|
16 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
17 |
+
|
18 |
+
# 复制项目文件
|
19 |
+
COPY . .
|
20 |
+
|
21 |
+
# 暴露端口(如果需要)
|
22 |
+
EXPOSE 7860
|
23 |
+
|
24 |
+
# 运行项目
|
25 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
from gradio_image_prompter import ImagePrompter
|
5 |
+
import time
|
6 |
+
from pathlib import Path
|
7 |
+
from retriever import knowledgeBase
|
8 |
+
import llm
|
9 |
+
|
10 |
+
current_file_path = Path(__file__).resolve()
|
11 |
+
absolute_path = (current_file_path.parent / "files" / "input").resolve()
|
12 |
+
|
13 |
+
components = {}
|
14 |
+
|
15 |
+
params = {
|
16 |
+
"algo_type": None,
|
17 |
+
"input_image":None
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
def gradio(*keys):
|
22 |
+
if len(keys) == 1 and type(keys[0]) in [list, tuple]:
|
23 |
+
keys = keys[0]
|
24 |
+
|
25 |
+
return [components[k] for k in keys]
|
26 |
+
|
27 |
+
def create_ui():
|
28 |
+
with gr.Blocks() as demo:
|
29 |
+
with gr.Tab("知识库"):
|
30 |
+
with gr.Row():
|
31 |
+
with gr.Column(scale=1):
|
32 |
+
with gr.Group():
|
33 |
+
components["db_view"] = gr.Dataframe(
|
34 |
+
headers=["列表"],
|
35 |
+
datatype=["str"],
|
36 |
+
row_count=2,
|
37 |
+
col_count=(1, "fixed"),
|
38 |
+
interactive=False
|
39 |
+
)
|
40 |
+
components["file_expr"] = gr.FileExplorer(
|
41 |
+
scale=1,
|
42 |
+
value=[],
|
43 |
+
file_count="single",
|
44 |
+
root=absolute_path,
|
45 |
+
# ignore_glob="**/__init__.py",
|
46 |
+
elem_id="file_expr",
|
47 |
+
)
|
48 |
+
with gr.Column(scale=2):
|
49 |
+
with gr.Row():
|
50 |
+
with gr.Column(scale=2):
|
51 |
+
components["db_name"] = gr.Textbox(label="名称", info="请输入库名称", lines=1, value="")
|
52 |
+
with gr.Column(scale=2):
|
53 |
+
components["db_submit_btn"] = gr.Button(value="提交")
|
54 |
+
components["file_upload"] = gr.File(elem_id='file_upload',file_count='multiple',label='文档上传', file_types=[".pdf", ".doc", '.docx', '.json', '.csv'])
|
55 |
+
with gr.Row():
|
56 |
+
with gr.Column(scale=2):
|
57 |
+
components["db_input"] = gr.Textbox(label="关键词", lines=1, value="")
|
58 |
+
with gr.Column(scale=1):
|
59 |
+
components["db_test_select"] = gr.Dropdown(knowledgeBase.get_bases(),multiselect=True, label="知识库选择")
|
60 |
+
with gr.Column(scale=1):
|
61 |
+
components["dbtest_submit_btn"] = gr.Button(value="检索")
|
62 |
+
with gr.Row():
|
63 |
+
with gr.Group():
|
64 |
+
components["db_search_result"] = gr.JSON(label="检索结果")
|
65 |
+
|
66 |
+
with gr.Tab("问答"):
|
67 |
+
with gr.Row():
|
68 |
+
with gr.Column(scale=2):
|
69 |
+
with gr.Group():
|
70 |
+
components["chatbot"] = gr.Chatbot(
|
71 |
+
[(None,"你好,有什么需要帮助的?")],
|
72 |
+
elem_id="chatbot",
|
73 |
+
bubble_full_width=False,
|
74 |
+
height=600
|
75 |
+
)
|
76 |
+
components["chat_input"] = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
|
77 |
+
components["db_select"] = gr.CheckboxGroup(knowledgeBase.get_bases(),label="知识库", info="可选择1个或多个知识库")
|
78 |
+
create_event_handlers()
|
79 |
+
demo.load(init,None,gradio("db_view","db_select","db_test_select"))
|
80 |
+
return demo
|
81 |
+
|
82 |
+
def init():
|
83 |
+
db_list = knowledgeBase.get_bases()
|
84 |
+
db_df_list = knowledgeBase.get_df_bases()
|
85 |
+
return db_df_list,gr.CheckboxGroup(db_list,label="知识库", info="可选择1个或多个知识库"),gr.Dropdown(db_list,multiselect=True, label="知识库选择")
|
86 |
+
|
87 |
+
def create_event_handlers():
|
88 |
+
|
89 |
+
components["db_submit_btn"].click(
|
90 |
+
file_handler,gradio('file_upload','db_name'),gradio("db_view",'db_select',"db_test_select")
|
91 |
+
)
|
92 |
+
|
93 |
+
components["chat_input"].submit(
|
94 |
+
do_llm_request, gradio("chatbot", "chat_input"), gradio("chatbot", "chat_input")
|
95 |
+
).then(
|
96 |
+
do_llm_response, gradio("chatbot","db_select"), gradio("chatbot"), api_name="bot_response"
|
97 |
+
).then(
|
98 |
+
lambda: gr.MultimodalTextbox(interactive=True), None, gradio('chat_input')
|
99 |
+
)
|
100 |
+
|
101 |
+
# components["chatbot"].like(print_like_dislike, None, None)
|
102 |
+
|
103 |
+
components['dbtest_submit_btn'].click(
|
104 |
+
do_search, gradio('db_test_select','db_input'), gradio('db_search_result')
|
105 |
+
)
|
106 |
+
|
107 |
+
components['db_view'].select(
|
108 |
+
db_expr, gradio('db_view'), gradio('file_expr')
|
109 |
+
)
|
110 |
+
|
111 |
+
def print_like_dislike(x: gr.LikeData):
|
112 |
+
print(x.index, x.value, x.liked)
|
113 |
+
|
114 |
+
def do_llm_request(history, message):
|
115 |
+
for x in message["files"]:
|
116 |
+
history.append(((x,), None))
|
117 |
+
if message["text"] is not None:
|
118 |
+
history.append((message["text"], None))
|
119 |
+
return history, gr.MultimodalTextbox(value=None, interactive=False)
|
120 |
+
|
121 |
+
def do_llm_response(history,selected_dbs):
|
122 |
+
print("do_llm_response:",history,selected_dbs)
|
123 |
+
user_input = history[-1][0]
|
124 |
+
prompt = ""
|
125 |
+
quote = ""
|
126 |
+
if len(selected_dbs) > 0:
|
127 |
+
knowledge = knowledgeBase.retrieve_documents(selected_dbs,user_input)
|
128 |
+
print("do_llm_response context:",knowledge)
|
129 |
+
prompt = f'''
|
130 |
+
背景1:{knowledge[0]["content"]}
|
131 |
+
背景2:{knowledge[1]["content"]}
|
132 |
+
背景3:{knowledge[2]["content"]}
|
133 |
+
基于以上事实回答问题:{user_input}
|
134 |
+
'''
|
135 |
+
|
136 |
+
quote = f'''
|
137 |
+
> 文档:{knowledge[0]["meta"]["source"]},页码:{knowledge[0]["meta"]["page"]}
|
138 |
+
> 文档:{knowledge[1]["meta"]["source"]},页码:{knowledge[1]["meta"]["page"]}
|
139 |
+
> 文档:{knowledge[2]["meta"]["source"]},页码:{knowledge[2]["meta"]["page"]}
|
140 |
+
'''
|
141 |
+
else:
|
142 |
+
prompt = user_input
|
143 |
+
|
144 |
+
history[-1][1] = ""
|
145 |
+
if llm_client is None:
|
146 |
+
gr.Warning("请先设置大模型")
|
147 |
+
response = "模型参数未设置"
|
148 |
+
else:
|
149 |
+
print("do_llm_response prompt:",prompt)
|
150 |
+
response = llm_client(prompt)
|
151 |
+
response = response.removeprefix(prompt)
|
152 |
+
response += quote
|
153 |
+
|
154 |
+
for character in response:
|
155 |
+
history[-1][1] += character
|
156 |
+
time.sleep(0.01)
|
157 |
+
yield history
|
158 |
+
|
159 |
+
|
160 |
+
llm_client = llm.baidu_client
|
161 |
+
|
162 |
+
|
163 |
+
def file_handler(file_objs,name):
|
164 |
+
import shutil
|
165 |
+
import os
|
166 |
+
|
167 |
+
print("file_obj:",file_objs)
|
168 |
+
|
169 |
+
os.makedirs(os.path.dirname("./files/input/"), exist_ok=True)
|
170 |
+
|
171 |
+
for idx, file in enumerate(file_objs):
|
172 |
+
print(file)
|
173 |
+
file_path = "./files/input/" + os.path.basename(file.name)
|
174 |
+
if not os.path.exists(file_path):
|
175 |
+
shutil.move(file.name,"./files/input/")
|
176 |
+
|
177 |
+
knowledgeBase.add_documents_to_kb(name,[file_path])
|
178 |
+
|
179 |
+
dbs = knowledgeBase.get_bases()
|
180 |
+
dfs = knowledgeBase.get_df_bases()
|
181 |
+
return dfs,gr.CheckboxGroup(dbs,label="知识库", info="可选择1个或多个知识库"),gr.Dropdown(dbs,multiselect=True, label="知识库选择")
|
182 |
+
|
183 |
+
def db_expr(selected_index: gr.SelectData, dataframe_origin):
|
184 |
+
print("db_expr",selected_index.index)
|
185 |
+
|
186 |
+
dbname = dataframe_origin.iloc[selected_index.index[0],selected_index.index[1]]
|
187 |
+
print("db_expr",dbname)
|
188 |
+
|
189 |
+
return knowledgeBase.get_db_files(dbname)
|
190 |
+
|
191 |
+
def do_search(selected_dbs,user_input):
|
192 |
+
print("do_search:",selected_dbs,user_input)
|
193 |
+
context = knowledgeBase.retrieve_documents(selected_dbs,user_input)
|
194 |
+
return context
|
195 |
+
|
196 |
+
if __name__ == "__main__":
|
197 |
+
demo = create_ui()
|
198 |
+
# demo.launch(server_name="10.151.124.137")
|
199 |
+
demo.launch()
|
config.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
wenxin_ak = ""
|
3 |
+
wenxin_sk = ""
|
4 |
+
|
5 |
+
tongyi_ak = ""
|
6 |
+
tongyi_sk = ""
|
7 |
+
|
8 |
+
hg_token = ""
|
embedding.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
from typing import Any, List, Mapping, Optional,Union
|
3 |
+
from langchain.callbacks.manager import (
|
4 |
+
CallbackManagerForLLMRun
|
5 |
+
)
|
6 |
+
from langchain_core.embeddings import Embeddings
|
7 |
+
import torch
|
8 |
+
|
9 |
+
class Embedding(Embeddings):
|
10 |
+
|
11 |
+
def __init__(self,**kwargs):
|
12 |
+
self.model=AutoModel.from_pretrained('BAAI/bge-small-zh-v1.5')
|
13 |
+
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-zh-v1.5')
|
14 |
+
self.model.eval()
|
15 |
+
|
16 |
+
@property
|
17 |
+
def _llm_type(self) -> str:
|
18 |
+
return "BAAI/bge-small-zh-v1.5"
|
19 |
+
|
20 |
+
@property
|
21 |
+
def model_name(self) -> str:
|
22 |
+
return "embedding"
|
23 |
+
|
24 |
+
def _call(
|
25 |
+
self,
|
26 |
+
prompt: List[str],
|
27 |
+
stop: Optional[List[str]] = None,
|
28 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
29 |
+
**kwargs: Any,
|
30 |
+
) -> str:
|
31 |
+
encoded_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors='pt')
|
32 |
+
|
33 |
+
with torch.no_grad():
|
34 |
+
model_output = self.model(**encoded_input)
|
35 |
+
# Perform pooling. In this case, cls pooling.
|
36 |
+
sentence_embeddings = model_output[0][:, 0]
|
37 |
+
print(sentence_embeddings.shape)
|
38 |
+
# normalize embeddings
|
39 |
+
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
40 |
+
return sentence_embeddings.numpy()
|
41 |
+
|
42 |
+
@property
|
43 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
44 |
+
"""Get the identifying parameters."""
|
45 |
+
return {"model_path": self.model_path}
|
46 |
+
|
47 |
+
def embed_documents(self, texts) -> List[List[float]]:
|
48 |
+
# Embed a list of documents
|
49 |
+
embeddings = []
|
50 |
+
print("embed_documents:",len(texts),type(texts))
|
51 |
+
embedding = self._call(texts)
|
52 |
+
for row in embedding:
|
53 |
+
embeddings.append(row)
|
54 |
+
# print("embed_documents: shape",embeddings.shape)
|
55 |
+
return embeddings
|
56 |
+
|
57 |
+
def embed_query(self, text) -> List[float]:
|
58 |
+
# Embed a single query
|
59 |
+
embedding = self._call([text])
|
60 |
+
return embedding[0]
|
llm.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
from http import HTTPStatus
|
4 |
+
from dashscope import Application
|
5 |
+
import config
|
6 |
+
|
7 |
+
def baidu_client(input):
|
8 |
+
|
9 |
+
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k?access_token=" + get_access_token()
|
10 |
+
|
11 |
+
payload = json.dumps({
|
12 |
+
"temperature": 0.95,
|
13 |
+
"top_p": 0.7,
|
14 |
+
"penalty_score": 1,
|
15 |
+
"messages": [
|
16 |
+
{
|
17 |
+
"role": "user",
|
18 |
+
"content": input
|
19 |
+
}
|
20 |
+
],
|
21 |
+
"system": ""
|
22 |
+
})
|
23 |
+
headers = {
|
24 |
+
'Content-Type': 'application/json'
|
25 |
+
}
|
26 |
+
|
27 |
+
response = requests.request("POST", url, headers=headers, data=payload)
|
28 |
+
|
29 |
+
print("baidu_client",response.text)
|
30 |
+
return response.json()["result"]
|
31 |
+
|
32 |
+
|
33 |
+
def get_access_token():
|
34 |
+
"""
|
35 |
+
使用 AK,SK 生成鉴权签名(Access Token)
|
36 |
+
:return: access_token,或是None(如果错误)
|
37 |
+
"""
|
38 |
+
url = "https://aip.baidubce.com/oauth/2.0/token"
|
39 |
+
params = {"grant_type": "client_credentials", "client_id": config.wenxin_ak, "client_secret": config.wenxin_sk}
|
40 |
+
return str(requests.post(url, params=params).json().get("access_token"))
|
41 |
+
|
42 |
+
|
43 |
+
def qwen_agent_app(input):
|
44 |
+
response = Application.call(app_id=config.tongyi_ak,
|
45 |
+
prompt=input,
|
46 |
+
api_key=config.tongyi_sk,
|
47 |
+
)
|
48 |
+
|
49 |
+
if response.status_code != HTTPStatus.OK:
|
50 |
+
print('request_id=%s, code=%s, message=%s\n' % (response.request_id, response.status_code, response.message))
|
51 |
+
return ""
|
52 |
+
else:
|
53 |
+
print('request_id=%s\n output=%s\n usage=%s\n' % (response.request_id, response.output, response.usage))
|
54 |
+
return response.output["text"]
|
55 |
+
|
56 |
+
|
57 |
+
def hg_client(input):
|
58 |
+
|
59 |
+
import requests
|
60 |
+
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
|
61 |
+
headers = {"Authorization": f"Bearer {config.hg_token}"}
|
62 |
+
|
63 |
+
def query(payload):
|
64 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
65 |
+
return response.json()
|
66 |
+
|
67 |
+
output = query({
|
68 |
+
"inputs": input,
|
69 |
+
})
|
70 |
+
print(output)
|
71 |
+
if len(output) >0:
|
72 |
+
return output[0]['generated_text']
|
73 |
+
return ""
|
makefile
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 定义变量
|
2 |
+
IMAGE_NAME=guojingneo/rag-app
|
3 |
+
DOCKERFILE_PATH=Dockerfile
|
4 |
+
CONTAINER_NAME=rag-app-container
|
5 |
+
PORT=7860
|
6 |
+
|
7 |
+
# 获取 Git 提交 ID
|
8 |
+
COMMIT_ID := $(shell git rev-parse --short HEAD)
|
9 |
+
|
10 |
+
# 默认目标
|
11 |
+
.PHONY: all
|
12 |
+
all: build
|
13 |
+
|
14 |
+
# 构建 Docker 镜像
|
15 |
+
.PHONY: build
|
16 |
+
build:
|
17 |
+
docker build -t $(IMAGE_NAME):$(COMMIT_ID) -f $(DOCKERFILE_PATH) .
|
18 |
+
|
19 |
+
# 运行 Docker 容器
|
20 |
+
.PHONY: run
|
21 |
+
run:
|
22 |
+
docker run -d --name $(CONTAINER_NAME) -p $(PORT):$(PORT) $(IMAGE_NAME):$(COMMIT_ID)
|
23 |
+
|
24 |
+
# 停止并删除容器
|
25 |
+
.PHONY: stop
|
26 |
+
stop:
|
27 |
+
docker stop $(CONTAINER_NAME) || true
|
28 |
+
docker rm $(CONTAINER_NAME) || true
|
29 |
+
|
30 |
+
# 推送 Docker 镜像到注册表
|
31 |
+
.PHONY: push
|
32 |
+
push:
|
33 |
+
docker push $(IMAGE_NAME):$(COMMIT_ID)
|
34 |
+
|
35 |
+
# 清理未使用的 Docker 镜像和容器
|
36 |
+
.PHONY: clean
|
37 |
+
clean:
|
38 |
+
docker system prune -f
|
39 |
+
|
40 |
+
# 打包镜像并推送
|
41 |
+
.PHONY: package
|
42 |
+
package: build push
|
43 |
+
|
44 |
+
# 显示帮助信息
|
45 |
+
.PHONY: help
|
46 |
+
help:
|
47 |
+
@echo "使用方法:"
|
48 |
+
@echo " make build 构建 Docker 镜像"
|
49 |
+
@echo " make run 运行 Docker 容器"
|
50 |
+
@echo " make stop 停止并删除容器"
|
51 |
+
@echo " make push 推送 Docker 镜像到注册表"
|
52 |
+
@echo " make clean 清理未使用的 Docker 镜像和容器"
|
53 |
+
@echo " make package 构建并推送 Docker 镜像"
|
54 |
+
@echo " make help 显示帮助信息"
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
faiss-cpu==1.8.0
|
2 |
+
pypdf==4.2.0
|
3 |
+
langchain==0.2.5
|
4 |
+
langchain-community==0.2.5
|
5 |
+
transformers==4.32.1
|
6 |
+
dashscope==1.20.0
|
retriever.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import FAISS
|
2 |
+
from langchain_community.document_loaders import TextLoader, JSONLoader, PyPDFLoader
|
3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
+
from langchain_community.docstore.in_memory import InMemoryDocstore
|
5 |
+
import faiss
|
6 |
+
import os
|
7 |
+
import glob
|
8 |
+
import json
|
9 |
+
from typing import Any,List,Dict
|
10 |
+
from embedding import Embedding
|
11 |
+
|
12 |
+
|
13 |
+
class KnowledgeBaseManager:
|
14 |
+
def __init__(self, base_path="./knowledge_bases", embedding_dim=512, batch_size=16):
|
15 |
+
self.base_path = base_path
|
16 |
+
self.embedding_dim = embedding_dim
|
17 |
+
self.batch_size = batch_size
|
18 |
+
self.embeddings = Embedding()
|
19 |
+
self.knowledge_bases: Dict[str, FAISS] = {}
|
20 |
+
self.db_files_map: Dict[str, list] = {}
|
21 |
+
os.makedirs(self.base_path, exist_ok=True)
|
22 |
+
|
23 |
+
faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
|
24 |
+
# 获取不带后缀的名称
|
25 |
+
file_names_without_extension = [os.path.splitext(os.path.basename(file))[0] for file in faiss_files]
|
26 |
+
for name in file_names_without_extension:
|
27 |
+
self.load_knowledge_base(name)
|
28 |
+
|
29 |
+
|
30 |
+
def create_knowledge_base(self, name: str):
|
31 |
+
index = faiss.IndexFlatL2(self.embedding_dim)
|
32 |
+
kb = FAISS(self.embeddings, index, InMemoryDocstore(), {})
|
33 |
+
if name in self.knowledge_bases:
|
34 |
+
print(f"Knowledge base '{name}' already exists.")
|
35 |
+
return
|
36 |
+
|
37 |
+
self.knowledge_bases[name] = kb
|
38 |
+
self.db_files_map[name] = []
|
39 |
+
self.save_knowledge_base(name)
|
40 |
+
print(f"Knowledge base '{name}' created.")
|
41 |
+
|
42 |
+
def delete_knowledge_base(self, name: str):
|
43 |
+
if name in self.knowledge_bases:
|
44 |
+
del self.knowledge_bases[name]
|
45 |
+
del self.db_files_map[name]
|
46 |
+
os.remove(os.path.join(self.base_path, f"{name}.faiss"))
|
47 |
+
print(f"Knowledge base '{name}' deleted.")
|
48 |
+
else:
|
49 |
+
print(f"Knowledge base '{name}' does not exist.")
|
50 |
+
|
51 |
+
def load_knowledge_base(self, name: str):
|
52 |
+
kb_path = os.path.join(self.base_path, f"{name}.faiss")
|
53 |
+
if os.path.exists(kb_path):
|
54 |
+
self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
|
55 |
+
# 加载文件中的数据
|
56 |
+
try:
|
57 |
+
with open('db.json', 'r+') as f:
|
58 |
+
self.db_files_map = json.load(f)
|
59 |
+
except FileNotFoundError:
|
60 |
+
# 如果文件不存在,则创建一个空的文件并初始化 self.db_files_map
|
61 |
+
with open('db.json', 'w+') as f:
|
62 |
+
self.db_files_map = {}
|
63 |
+
json.dump(self.db_files_map, f)
|
64 |
+
print(f"Knowledge base '{name}' loaded.")
|
65 |
+
else:
|
66 |
+
print(f"Knowledge base '{name}' does not exist.")
|
67 |
+
|
68 |
+
def save_knowledge_base(self, name: str):
|
69 |
+
if name in self.knowledge_bases:
|
70 |
+
self.knowledge_bases[name].save_local(self.base_path, name)
|
71 |
+
with open('db.json', 'w') as f:
|
72 |
+
json.dump(self.db_files_map, f)
|
73 |
+
print(f"Knowledge base '{name}' saved.")
|
74 |
+
else:
|
75 |
+
print(f"Knowledge base '{name}' does not exist.")
|
76 |
+
|
77 |
+
# Document(page_content = '渠道版', metadata = {
|
78 |
+
# 'source': './files/input/PS004.pdf',
|
79 |
+
# 'page': 0
|
80 |
+
# }), Document(page_content = '2/20.', metadata = {
|
81 |
+
# 'source': './files/input/PS004.pdf',
|
82 |
+
# 'page': 1
|
83 |
+
# })
|
84 |
+
def add_documents_to_kb(self, name: str, file_paths: List[str]):
|
85 |
+
if name not in self.knowledge_bases:
|
86 |
+
print(f"Knowledge base '{name}' does not exist.")
|
87 |
+
self.create_knowledge_base(name)
|
88 |
+
|
89 |
+
kb = self.knowledge_bases[name]
|
90 |
+
self.db_files_map[name].extend([os.path.basename(file_path) for file_path in file_paths])
|
91 |
+
documents = self.load_documents(file_paths)
|
92 |
+
print(f"Loaded {len(documents)} documents.")
|
93 |
+
print(documents)
|
94 |
+
pages = self.split_documents(documents)
|
95 |
+
print(f"Split documents into {len(pages)} pages.")
|
96 |
+
# print(pages)
|
97 |
+
|
98 |
+
doc_ids = []
|
99 |
+
for i in range(0, len(pages), self.batch_size):
|
100 |
+
batch = pages[i:i+self.batch_size]
|
101 |
+
doc_ids.extend(kb.add_documents(batch))
|
102 |
+
|
103 |
+
self.save_knowledge_base(name)
|
104 |
+
return doc_ids
|
105 |
+
|
106 |
+
def load_documents(self, file_paths: List[str]):
|
107 |
+
documents = []
|
108 |
+
for file_path in file_paths:
|
109 |
+
loader = self.get_loader(file_path)
|
110 |
+
documents.extend(loader.load())
|
111 |
+
return documents
|
112 |
+
|
113 |
+
def get_loader(self, file_path: str):
|
114 |
+
if file_path.endswith('.txt'):
|
115 |
+
return TextLoader(file_path)
|
116 |
+
elif file_path.endswith('.json'):
|
117 |
+
return JSONLoader(file_path)
|
118 |
+
elif file_path.endswith('.pdf'):
|
119 |
+
return PyPDFLoader(file_path)
|
120 |
+
else:
|
121 |
+
raise ValueError("Unsupported file format")
|
122 |
+
|
123 |
+
def split_documents(self, documents):
|
124 |
+
text_splitter = RecursiveCharacterTextSplitter(separators=[
|
125 |
+
"\n\n",
|
126 |
+
"\n",
|
127 |
+
" ",
|
128 |
+
".",
|
129 |
+
",",
|
130 |
+
"\u200b", # Zero-width space
|
131 |
+
"\uff0c", # Fullwidth comma
|
132 |
+
"\u3001", # Ideographic comma
|
133 |
+
"\uff0e", # Fullwidth full stop
|
134 |
+
"\u3002", # Ideographic full stop
|
135 |
+
"",
|
136 |
+
],
|
137 |
+
chunk_size=512, chunk_overlap=0)
|
138 |
+
return text_splitter.split_documents(documents)
|
139 |
+
|
140 |
+
def retrieve_documents(self, names: List[str], query: str):
|
141 |
+
results = []
|
142 |
+
for name in names:
|
143 |
+
if name not in self.knowledge_bases:
|
144 |
+
print(f"Knowledge base '{name}' does not exist.")
|
145 |
+
continue
|
146 |
+
|
147 |
+
retriever = self.knowledge_bases[name].as_retriever(
|
148 |
+
search_type="mmr",
|
149 |
+
search_kwargs={"score_threshold": 0.5, "k": 3}
|
150 |
+
)
|
151 |
+
docs = retriever.get_relevant_documents(query)
|
152 |
+
results.extend([{"name": name, "content": doc.page_content,"meta": doc.metadata} for doc in docs])
|
153 |
+
|
154 |
+
|
155 |
+
return results
|
156 |
+
|
157 |
+
def get_db_files(self,name):
|
158 |
+
data = self.db_files_map[name]
|
159 |
+
return data
|
160 |
+
|
161 |
+
def get_bases(self):
|
162 |
+
data = self.knowledge_bases.keys()
|
163 |
+
return list(data)
|
164 |
+
|
165 |
+
def get_df_bases(self):
|
166 |
+
import pandas as pd
|
167 |
+
data = self.knowledge_bases.keys()
|
168 |
+
return pd.DataFrame(list(data), columns=['列表'])
|
169 |
+
|
170 |
+
knowledgeBase = KnowledgeBaseManager()
|