import asyncio import openai import chainlit as cl # importing chainlit for our app from chainlit.prompt import Prompt, PromptMessage # importing prompt tools import os import getpass from dotenv import load_dotenv load_dotenv() os.environ["PINECONE_ENV"] = "gcp-starter" import arxiv arxiv_client = arxiv.Client() paper_urls = [] '''' search = arxiv.Search( query = "Retrieval Augmented Generation", max_results = 5, sort_by = arxiv.SortCriterion.Relevance ) for result in arxiv_client.results(search): paper_urls.append(result.pdf_url) print(paper_urls) ''' from langchain.document_loaders import PyPDFLoader docs = [] '''' for paper_url in paper_urls: loader = PyPDFLoader(paper_url) docs.append(loader.load()) print(docs[0][6]) ''' from langchain.text_splitter import RecursiveCharacterTextSplitter text_splitter = RecursiveCharacterTextSplitter( chunk_size = 1000, chunk_overlap = 100, length_function = len ) import pinecone from pinecone.core.client.configuration import Configuration as OpenApiConfiguration YOUR_API_KEY = os.environ["PINECONE_API_KEY"] YOUR_ENV = os.environ["PINECONE_ENV"] index_name = 'arxiv-paper-index2' pinecone.init( api_key=YOUR_API_KEY, environment=YOUR_ENV ) if index_name not in pinecone.list_indexes(): # we create a new index pinecone.create_index( name=index_name, metric='cosine', dimension=1536 ) index = pinecone.GRPCIndex(index_name) from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings import CacheBackedEmbeddings from langchain.storage import LocalFileStore store = LocalFileStore("./cache/") core_embeddings_model = OpenAIEmbeddings() embedder = CacheBackedEmbeddings.from_bytes_store( core_embeddings_model, store, namespace=core_embeddings_model.model ) from tqdm.auto import tqdm from uuid import uuid4 BATCH_LIMIT = 100 texts = [] metadatas = [] '''' for i in tqdm(range(len(docs))): for doc in docs[i]: metadata = { 'source_document' : doc.metadata["source"], 'page_number' : doc.metadata["page"] } record_texts = text_splitter.split_text(doc.page_content) record_metadatas = [{ "chunk": j, "text": text, **metadata } for j, text in enumerate(record_texts)] texts.extend(record_texts) metadatas.extend(record_metadatas) if len(texts) >= BATCH_LIMIT: ids = [str(uuid4()) for _ in range(len(texts))] embeds = embedder.embed_documents(texts) index.upsert(vectors=zip(ids, embeds, metadatas)) texts = [] metadatas = [] if len(texts) > 0: ids = [str(uuid4()) for _ in range(len(texts))] embeds = embedder.embed_documents(texts) index.upsert(vectors=zip(ids, embeds, metadatas)) ''' from langchain.vectorstores import Pinecone text_field = "text" index = pinecone.Index(index_name) vectorstore = Pinecone( index, embedder.embed_query, text_field ) '''' query = "What is dense vector retrieval?" ''' '''' vectorstore.similarity_search( query, k=3 ) ''' from langchain.chat_models import ChatOpenAI llm = ChatOpenAI( model="gpt-3.5-turbo", temperature=0 ) from langchain.prompts import ChatPromptTemplate system_template = """Answer the following question with the provided context only. If you aren't able to get the answer from the provided context only, then please don't answer the question. ### CONTEXT {context} ###QUESTION {question} """ retriever = vectorstore.as_retriever() from langchain.prompts import ChatPromptTemplate prompt = ChatPromptTemplate.from_template(system_template) from operator import itemgetter from langchain.schema.runnable import RunnableLambda, RunnablePassthrough from langchain.schema import format_document from langchain.schema.output_parser import StrOutputParser from langchain.prompts.prompt import PromptTemplate retrieval_augmented_qa_chain = ( {"context": itemgetter("question") | retriever, "question": itemgetter("question") } | RunnablePassthrough.assign( context=itemgetter("context") ) | { "response": prompt | llm, "context": itemgetter("context"), } ) import langchain from langchain.cache import InMemoryCache from langchain.globals import set_llm_cache set_llm_cache(InMemoryCache()) @cl.on_chat_start async def on_chat_start(): print("starting up") @cl.on_message async def on_message(message: cl.Message): await (cl.Message(content=retrieval_augmented_qa_chain.invoke({"question":message.content})).send())