File size: 4,895 Bytes
d97a6fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
085b39c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97a6fa
 
 
 
085b39c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97a6fa
 
 
 
 
 
 
 
085b39c
d97a6fa
 
 
 
085b39c
d97a6fa
 
085b39c
d97a6fa
 
 
 
 
 
 
085b39c
d97a6fa
 
 
085b39c
d97a6fa
 
 
085b39c
d97a6fa
 
 
 
 
 
085b39c
d97a6fa
 
085b39c
d97a6fa
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import (
    PyPDFLoader,
    DataFrameLoader,
)
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.chat_models import ChatOpenAI
from bot.utils.show_log import logger
import pandas as pd
import threading
import glob
import os
import queue


class Query:
    def __init__(self, question, llm, index):
        self.question = question
        self.llm = llm
        self.index = index

    def query(self):
        """Query the vectorstore."""
        llm = self.llm or ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
        chain = RetrievalQA.from_chain_type(
            llm, retriever=self.index.as_retriever()
        )
        return chain.run(self.question)


class SearchableIndex:
    def __init__(self, path):
        self.path = path

    @classmethod
    def get_splits(cls, path, target_col=None, sheet_name=None):
        extension = os.path.splitext(path)[1].lower()
        doc_list = None
        if extension == ".txt":
            with open(path, 'r') as txt:
                data = txt.read()
                text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
                                                            chunk_overlap=0,
                                                            length_function=len)
                doc_list = text_split.split_text(data)
        elif extension == ".pdf":
            loader = PyPDFLoader(path)
            pages = loader.load_and_split()
            text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
                                                        chunk_overlap=0,
                                                        length_function=len)
            doc_list = []
            for pg in pages:
                pg_splits = text_split.split_text(pg.page_content)
                doc_list.extend(pg_splits)
        elif extension == ".xml":
            df = pd.read_excel(io=path, engine='openpyxl', sheet_name=sheet_name)
            df_loader = DataFrameLoader(df, page_content_column=target_col)
            doc_list = df_loader.load()
        elif extension == ".csv":
            csv_loader = CSVLoader(path)
            doc_list = csv_loader.load()
        if doc_list is None:
            raise ValueError("Unsupported file format")
        return doc_list

    @classmethod
    def merge_or_create_index(cls, index_store, faiss_db, embeddings, logger):
        if os.path.exists(index_store):
            local_db = FAISS.load_local(index_store, embeddings)
            local_db.merge_from(faiss_db)
            local_db.save_local(index_store)
            logger.info("Merge index completed")
        else:
            faiss_db.save_local(folder_path=index_store)
            logger.info("New store created and loaded...")
            local_db = FAISS.load_local(index_store, embeddings)
        return local_db

    @classmethod
    def check_and_load_index(cls, index_files, embeddings, logger, result_queue):
        if index_files:
            local_db = FAISS.load_local(index_files[0], embeddings)
        else:
            raise logger.warning("Index store does not exist")
        result_queue.put(local_db)  # Put the result in the queue

    @classmethod
    def embed_index(cls, url, path, llm, prompt, target_col=None, sheet_name=None):
        embeddings = OpenAIEmbeddings()

        if url != 'NO_URL' and path:
            doc_list = cls.get_splits(path, target_col, sheet_name)
            faiss_db = FAISS.from_texts(doc_list, embeddings)
            index_store = os.path.splitext(path)[0] + "_index"
            local_db = cls.merge_or_create_index(index_store, faiss_db, embeddings, logger)
            return Query(prompt, llm, local_db)
        elif url == 'NO_URL' and path:
            index_files = glob.glob(os.path.join(path, '*_index'))

            result_queue = queue.Queue()  # Create a queue to store the result

            thread = threading.Thread(target=cls.check_and_load_index,
                                      args=(index_files, embeddings, logger, result_queue))
            thread.start()
            local_db = result_queue.get()  # Retrieve the result from the queue
            return Query(prompt, llm, local_db)


if __name__ == '__main__':
    pass
    # Examples for search query
    # index = SearchableIndex.embed_index(
    #     path="/Users/macbook/Downloads/AI_test_exam/ChatBot/learning_documents/combined_content.txt")
    # prompt = 'show more detail about types of data collected'
    # llm = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
    # result = SearchableIndex.query(prompt, llm=llm, index=index)
    # print(result)