File size: 3,979 Bytes
d97a6fa
 
bad3833
d97a6fa
 
 
 
 
 
 
 
 
 
085b39c
 
 
 
 
 
 
 
 
 
 
 
 
 
d97a6fa
 
 
 
085b39c
bad3833
085b39c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97a6fa
 
 
bad3833
d97a6fa
 
 
e44f2dc
d97a6fa
e44f2dc
 
 
 
bad3833
085b39c
d97a6fa
 
bad3833
d97a6fa
bad3833
 
 
 
e44f2dc
 
 
bad3833
e44f2dc
 
bad3833
 
e44f2dc
 
 
d97a6fa
 
bad3833
d97a6fa
 
e44f2dc
 
bad3833
e44f2dc
 
 
 
d97a6fa
e44f2dc
 
085b39c
d97a6fa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.chat_models import ChatOpenAI
from bot.utils.show_log import logger
import threading
import glob
import os
import queue


class Query:
    def __init__(self, question, llm, index):
        self.question = question
        self.llm = llm
        self.index = index

    def query(self):
        llm = self.llm or ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
        chain = RetrievalQA.from_chain_type(
            llm, retriever=self.index.as_retriever()
        )
        return chain.run(self.question)


class SearchableIndex:
    def __init__(self, path):
        self.path = path

    @classmethod
    def get_splits(cls, path):
        extension = os.path.splitext(path)[1].lower()
        doc_list = None
        if extension == ".txt":
            with open(path, 'r') as txt:
                data = txt.read()
                text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
                                                            chunk_overlap=0,
                                                            length_function=len)
                doc_list = text_split.split_text(data)
        elif extension == ".pdf":
            loader = PyPDFLoader(path)
            pages = loader.load_and_split()
            text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
                                                        chunk_overlap=0,
                                                        length_function=len)
            doc_list = []
            for pg in pages:
                pg_splits = text_split.split_text(pg.page_content)
                doc_list.extend(pg_splits)
        if doc_list is None:
            raise ValueError("Unsupported file format")
        return doc_list

    @classmethod
    def merge_or_create_index(cls, index_store, faiss_db, embeddings, loggers):
        if os.path.exists(index_store):
            local_db = FAISS.load_local(index_store, embeddings)
            local_db.merge_from(faiss_db)
            operation_info = "Merge"
        else:
            local_db = faiss_db  # Use the provided faiss_db directly for a new store
            operation_info = "New store creation"

        local_db.save_local(index_store)
        loggers.info(f"{operation_info} index completed")
        return local_db

    @classmethod
    def load_or_check_index(cls, index_files, embeddings, loggers, result_queue):
        if index_files:
            local_db = FAISS.load_local(index_files[0], embeddings)
            result_queue.put(local_db)
            return local_db
        loggers.warning("Index store does not exist")
        return None

    @classmethod
    def load_index_asynchronously(cls, index_files, embeddings, loggers):
        result_queue = queue.Queue()
        thread = threading.Thread(
            target=cls.load_or_check_index,
            args=(index_files, embeddings, loggers, result_queue)
        )
        thread.start()
        return result_queue.get()

    @classmethod
    def embed_index(cls, url, path, llm, prompt):
        embeddings = OpenAIEmbeddings()

        if path:
            if url != 'NO_URL':
                doc_list = cls.get_splits(path)
                faiss_db = FAISS.from_texts(doc_list, embeddings)
                index_store = os.path.splitext(path)[0] + "_index"
                local_db = cls.merge_or_create_index(index_store, faiss_db, embeddings, logger)
                return Query(prompt, llm, local_db)

            index_files = glob.glob(os.path.join(path, '*_index'))
            local_db = cls.load_index_asynchronously(index_files, embeddings, logger)
            return Query(prompt, llm, local_db)


if __name__ == '__main__':
    pass