ChatitoArXiv / utils /store.py
RubenAMtz's picture
added async functionality to chain execution
56875e8
from tqdm.auto import tqdm
from langchain.embeddings import CacheBackedEmbeddings
from uuid import uuid4
from langchain_core.documents import Document
from typing import List
from langchain.text_splitter import TextSplitter
from pinecone import GRPCIndex
import arxiv
import chainlit as cl
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
BATCH_LIMIT = 100
def index_documents(
docs: List[Document],
text_splitter: TextSplitter,
embedder: CacheBackedEmbeddings,
index: GRPCIndex) -> None:
texts = []
metadatas = []
for i in tqdm(range(len(docs))):
for doc in docs[i]:
metadata = {
'source_document' : doc.metadata["source"],
'page_number' : doc.metadata["page"]
}
record_texts = text_splitter.split_text(doc.page_content)
record_metadatas = [{
"chunk": j, "text": text, **metadata
} for j, text in enumerate(record_texts)]
texts.extend(record_texts)
metadatas.extend(record_metadatas)
if len(texts) >= BATCH_LIMIT:
ids = [str(uuid4()) for _ in range(len(texts))]
embeds = embedder.embed_documents(texts)
index.upsert(vectors=zip(ids, embeds, metadatas))
texts = []
metadatas = []
if len(texts) > 0:
ids = [str(uuid4()) for _ in range(len(texts))]
embeds = embedder.embed_documents(texts)
index.upsert(vectors=zip(ids, embeds, metadatas))
def search_and_index(message: cl.Message, quantity: int, embedder: CacheBackedEmbeddings, index: GRPCIndex) -> None:
arxiv_client = arxiv.Client()
search = arxiv.Search(
query = message.content,
max_results = quantity,
sort_by = arxiv.SortCriterion.Relevance
)
paper_urls = []
for result in arxiv_client.results(search):
paper_urls.append(result.pdf_url)
# load them and split them (on message)
docs = []
for paper_url in paper_urls:
try:
loader = PyPDFLoader(paper_url)
docs.append(loader.load())
except:
print(f"Error loading {paper_url}")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 400,
chunk_overlap = 30,
length_function = len
)
# create an index using pinecone (on message)
index_documents(docs, text_splitter, embedder, index)