Spaces:
Build error
Build error
File size: 4,895 Bytes
d97a6fa 085b39c d97a6fa 085b39c d97a6fa 085b39c d97a6fa 085b39c d97a6fa 085b39c d97a6fa 085b39c d97a6fa 085b39c d97a6fa 085b39c d97a6fa 085b39c d97a6fa 085b39c d97a6fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import (
PyPDFLoader,
DataFrameLoader,
)
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.chat_models import ChatOpenAI
from bot.utils.show_log import logger
import pandas as pd
import threading
import glob
import os
import queue
class Query:
def __init__(self, question, llm, index):
self.question = question
self.llm = llm
self.index = index
def query(self):
"""Query the vectorstore."""
llm = self.llm or ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
chain = RetrievalQA.from_chain_type(
llm, retriever=self.index.as_retriever()
)
return chain.run(self.question)
class SearchableIndex:
def __init__(self, path):
self.path = path
@classmethod
def get_splits(cls, path, target_col=None, sheet_name=None):
extension = os.path.splitext(path)[1].lower()
doc_list = None
if extension == ".txt":
with open(path, 'r') as txt:
data = txt.read()
text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
chunk_overlap=0,
length_function=len)
doc_list = text_split.split_text(data)
elif extension == ".pdf":
loader = PyPDFLoader(path)
pages = loader.load_and_split()
text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
chunk_overlap=0,
length_function=len)
doc_list = []
for pg in pages:
pg_splits = text_split.split_text(pg.page_content)
doc_list.extend(pg_splits)
elif extension == ".xml":
df = pd.read_excel(io=path, engine='openpyxl', sheet_name=sheet_name)
df_loader = DataFrameLoader(df, page_content_column=target_col)
doc_list = df_loader.load()
elif extension == ".csv":
csv_loader = CSVLoader(path)
doc_list = csv_loader.load()
if doc_list is None:
raise ValueError("Unsupported file format")
return doc_list
@classmethod
def merge_or_create_index(cls, index_store, faiss_db, embeddings, logger):
if os.path.exists(index_store):
local_db = FAISS.load_local(index_store, embeddings)
local_db.merge_from(faiss_db)
local_db.save_local(index_store)
logger.info("Merge index completed")
else:
faiss_db.save_local(folder_path=index_store)
logger.info("New store created and loaded...")
local_db = FAISS.load_local(index_store, embeddings)
return local_db
@classmethod
def check_and_load_index(cls, index_files, embeddings, logger, result_queue):
if index_files:
local_db = FAISS.load_local(index_files[0], embeddings)
else:
raise logger.warning("Index store does not exist")
result_queue.put(local_db) # Put the result in the queue
@classmethod
def embed_index(cls, url, path, llm, prompt, target_col=None, sheet_name=None):
embeddings = OpenAIEmbeddings()
if url != 'NO_URL' and path:
doc_list = cls.get_splits(path, target_col, sheet_name)
faiss_db = FAISS.from_texts(doc_list, embeddings)
index_store = os.path.splitext(path)[0] + "_index"
local_db = cls.merge_or_create_index(index_store, faiss_db, embeddings, logger)
return Query(prompt, llm, local_db)
elif url == 'NO_URL' and path:
index_files = glob.glob(os.path.join(path, '*_index'))
result_queue = queue.Queue() # Create a queue to store the result
thread = threading.Thread(target=cls.check_and_load_index,
args=(index_files, embeddings, logger, result_queue))
thread.start()
local_db = result_queue.get() # Retrieve the result from the queue
return Query(prompt, llm, local_db)
if __name__ == '__main__':
pass
# Examples for search query
# index = SearchableIndex.embed_index(
# path="/Users/macbook/Downloads/AI_test_exam/ChatBot/learning_documents/combined_content.txt")
# prompt = 'show more detail about types of data collected'
# llm = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
# result = SearchableIndex.query(prompt, llm=llm, index=index)
# print(result)
|