Spaces:
Sleeping
Sleeping
File size: 7,030 Bytes
cc74372 0e15749 cc74372 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader, JSONLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.docstore.in_memory import InMemoryDocstore
import faiss
import os
import glob
import json
from typing import Any,List,Dict
from embedding import Embedding
class KnowledgeBaseManager:
def __init__(self, base_path="./knowledge_bases", embedding_dim=512, batch_size=16):
self.base_path = base_path
self.embedding_dim = embedding_dim
self.batch_size = batch_size
self.embeddings = Embedding()
self.knowledge_bases: Dict[str, FAISS] = {}
self.db_files_map: Dict[str, list] = {}
os.makedirs(self.base_path, exist_ok=True)
faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
# 获取不带后缀的名称
file_names_without_extension = [os.path.splitext(os.path.basename(file))[0] for file in faiss_files]
for name in file_names_without_extension:
self.load_knowledge_base(name)
def create_knowledge_base(self, name: str):
index = faiss.IndexFlatL2(self.embedding_dim)
kb = FAISS(self.embeddings, index, InMemoryDocstore(), {})
if name in self.knowledge_bases:
print(f"Knowledge base '{name}' already exists.")
return
self.knowledge_bases[name] = kb
self.db_files_map[name] = []
self.save_knowledge_base(name)
print(f"Knowledge base '{name}' created.")
def delete_knowledge_base(self, name: str):
if name in self.knowledge_bases:
del self.knowledge_bases[name]
del self.db_files_map[name]
os.remove(os.path.join(self.base_path, f"{name}.faiss"))
print(f"Knowledge base '{name}' deleted.")
else:
print(f"Knowledge base '{name}' does not exist.")
def load_knowledge_base(self, name: str):
kb_path = os.path.join(self.base_path, f"{name}.faiss")
if os.path.exists(kb_path):
self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
# 加载文件中的数据
try:
with open('db.json', 'r+') as f:
self.db_files_map = json.load(f)
except FileNotFoundError:
# 如果文件不存在,则创建一个空的文件并初始化 self.db_files_map
with open('db.json', 'w+') as f:
self.db_files_map = {}
json.dump(self.db_files_map, f)
print(f"Knowledge base '{name}' loaded.")
else:
print(f"Knowledge base '{name}' does not exist.")
def save_knowledge_base(self, name: str):
if name in self.knowledge_bases:
self.knowledge_bases[name].save_local(self.base_path, name)
with open('db.json', 'w') as f:
json.dump(self.db_files_map, f)
print(f"Knowledge base '{name}' saved.")
else:
print(f"Knowledge base '{name}' does not exist.")
# Document(page_content = '渠道版', metadata = {
# 'source': './files/input/PS004.pdf',
# 'page': 0
# }), Document(page_content = '2/20.', metadata = {
# 'source': './files/input/PS004.pdf',
# 'page': 1
# })
def add_documents_to_kb(self, name: str, file_paths: List[str]):
if name not in self.knowledge_bases:
print(f"Knowledge base '{name}' does not exist.")
self.create_knowledge_base(name)
kb = self.knowledge_bases[name]
self.db_files_map[name].extend([os.path.basename(file_path) for file_path in file_paths])
documents = self.load_documents(file_paths)
print(f"Loaded {len(documents)} documents.")
print(documents)
pages = self.split_documents(documents)
print(f"Split documents into {len(pages)} pages.")
# print(pages)
doc_ids = []
for i in range(0, len(pages), self.batch_size):
batch = pages[i:i+self.batch_size]
doc_ids.extend(kb.add_documents(batch))
self.save_knowledge_base(name)
return doc_ids
def load_documents(self, file_paths: List[str]):
documents = []
for file_path in file_paths:
loader = self.get_loader(file_path)
documents.extend(loader.load())
return documents
def get_loader(self, file_path: str):
if file_path.endswith('.txt'):
return TextLoader(file_path)
elif file_path.endswith('.json'):
return JSONLoader(file_path)
elif file_path.endswith('.pdf'):
return PyPDFLoader(file_path)
else:
raise ValueError("Unsupported file format")
def split_documents(self, documents):
text_splitter = RecursiveCharacterTextSplitter(separators=[
"\n\n",
"\n",
" ",
".",
",",
"\u200b", # Zero-width space
"\uff0c", # Fullwidth comma
"\u3001", # Ideographic comma
"\uff0e", # Fullwidth full stop
"\u3002", # Ideographic full stop
"",
],
chunk_size=512, chunk_overlap=0)
return text_splitter.split_documents(documents)
def retrieve_documents(self, names: List[str], query: str):
results = []
for name in names:
if name not in self.knowledge_bases:
print(f"Knowledge base '{name}' does not exist.")
continue
retriever = self.knowledge_bases[name].as_retriever(
search_type="mmr",
search_kwargs={"score_threshold": 0.5, "k": 3}
)
docs = retriever.get_relevant_documents(query)
results.extend([{"name": name, "content": doc.page_content,"meta": doc.metadata} for doc in docs])
return results
def get_db_files(self,name):
data = self.db_files_map.get(name)
return data
def get_bases(self):
data = self.knowledge_bases.keys()
return list(data)
def get_df_bases(self):
import pandas as pd
data = self.knowledge_bases.keys()
return pd.DataFrame(list(data), columns=['列表'])
knowledgeBase = KnowledgeBaseManager()
|