File size: 2,556 Bytes
cdda8d7
 
 
 
 
 
 
649e581
 
 
 
cdda8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649e581
 
 
56875e8
649e581
 
 
 
 
56875e8
649e581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from tqdm.auto import tqdm
from langchain.embeddings import CacheBackedEmbeddings
from uuid import uuid4
from langchain_core.documents import Document
from typing import List
from langchain.text_splitter import TextSplitter
from pinecone import GRPCIndex
import arxiv
import chainlit as cl
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

BATCH_LIMIT = 100

def index_documents(
        docs: List[Document], 
        text_splitter: TextSplitter, 
        embedder: CacheBackedEmbeddings, 
        index: GRPCIndex) -> None:
    
    texts = []
    metadatas = []

    for i in tqdm(range(len(docs))):
        for doc in docs[i]:
            metadata = {
                'source_document' : doc.metadata["source"],
                'page_number' : doc.metadata["page"]
            }

            record_texts = text_splitter.split_text(doc.page_content)

            record_metadatas = [{
                "chunk": j, "text": text, **metadata
            } for j, text in enumerate(record_texts)]
            texts.extend(record_texts)
            metadatas.extend(record_metadatas)
            if len(texts) >= BATCH_LIMIT:
                ids = [str(uuid4()) for _ in range(len(texts))]
                embeds = embedder.embed_documents(texts)
                index.upsert(vectors=zip(ids, embeds, metadatas))
                texts = []
                metadatas = []

    if len(texts) > 0:
        ids = [str(uuid4()) for _ in range(len(texts))]
        embeds = embedder.embed_documents(texts)
        index.upsert(vectors=zip(ids, embeds, metadatas))


def search_and_index(message: cl.Message, quantity: int, embedder: CacheBackedEmbeddings, index: GRPCIndex) -> None:
            
    arxiv_client = arxiv.Client()
    
    search = arxiv.Search(
        query = message.content,
        max_results = quantity,
        sort_by = arxiv.SortCriterion.Relevance
    )
    paper_urls = []

    for result in arxiv_client.results(search):
        paper_urls.append(result.pdf_url)

    # load them and split them (on message)
    docs = []
    for paper_url in paper_urls:
        try:
            loader = PyPDFLoader(paper_url)
            docs.append(loader.load())
        except:
            print(f"Error loading {paper_url}")

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size = 400,
        chunk_overlap = 30,
        length_function = len
    )

    # create an index using pinecone (on message)
    index_documents(docs, text_splitter, embedder, index)