File size: 7,030 Bytes
cc74372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e15749
cc74372
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader, JSONLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.docstore.in_memory import InMemoryDocstore
import faiss
import os
import glob
import json
from typing import Any,List,Dict
from embedding import Embedding


class KnowledgeBaseManager:
    def __init__(self, base_path="./knowledge_bases", embedding_dim=512, batch_size=16):
        self.base_path = base_path
        self.embedding_dim = embedding_dim
        self.batch_size = batch_size
        self.embeddings = Embedding()
        self.knowledge_bases: Dict[str, FAISS] = {}
        self.db_files_map: Dict[str, list] = {}
        os.makedirs(self.base_path, exist_ok=True)

        faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
        # 获取不带后缀的名称
        file_names_without_extension = [os.path.splitext(os.path.basename(file))[0] for file in faiss_files]
        for name in file_names_without_extension:
            self.load_knowledge_base(name)


    def create_knowledge_base(self, name: str):
        index = faiss.IndexFlatL2(self.embedding_dim)
        kb = FAISS(self.embeddings, index, InMemoryDocstore(), {})
        if name in self.knowledge_bases:
            print(f"Knowledge base '{name}' already exists.")
            return
        
        self.knowledge_bases[name] = kb
        self.db_files_map[name] = []
        self.save_knowledge_base(name)
        print(f"Knowledge base '{name}' created.")

    def delete_knowledge_base(self, name: str):
        if name in self.knowledge_bases:
            del self.knowledge_bases[name]
            del self.db_files_map[name]
            os.remove(os.path.join(self.base_path, f"{name}.faiss"))
            print(f"Knowledge base '{name}' deleted.")
        else:
            print(f"Knowledge base '{name}' does not exist.")

    def load_knowledge_base(self, name: str):
        kb_path = os.path.join(self.base_path, f"{name}.faiss")
        if os.path.exists(kb_path):
            self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
            # 加载文件中的数据
            try:
                with open('db.json', 'r+') as f:
                    self.db_files_map = json.load(f)
            except FileNotFoundError:
                # 如果文件不存在,则创建一个空的文件并初始化 self.db_files_map
                with open('db.json', 'w+') as f:
                    self.db_files_map = {}
                    json.dump(self.db_files_map, f)
            print(f"Knowledge base '{name}' loaded.")
        else:
            print(f"Knowledge base '{name}' does not exist.")

    def save_knowledge_base(self, name: str):
        if name in self.knowledge_bases:
            self.knowledge_bases[name].save_local(self.base_path, name)
            with open('db.json', 'w') as f:
                json.dump(self.db_files_map, f)
            print(f"Knowledge base '{name}' saved.")
        else:
            print(f"Knowledge base '{name}' does not exist.")

    # Document(page_content = '渠道版', metadata = {
	# 'source': './files/input/PS004.pdf',
	# 'page': 0
    # }), Document(page_content = '2/20.', metadata = {
    #     'source': './files/input/PS004.pdf',
    #     'page': 1
    # })
    def add_documents_to_kb(self, name: str, file_paths: List[str]):
        if name not in self.knowledge_bases:
            print(f"Knowledge base '{name}' does not exist.")
            self.create_knowledge_base(name)
        
        kb = self.knowledge_bases[name]
        self.db_files_map[name].extend([os.path.basename(file_path) for file_path in file_paths])
        documents = self.load_documents(file_paths)
        print(f"Loaded {len(documents)} documents.")
        print(documents)
        pages = self.split_documents(documents)
        print(f"Split documents into {len(pages)} pages.")
        # print(pages)
        
        doc_ids = []
        for i in range(0, len(pages), self.batch_size):
            batch = pages[i:i+self.batch_size]
            doc_ids.extend(kb.add_documents(batch))
        
        self.save_knowledge_base(name)
        return doc_ids

    def load_documents(self, file_paths: List[str]):
        documents = []
        for file_path in file_paths:
            loader = self.get_loader(file_path)
            documents.extend(loader.load())
        return documents

    def get_loader(self, file_path: str):
        if file_path.endswith('.txt'):
            return TextLoader(file_path)
        elif file_path.endswith('.json'):
            return JSONLoader(file_path)
        elif file_path.endswith('.pdf'):
            return PyPDFLoader(file_path)
        else:
            raise ValueError("Unsupported file format")

    def split_documents(self, documents):
        text_splitter = RecursiveCharacterTextSplitter(separators=[
                                                    "\n\n",
                                                    "\n",
                                                    " ",
                                                    ".",
                                                    ",",
                                                    "\u200b",  # Zero-width space
                                                    "\uff0c",  # Fullwidth comma
                                                    "\u3001",  # Ideographic comma
                                                    "\uff0e",  # Fullwidth full stop
                                                    "\u3002",  # Ideographic full stop
                                                    "",
                                                ],
                                                chunk_size=512, chunk_overlap=0)
        return text_splitter.split_documents(documents)

    def retrieve_documents(self, names: List[str], query: str):
        results = []
        for name in names:
            if name not in self.knowledge_bases:
                print(f"Knowledge base '{name}' does not exist.")
                continue
            
            retriever = self.knowledge_bases[name].as_retriever(
                search_type="mmr",
                search_kwargs={"score_threshold": 0.5, "k": 3}
            )
            docs = retriever.get_relevant_documents(query)
            results.extend([{"name": name, "content": doc.page_content,"meta": doc.metadata} for doc in docs])
            
        
        return results
    
    def get_db_files(self,name):
        data = self.db_files_map.get(name)
        return data
    
    def get_bases(self):
        data = self.knowledge_bases.keys()
        return list(data)
    
    def get_df_bases(self):
        import pandas as pd
        data = self.knowledge_bases.keys()
        return pd.DataFrame(list(data), columns=['列表'])

knowledgeBase = KnowledgeBaseManager()