Spaces:
Sleeping
Sleeping
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import TextLoader, JSONLoader, PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.docstore.in_memory import InMemoryDocstore | |
import faiss | |
import os | |
import glob | |
import json | |
from typing import Any,List,Dict | |
from embedding import Embedding | |
class KnowledgeBaseManager: | |
def __init__(self, base_path="./knowledge_bases", embedding_dim=512, batch_size=16): | |
self.base_path = base_path | |
self.embedding_dim = embedding_dim | |
self.batch_size = batch_size | |
self.embeddings = Embedding() | |
self.knowledge_bases: Dict[str, FAISS] = {} | |
self.db_files_map: Dict[str, list] = {} | |
os.makedirs(self.base_path, exist_ok=True) | |
faiss_files = glob.glob(os.path.join(base_path, '*.faiss')) | |
# 获取不带后缀的名称 | |
file_names_without_extension = [os.path.splitext(os.path.basename(file))[0] for file in faiss_files] | |
for name in file_names_without_extension: | |
self.load_knowledge_base(name) | |
def create_knowledge_base(self, name: str): | |
index = faiss.IndexFlatL2(self.embedding_dim) | |
kb = FAISS(self.embeddings, index, InMemoryDocstore(), {}) | |
if name in self.knowledge_bases: | |
print(f"Knowledge base '{name}' already exists.") | |
return | |
self.knowledge_bases[name] = kb | |
self.db_files_map[name] = [] | |
self.save_knowledge_base(name) | |
print(f"Knowledge base '{name}' created.") | |
def delete_knowledge_base(self, name: str): | |
if name in self.knowledge_bases: | |
del self.knowledge_bases[name] | |
del self.db_files_map[name] | |
os.remove(os.path.join(self.base_path, f"{name}.faiss")) | |
print(f"Knowledge base '{name}' deleted.") | |
else: | |
print(f"Knowledge base '{name}' does not exist.") | |
def load_knowledge_base(self, name: str): | |
kb_path = os.path.join(self.base_path, f"{name}.faiss") | |
if os.path.exists(kb_path): | |
self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True) | |
# 加载文件中的数据 | |
try: | |
with open('db.json', 'r+') as f: | |
self.db_files_map = json.load(f) | |
except FileNotFoundError: | |
# 如果文件不存在,则创建一个空的文件并初始化 self.db_files_map | |
with open('db.json', 'w+') as f: | |
self.db_files_map = {} | |
json.dump(self.db_files_map, f) | |
print(f"Knowledge base '{name}' loaded.") | |
else: | |
print(f"Knowledge base '{name}' does not exist.") | |
def save_knowledge_base(self, name: str): | |
if name in self.knowledge_bases: | |
self.knowledge_bases[name].save_local(self.base_path, name) | |
with open('db.json', 'w') as f: | |
json.dump(self.db_files_map, f) | |
print(f"Knowledge base '{name}' saved.") | |
else: | |
print(f"Knowledge base '{name}' does not exist.") | |
# Document(page_content = '渠道版', metadata = { | |
# 'source': './files/input/PS004.pdf', | |
# 'page': 0 | |
# }), Document(page_content = '2/20.', metadata = { | |
# 'source': './files/input/PS004.pdf', | |
# 'page': 1 | |
# }) | |
def add_documents_to_kb(self, name: str, file_paths: List[str]): | |
if name not in self.knowledge_bases: | |
print(f"Knowledge base '{name}' does not exist.") | |
self.create_knowledge_base(name) | |
kb = self.knowledge_bases[name] | |
self.db_files_map[name].extend([os.path.basename(file_path) for file_path in file_paths]) | |
documents = self.load_documents(file_paths) | |
print(f"Loaded {len(documents)} documents.") | |
print(documents) | |
pages = self.split_documents(documents) | |
print(f"Split documents into {len(pages)} pages.") | |
# print(pages) | |
doc_ids = [] | |
for i in range(0, len(pages), self.batch_size): | |
batch = pages[i:i+self.batch_size] | |
doc_ids.extend(kb.add_documents(batch)) | |
self.save_knowledge_base(name) | |
return doc_ids | |
def load_documents(self, file_paths: List[str]): | |
documents = [] | |
for file_path in file_paths: | |
loader = self.get_loader(file_path) | |
documents.extend(loader.load()) | |
return documents | |
def get_loader(self, file_path: str): | |
if file_path.endswith('.txt'): | |
return TextLoader(file_path) | |
elif file_path.endswith('.json'): | |
return JSONLoader(file_path) | |
elif file_path.endswith('.pdf'): | |
return PyPDFLoader(file_path) | |
else: | |
raise ValueError("Unsupported file format") | |
def split_documents(self, documents): | |
text_splitter = RecursiveCharacterTextSplitter(separators=[ | |
"\n\n", | |
"\n", | |
" ", | |
".", | |
",", | |
"\u200b", # Zero-width space | |
"\uff0c", # Fullwidth comma | |
"\u3001", # Ideographic comma | |
"\uff0e", # Fullwidth full stop | |
"\u3002", # Ideographic full stop | |
"", | |
], | |
chunk_size=512, chunk_overlap=0) | |
return text_splitter.split_documents(documents) | |
def retrieve_documents(self, names: List[str], query: str): | |
results = [] | |
for name in names: | |
if name not in self.knowledge_bases: | |
print(f"Knowledge base '{name}' does not exist.") | |
continue | |
retriever = self.knowledge_bases[name].as_retriever( | |
search_type="mmr", | |
search_kwargs={"score_threshold": 0.5, "k": 3} | |
) | |
docs = retriever.get_relevant_documents(query) | |
results.extend([{"name": name, "content": doc.page_content,"meta": doc.metadata} for doc in docs]) | |
return results | |
def get_db_files(self,name): | |
data = self.db_files_map[name] | |
return data | |
def get_bases(self): | |
data = self.knowledge_bases.keys() | |
return list(data) | |
def get_df_bases(self): | |
import pandas as pd | |
data = self.knowledge_bases.keys() | |
return pd.DataFrame(list(data), columns=['列表']) | |
knowledgeBase = KnowledgeBaseManager() | |