raqa-arxiv-app / app.py
garyg-ai's picture
synced
aa669cb
import asyncio
import openai
import chainlit as cl # importing chainlit for our app
from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
import os
import getpass
from dotenv import load_dotenv
load_dotenv()
os.environ["PINECONE_ENV"] = "gcp-starter"
import arxiv
arxiv_client = arxiv.Client()
paper_urls = []
''''
search = arxiv.Search(
query = "Retrieval Augmented Generation",
max_results = 5,
sort_by = arxiv.SortCriterion.Relevance
)
for result in arxiv_client.results(search):
paper_urls.append(result.pdf_url)
print(paper_urls)
'''
from langchain.document_loaders import PyPDFLoader
docs = []
''''
for paper_url in paper_urls:
loader = PyPDFLoader(paper_url)
docs.append(loader.load())
print(docs[0][6])
'''
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 1000,
chunk_overlap = 100,
length_function = len
)
import pinecone
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration
YOUR_API_KEY = os.environ["PINECONE_API_KEY"]
YOUR_ENV = os.environ["PINECONE_ENV"]
index_name = 'arxiv-paper-index2'
pinecone.init(
api_key=YOUR_API_KEY,
environment=YOUR_ENV
)
if index_name not in pinecone.list_indexes():
# we create a new index
pinecone.create_index(
name=index_name,
metric='cosine',
dimension=1536
)
index = pinecone.GRPCIndex(index_name)
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
store = LocalFileStore("./cache/")
core_embeddings_model = OpenAIEmbeddings()
embedder = CacheBackedEmbeddings.from_bytes_store(
core_embeddings_model,
store,
namespace=core_embeddings_model.model
)
from tqdm.auto import tqdm
from uuid import uuid4
BATCH_LIMIT = 100
texts = []
metadatas = []
''''
for i in tqdm(range(len(docs))):
for doc in docs[i]:
metadata = {
'source_document' : doc.metadata["source"],
'page_number' : doc.metadata["page"]
}
record_texts = text_splitter.split_text(doc.page_content)
record_metadatas = [{
"chunk": j, "text": text, **metadata
} for j, text in enumerate(record_texts)]
texts.extend(record_texts)
metadatas.extend(record_metadatas)
if len(texts) >= BATCH_LIMIT:
ids = [str(uuid4()) for _ in range(len(texts))]
embeds = embedder.embed_documents(texts)
index.upsert(vectors=zip(ids, embeds, metadatas))
texts = []
metadatas = []
if len(texts) > 0:
ids = [str(uuid4()) for _ in range(len(texts))]
embeds = embedder.embed_documents(texts)
index.upsert(vectors=zip(ids, embeds, metadatas))
'''
from langchain.vectorstores import Pinecone
text_field = "text"
index = pinecone.Index(index_name)
vectorstore = Pinecone(
index,
embedder.embed_query,
text_field
)
''''
query = "What is dense vector retrieval?"
'''
''''
vectorstore.similarity_search(
query,
k=3
)
'''
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(
model="gpt-3.5-turbo",
temperature=0
)
from langchain.prompts import ChatPromptTemplate
system_template = """Answer the following question with the provided context only. If you aren't able to get the answer from the provided context only, then please don't answer the question.
### CONTEXT
{context}
###QUESTION
{question}
"""
retriever = vectorstore.as_retriever()
from langchain.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_template(system_template)
from operator import itemgetter
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain.schema import format_document
from langchain.schema.output_parser import StrOutputParser
from langchain.prompts.prompt import PromptTemplate
retrieval_augmented_qa_chain = (
{"context": itemgetter("question") | retriever,
"question": itemgetter("question")
}
| RunnablePassthrough.assign(
context=itemgetter("context")
)
| {
"response": prompt | llm,
"context": itemgetter("context"),
}
)
import langchain
from langchain.cache import InMemoryCache
from langchain.globals import set_llm_cache
set_llm_cache(InMemoryCache())
@cl.on_chat_start
async def on_chat_start():
print("starting up")
@cl.on_message
async def on_message(message: cl.Message):
await (cl.Message(content=retrieval_augmented_qa_chain.invoke({"question":message.content})).send())