Spaces:
Sleeping
Sleeping
from tqdm.auto import tqdm | |
from langchain.embeddings import CacheBackedEmbeddings | |
from uuid import uuid4 | |
from langchain_core.documents import Document | |
from typing import List | |
from langchain.text_splitter import TextSplitter | |
from pinecone import GRPCIndex | |
import arxiv | |
import chainlit as cl | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
BATCH_LIMIT = 100 | |
def index_documents( | |
docs: List[Document], | |
text_splitter: TextSplitter, | |
embedder: CacheBackedEmbeddings, | |
index: GRPCIndex) -> None: | |
texts = [] | |
metadatas = [] | |
for i in tqdm(range(len(docs))): | |
for doc in docs[i]: | |
metadata = { | |
'source_document' : doc.metadata["source"], | |
'page_number' : doc.metadata["page"] | |
} | |
record_texts = text_splitter.split_text(doc.page_content) | |
record_metadatas = [{ | |
"chunk": j, "text": text, **metadata | |
} for j, text in enumerate(record_texts)] | |
texts.extend(record_texts) | |
metadatas.extend(record_metadatas) | |
if len(texts) >= BATCH_LIMIT: | |
ids = [str(uuid4()) for _ in range(len(texts))] | |
embeds = embedder.embed_documents(texts) | |
index.upsert(vectors=zip(ids, embeds, metadatas)) | |
texts = [] | |
metadatas = [] | |
if len(texts) > 0: | |
ids = [str(uuid4()) for _ in range(len(texts))] | |
embeds = embedder.embed_documents(texts) | |
index.upsert(vectors=zip(ids, embeds, metadatas)) | |
def search_and_index(message: cl.Message, quantity: int, embedder: CacheBackedEmbeddings, index: GRPCIndex) -> None: | |
arxiv_client = arxiv.Client() | |
search = arxiv.Search( | |
query = message.content, | |
max_results = quantity, | |
sort_by = arxiv.SortCriterion.Relevance | |
) | |
paper_urls = [] | |
for result in arxiv_client.results(search): | |
paper_urls.append(result.pdf_url) | |
# load them and split them (on message) | |
docs = [] | |
for paper_url in paper_urls: | |
try: | |
loader = PyPDFLoader(paper_url) | |
docs.append(loader.load()) | |
except: | |
print(f"Error loading {paper_url}") | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size = 400, | |
chunk_overlap = 30, | |
length_function = len | |
) | |
# create an index using pinecone (on message) | |
index_documents(docs, text_splitter, embedder, index) |