Spaces:
Sleeping
Sleeping
neoguojing
commited on
Commit
·
864919f
1
Parent(s):
1405c00
fix
Browse files- .gitignore +10 -0
- app.py +24 -0
- retriever.py +20 -0
.gitignore
CHANGED
@@ -5,3 +5,13 @@ __pycache__/
|
|
5 |
files/input/ir2023_ashare.pdf
|
6 |
knowledge_bases/中国移动.faiss
|
7 |
knowledge_bases/中国移动.pkl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
files/input/ir2023_ashare.pdf
|
6 |
knowledge_bases/中国移动.faiss
|
7 |
knowledge_bases/中国移动.pkl
|
8 |
+
files/input/ClientList.csv
|
9 |
+
db.json
|
10 |
+
files/input/history_Example_20240109T225059.json
|
11 |
+
files/input/ir2023_ashare.txt
|
12 |
+
knowledge_bases/json.faiss
|
13 |
+
knowledge_bases/json.pkl
|
14 |
+
knowledge_bases/list.faiss
|
15 |
+
knowledge_bases/list.pkl
|
16 |
+
knowledge_bases/text.faiss
|
17 |
+
knowledge_bases/text.pkl
|
app.py
CHANGED
@@ -8,6 +8,10 @@ from sam_everything import SamAnything
|
|
8 |
from ocr import do_ocr
|
9 |
from retriever import knowledgeBase
|
10 |
import time
|
|
|
|
|
|
|
|
|
11 |
|
12 |
components = {}
|
13 |
|
@@ -137,6 +141,14 @@ def create_ui():
|
|
137 |
col_count=(1, "fixed"),
|
138 |
interactive=False
|
139 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
with gr.Column(scale=2):
|
141 |
with gr.Row():
|
142 |
with gr.Column(scale=2):
|
@@ -246,6 +258,10 @@ def create_event_handlers():
|
|
246 |
llm, gradio('ak','sk','llm_client'), None
|
247 |
)
|
248 |
|
|
|
|
|
|
|
|
|
249 |
def do_refernce(algo_type,input_image):
|
250 |
# def do_refernce():
|
251 |
print("input image",input_image)
|
@@ -422,6 +438,14 @@ def file_handler(file_objs,name):
|
|
422 |
dfs = knowledgeBase.get_df_bases()
|
423 |
return dfs,gr.CheckboxGroup(dbs,label="知识库", info="可选择1个或多个知识库"),gr.Dropdown(dbs,multiselect=True, label="知识库选择")
|
424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
def do_search(selected_dbs,user_input):
|
426 |
print("do_search:",selected_dbs,user_input)
|
427 |
context = knowledgeBase.retrieve_documents(selected_dbs,user_input)
|
|
|
8 |
from ocr import do_ocr
|
9 |
from retriever import knowledgeBase
|
10 |
import time
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
current_file_path = Path(__file__).resolve()
|
14 |
+
absolute_path = (current_file_path.parent / "files" / "input").resolve()
|
15 |
|
16 |
components = {}
|
17 |
|
|
|
141 |
col_count=(1, "fixed"),
|
142 |
interactive=False
|
143 |
)
|
144 |
+
components["file_expr"] = gr.FileExplorer(
|
145 |
+
scale=1,
|
146 |
+
value=[],
|
147 |
+
file_count="single",
|
148 |
+
root=absolute_path,
|
149 |
+
# ignore_glob="**/__init__.py",
|
150 |
+
elem_id="file_expr",
|
151 |
+
)
|
152 |
with gr.Column(scale=2):
|
153 |
with gr.Row():
|
154 |
with gr.Column(scale=2):
|
|
|
258 |
llm, gradio('ak','sk','llm_client'), None
|
259 |
)
|
260 |
|
261 |
+
components['db_view'].select(
|
262 |
+
db_expr, gradio('db_view'), gradio('file_expr')
|
263 |
+
)
|
264 |
+
|
265 |
def do_refernce(algo_type,input_image):
|
266 |
# def do_refernce():
|
267 |
print("input image",input_image)
|
|
|
438 |
dfs = knowledgeBase.get_df_bases()
|
439 |
return dfs,gr.CheckboxGroup(dbs,label="知识库", info="可选择1个或多个知识库"),gr.Dropdown(dbs,multiselect=True, label="知识库选择")
|
440 |
|
441 |
+
def db_expr(selected_index: gr.SelectData, dataframe_origin):
|
442 |
+
print("db_expr",selected_index.index)
|
443 |
+
|
444 |
+
dbname = dataframe_origin.iloc[selected_index.index[0],selected_index.index[1]]
|
445 |
+
print("db_expr",dbname)
|
446 |
+
|
447 |
+
return knowledgeBase.get_db_files(dbname)
|
448 |
+
|
449 |
def do_search(selected_dbs,user_input):
|
450 |
print("do_search:",selected_dbs,user_input)
|
451 |
context = knowledgeBase.retrieve_documents(selected_dbs,user_input)
|
retriever.py
CHANGED
@@ -5,6 +5,7 @@ from langchain_community.docstore.in_memory import InMemoryDocstore
|
|
5 |
import faiss
|
6 |
import os
|
7 |
import glob
|
|
|
8 |
from typing import Any,List,Dict
|
9 |
from embedding import Embedding
|
10 |
|
@@ -16,6 +17,7 @@ class KnowledgeBaseManager:
|
|
16 |
self.batch_size = batch_size
|
17 |
self.embeddings = Embedding()
|
18 |
self.knowledge_bases: Dict[str, FAISS] = {}
|
|
|
19 |
os.makedirs(self.base_path, exist_ok=True)
|
20 |
|
21 |
faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
|
@@ -33,12 +35,14 @@ class KnowledgeBaseManager:
|
|
33 |
return
|
34 |
|
35 |
self.knowledge_bases[name] = kb
|
|
|
36 |
self.save_knowledge_base(name)
|
37 |
print(f"Knowledge base '{name}' created.")
|
38 |
|
39 |
def delete_knowledge_base(self, name: str):
|
40 |
if name in self.knowledge_bases:
|
41 |
del self.knowledge_bases[name]
|
|
|
42 |
os.remove(os.path.join(self.base_path, f"{name}.faiss"))
|
43 |
print(f"Knowledge base '{name}' deleted.")
|
44 |
else:
|
@@ -48,6 +52,15 @@ class KnowledgeBaseManager:
|
|
48 |
kb_path = os.path.join(self.base_path, f"{name}.faiss")
|
49 |
if os.path.exists(kb_path):
|
50 |
self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
print(f"Knowledge base '{name}' loaded.")
|
52 |
else:
|
53 |
print(f"Knowledge base '{name}' does not exist.")
|
@@ -55,6 +68,8 @@ class KnowledgeBaseManager:
|
|
55 |
def save_knowledge_base(self, name: str):
|
56 |
if name in self.knowledge_bases:
|
57 |
self.knowledge_bases[name].save_local(self.base_path, name)
|
|
|
|
|
58 |
print(f"Knowledge base '{name}' saved.")
|
59 |
else:
|
60 |
print(f"Knowledge base '{name}' does not exist.")
|
@@ -72,6 +87,7 @@ class KnowledgeBaseManager:
|
|
72 |
self.create_knowledge_base(name)
|
73 |
|
74 |
kb = self.knowledge_bases[name]
|
|
|
75 |
documents = self.load_documents(file_paths)
|
76 |
print(f"Loaded {len(documents)} documents.")
|
77 |
print(documents)
|
@@ -138,6 +154,10 @@ class KnowledgeBaseManager:
|
|
138 |
|
139 |
return results
|
140 |
|
|
|
|
|
|
|
|
|
141 |
def get_bases(self):
|
142 |
data = self.knowledge_bases.keys()
|
143 |
return list(data)
|
|
|
5 |
import faiss
|
6 |
import os
|
7 |
import glob
|
8 |
+
import json
|
9 |
from typing import Any,List,Dict
|
10 |
from embedding import Embedding
|
11 |
|
|
|
17 |
self.batch_size = batch_size
|
18 |
self.embeddings = Embedding()
|
19 |
self.knowledge_bases: Dict[str, FAISS] = {}
|
20 |
+
self.db_files_map: Dict[str, list] = {}
|
21 |
os.makedirs(self.base_path, exist_ok=True)
|
22 |
|
23 |
faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
|
|
|
35 |
return
|
36 |
|
37 |
self.knowledge_bases[name] = kb
|
38 |
+
self.db_files_map[name] = []
|
39 |
self.save_knowledge_base(name)
|
40 |
print(f"Knowledge base '{name}' created.")
|
41 |
|
42 |
def delete_knowledge_base(self, name: str):
|
43 |
if name in self.knowledge_bases:
|
44 |
del self.knowledge_bases[name]
|
45 |
+
del self.db_files_map[name]
|
46 |
os.remove(os.path.join(self.base_path, f"{name}.faiss"))
|
47 |
print(f"Knowledge base '{name}' deleted.")
|
48 |
else:
|
|
|
52 |
kb_path = os.path.join(self.base_path, f"{name}.faiss")
|
53 |
if os.path.exists(kb_path):
|
54 |
self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
|
55 |
+
# 加载文件中的数据
|
56 |
+
try:
|
57 |
+
with open('db.json', 'r+') as f:
|
58 |
+
self.db_files_map = json.load(f)
|
59 |
+
except FileNotFoundError:
|
60 |
+
# 如果文件不存在,则创建一个空的文件并初始化 self.db_files_map
|
61 |
+
with open('db.json', 'w+') as f:
|
62 |
+
self.db_files_map = {}
|
63 |
+
json.dump(self.db_files_map, f)
|
64 |
print(f"Knowledge base '{name}' loaded.")
|
65 |
else:
|
66 |
print(f"Knowledge base '{name}' does not exist.")
|
|
|
68 |
def save_knowledge_base(self, name: str):
|
69 |
if name in self.knowledge_bases:
|
70 |
self.knowledge_bases[name].save_local(self.base_path, name)
|
71 |
+
with open('db.json', 'w') as f:
|
72 |
+
json.dump(self.db_files_map, f)
|
73 |
print(f"Knowledge base '{name}' saved.")
|
74 |
else:
|
75 |
print(f"Knowledge base '{name}' does not exist.")
|
|
|
87 |
self.create_knowledge_base(name)
|
88 |
|
89 |
kb = self.knowledge_bases[name]
|
90 |
+
self.db_files_map[name].extend([os.path.basename(file_path) for file_path in file_paths])
|
91 |
documents = self.load_documents(file_paths)
|
92 |
print(f"Loaded {len(documents)} documents.")
|
93 |
print(documents)
|
|
|
154 |
|
155 |
return results
|
156 |
|
157 |
+
def get_db_files(self,name):
|
158 |
+
data = self.db_files_map[name]
|
159 |
+
return data
|
160 |
+
|
161 |
def get_bases(self):
|
162 |
data = self.knowledge_bases.keys()
|
163 |
return list(data)
|