File size: 5,402 Bytes
b2b64bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import os
from typing import List
import pinecone
from tqdm.auto import tqdm
from uuid import uuid4
import arxiv
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain.vectorstores import Pinecone
INDEX_BATCH_LIMIT = 100
class CharacterTextSplitter:
def __init__(
self,
chunk_size: int = 1000,
chunk_overlap: int = 200,
):
assert (
chunk_size > chunk_overlap
), "Chunk size must be greater than chunk overlap"
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size = self.chunk_size, # the character length of the chunk
chunk_overlap = self.chunk_overlap, # the character length of the overlap between chunks
length_function = len, # the length function - in this case, character length (aka the python len() fn.)
)
def split(self, text: str) -> List[str]:
return self.text_splitter.split_text(text)
class ArxivLoader:
def __init__(self, query : str = "Nuclear Fission", max_results : int = 5, encoding: str = "utf-8"):
""""""
self.query = query
self.max_results = max_results
self.paper_urls = []
self.documents = []
self.splitter = CharacterTextSplitter()
def retrieve_urls(self):
""""""
arxiv_client = arxiv.Client()
search = arxiv.Search(
query = self.query,
max_results = self.max_results,
sort_by = arxiv.SortCriterion.Relevance
)
for result in arxiv_client.results(search):
self.paper_urls.append(result.pdf_url)
def load_documents(self):
""""""
for paper_url in self.paper_urls:
loader = PyPDFLoader(paper_url)
self.documents.append(loader.load())
def format_document(self, document):
""""""
metadata = {
'source_document' : document.metadata["source"],
'page_number' : document.metadata["page"]
}
record_texts = self.splitter.split(document.page_content)
record_metadatas = [{
"chunk": j, "text": text, **metadata
} for j, text in enumerate(record_texts)]
return record_texts, record_metadatas
def main(self):
""""""
self.retrieve_urls()
self.load_documents()
class PineconeIndexer:
def __init__(self, index_name : str = "arxiv-paper-index", metric : str = "cosine", n_dims : int = 1536):
""""""
pinecone.init(
api_key=os.environ["PINECONE_API_KEY"],
environment=os.environ["PINECONE_ENV"]
)
if index_name not in pinecone.list_indexes():
# we create a new index
pinecone.create_index(
name=index_name,
metric=metric,
dimension=n_dims
)
self.index = pinecone.Index(index_name)
self.arxiv_loader = ArxivLoader()
def load_embedder(self):
""""""
store = LocalFileStore("./cache/")
core_embeddings_model = OpenAIEmbeddings()
self.embedder = CacheBackedEmbeddings.from_bytes_store(
core_embeddings_model,
store,
namespace=core_embeddings_model.model
)
def upsert(self, texts, metadatas):
""""""
ids = [str(uuid4()) for _ in range(len(texts))]
embeds = self.embedder.embed_documents(texts)
self.index.upsert(vectors=zip(ids, embeds, metadatas))
def index_documents(self, documents, batch_limit : int = INDEX_BATCH_LIMIT):
""""""
texts = []
metadatas = []
# iterate through your top-level document
for i in tqdm(range(len(documents))):
# select single document object
for page in documents[i] :
record_texts, record_metadatas = self.arxiv_loader.format_document(page)
texts.extend(record_texts)
metadatas.extend(record_metadatas)
if len(texts) >= batch_limit:
self.upsert(texts, metadatas)
texts = []
metadatas = []
if len(texts) > 0:
self.upsert(texts, metadatas)
def get_vectorstore(self):
""""""
return Pinecone(self.index, self.embedder.embed_query, "text")
if __name__ == "__main__":
print("-------------- Loading Arxiv --------------")
axloader = ArxivLoader()
axloader.retrieve_urls()
axloader.load_documents()
print("\n-------------- Splitting sample doc --------------")
sample_doc = axloader.documents[0]
sample_page = sample_doc[0]
splitter = CharacterTextSplitter()
chunks = splitter.split(sample_page.page_content)
print(len(chunks))
print(chunks[0])
print("\n-------------- testing pinecode indexer --------------")
pi = PineconeIndexer()
pi.load_embedder()
pi.index_documents(axloader.documents)
print(pi.index.describe_index_stats())
|