neoguojing commited on
Commit
864919f
·
1 Parent(s): 1405c00
Files changed (3) hide show
  1. .gitignore +10 -0
  2. app.py +24 -0
  3. retriever.py +20 -0
.gitignore CHANGED
@@ -5,3 +5,13 @@ __pycache__/
5
  files/input/ir2023_ashare.pdf
6
  knowledge_bases/中国移动.faiss
7
  knowledge_bases/中国移动.pkl
 
 
 
 
 
 
 
 
 
 
 
5
  files/input/ir2023_ashare.pdf
6
  knowledge_bases/中国移动.faiss
7
  knowledge_bases/中国移动.pkl
8
+ files/input/ClientList.csv
9
+ db.json
10
+ files/input/history_Example_20240109T225059.json
11
+ files/input/ir2023_ashare.txt
12
+ knowledge_bases/json.faiss
13
+ knowledge_bases/json.pkl
14
+ knowledge_bases/list.faiss
15
+ knowledge_bases/list.pkl
16
+ knowledge_bases/text.faiss
17
+ knowledge_bases/text.pkl
app.py CHANGED
@@ -8,6 +8,10 @@ from sam_everything import SamAnything
8
  from ocr import do_ocr
9
  from retriever import knowledgeBase
10
  import time
 
 
 
 
11
 
12
  components = {}
13
 
@@ -137,6 +141,14 @@ def create_ui():
137
  col_count=(1, "fixed"),
138
  interactive=False
139
  )
 
 
 
 
 
 
 
 
140
  with gr.Column(scale=2):
141
  with gr.Row():
142
  with gr.Column(scale=2):
@@ -246,6 +258,10 @@ def create_event_handlers():
246
  llm, gradio('ak','sk','llm_client'), None
247
  )
248
 
 
 
 
 
249
  def do_refernce(algo_type,input_image):
250
  # def do_refernce():
251
  print("input image",input_image)
@@ -422,6 +438,14 @@ def file_handler(file_objs,name):
422
  dfs = knowledgeBase.get_df_bases()
423
  return dfs,gr.CheckboxGroup(dbs,label="知识库", info="可选择1个或多个知识库"),gr.Dropdown(dbs,multiselect=True, label="知识库选择")
424
 
 
 
 
 
 
 
 
 
425
  def do_search(selected_dbs,user_input):
426
  print("do_search:",selected_dbs,user_input)
427
  context = knowledgeBase.retrieve_documents(selected_dbs,user_input)
 
8
  from ocr import do_ocr
9
  from retriever import knowledgeBase
10
  import time
11
+ from pathlib import Path
12
+
13
+ current_file_path = Path(__file__).resolve()
14
+ absolute_path = (current_file_path.parent / "files" / "input").resolve()
15
 
16
  components = {}
17
 
 
141
  col_count=(1, "fixed"),
142
  interactive=False
143
  )
144
+ components["file_expr"] = gr.FileExplorer(
145
+ scale=1,
146
+ value=[],
147
+ file_count="single",
148
+ root=absolute_path,
149
+ # ignore_glob="**/__init__.py",
150
+ elem_id="file_expr",
151
+ )
152
  with gr.Column(scale=2):
153
  with gr.Row():
154
  with gr.Column(scale=2):
 
258
  llm, gradio('ak','sk','llm_client'), None
259
  )
260
 
261
+ components['db_view'].select(
262
+ db_expr, gradio('db_view'), gradio('file_expr')
263
+ )
264
+
265
  def do_refernce(algo_type,input_image):
266
  # def do_refernce():
267
  print("input image",input_image)
 
438
  dfs = knowledgeBase.get_df_bases()
439
  return dfs,gr.CheckboxGroup(dbs,label="知识库", info="可选择1个或多个知识库"),gr.Dropdown(dbs,multiselect=True, label="知识库选择")
440
 
441
+ def db_expr(selected_index: gr.SelectData, dataframe_origin):
442
+ print("db_expr",selected_index.index)
443
+
444
+ dbname = dataframe_origin.iloc[selected_index.index[0],selected_index.index[1]]
445
+ print("db_expr",dbname)
446
+
447
+ return knowledgeBase.get_db_files(dbname)
448
+
449
  def do_search(selected_dbs,user_input):
450
  print("do_search:",selected_dbs,user_input)
451
  context = knowledgeBase.retrieve_documents(selected_dbs,user_input)
retriever.py CHANGED
@@ -5,6 +5,7 @@ from langchain_community.docstore.in_memory import InMemoryDocstore
5
  import faiss
6
  import os
7
  import glob
 
8
  from typing import Any,List,Dict
9
  from embedding import Embedding
10
 
@@ -16,6 +17,7 @@ class KnowledgeBaseManager:
16
  self.batch_size = batch_size
17
  self.embeddings = Embedding()
18
  self.knowledge_bases: Dict[str, FAISS] = {}
 
19
  os.makedirs(self.base_path, exist_ok=True)
20
 
21
  faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
@@ -33,12 +35,14 @@ class KnowledgeBaseManager:
33
  return
34
 
35
  self.knowledge_bases[name] = kb
 
36
  self.save_knowledge_base(name)
37
  print(f"Knowledge base '{name}' created.")
38
 
39
  def delete_knowledge_base(self, name: str):
40
  if name in self.knowledge_bases:
41
  del self.knowledge_bases[name]
 
42
  os.remove(os.path.join(self.base_path, f"{name}.faiss"))
43
  print(f"Knowledge base '{name}' deleted.")
44
  else:
@@ -48,6 +52,15 @@ class KnowledgeBaseManager:
48
  kb_path = os.path.join(self.base_path, f"{name}.faiss")
49
  if os.path.exists(kb_path):
50
  self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
 
 
 
 
 
 
 
 
 
51
  print(f"Knowledge base '{name}' loaded.")
52
  else:
53
  print(f"Knowledge base '{name}' does not exist.")
@@ -55,6 +68,8 @@ class KnowledgeBaseManager:
55
  def save_knowledge_base(self, name: str):
56
  if name in self.knowledge_bases:
57
  self.knowledge_bases[name].save_local(self.base_path, name)
 
 
58
  print(f"Knowledge base '{name}' saved.")
59
  else:
60
  print(f"Knowledge base '{name}' does not exist.")
@@ -72,6 +87,7 @@ class KnowledgeBaseManager:
72
  self.create_knowledge_base(name)
73
 
74
  kb = self.knowledge_bases[name]
 
75
  documents = self.load_documents(file_paths)
76
  print(f"Loaded {len(documents)} documents.")
77
  print(documents)
@@ -138,6 +154,10 @@ class KnowledgeBaseManager:
138
 
139
  return results
140
 
 
 
 
 
141
  def get_bases(self):
142
  data = self.knowledge_bases.keys()
143
  return list(data)
 
5
  import faiss
6
  import os
7
  import glob
8
+ import json
9
  from typing import Any,List,Dict
10
  from embedding import Embedding
11
 
 
17
  self.batch_size = batch_size
18
  self.embeddings = Embedding()
19
  self.knowledge_bases: Dict[str, FAISS] = {}
20
+ self.db_files_map: Dict[str, list] = {}
21
  os.makedirs(self.base_path, exist_ok=True)
22
 
23
  faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
 
35
  return
36
 
37
  self.knowledge_bases[name] = kb
38
+ self.db_files_map[name] = []
39
  self.save_knowledge_base(name)
40
  print(f"Knowledge base '{name}' created.")
41
 
42
  def delete_knowledge_base(self, name: str):
43
  if name in self.knowledge_bases:
44
  del self.knowledge_bases[name]
45
+ del self.db_files_map[name]
46
  os.remove(os.path.join(self.base_path, f"{name}.faiss"))
47
  print(f"Knowledge base '{name}' deleted.")
48
  else:
 
52
  kb_path = os.path.join(self.base_path, f"{name}.faiss")
53
  if os.path.exists(kb_path):
54
  self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
55
+ # 加载文件中的数据
56
+ try:
57
+ with open('db.json', 'r+') as f:
58
+ self.db_files_map = json.load(f)
59
+ except FileNotFoundError:
60
+ # 如果文件不存在,则创建一个空的文件并初始化 self.db_files_map
61
+ with open('db.json', 'w+') as f:
62
+ self.db_files_map = {}
63
+ json.dump(self.db_files_map, f)
64
  print(f"Knowledge base '{name}' loaded.")
65
  else:
66
  print(f"Knowledge base '{name}' does not exist.")
 
68
  def save_knowledge_base(self, name: str):
69
  if name in self.knowledge_bases:
70
  self.knowledge_bases[name].save_local(self.base_path, name)
71
+ with open('db.json', 'w') as f:
72
+ json.dump(self.db_files_map, f)
73
  print(f"Knowledge base '{name}' saved.")
74
  else:
75
  print(f"Knowledge base '{name}' does not exist.")
 
87
  self.create_knowledge_base(name)
88
 
89
  kb = self.knowledge_bases[name]
90
+ self.db_files_map[name].extend([os.path.basename(file_path) for file_path in file_paths])
91
  documents = self.load_documents(file_paths)
92
  print(f"Loaded {len(documents)} documents.")
93
  print(documents)
 
154
 
155
  return results
156
 
157
+ def get_db_files(self,name):
158
+ data = self.db_files_map[name]
159
+ return data
160
+
161
  def get_bases(self):
162
  data = self.knowledge_bases.keys()
163
  return list(data)