Spaces:
Runtime error
Runtime error
import asyncio | |
import openai | |
import chainlit as cl # importing chainlit for our app | |
from chainlit.prompt import Prompt, PromptMessage # importing prompt tools | |
import os | |
import getpass | |
from dotenv import load_dotenv | |
load_dotenv() | |
os.environ["PINECONE_ENV"] = "gcp-starter" | |
import arxiv | |
arxiv_client = arxiv.Client() | |
paper_urls = [] | |
'''' | |
search = arxiv.Search( | |
query = "Retrieval Augmented Generation", | |
max_results = 5, | |
sort_by = arxiv.SortCriterion.Relevance | |
) | |
for result in arxiv_client.results(search): | |
paper_urls.append(result.pdf_url) | |
print(paper_urls) | |
''' | |
from langchain.document_loaders import PyPDFLoader | |
docs = [] | |
'''' | |
for paper_url in paper_urls: | |
loader = PyPDFLoader(paper_url) | |
docs.append(loader.load()) | |
print(docs[0][6]) | |
''' | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size = 1000, | |
chunk_overlap = 100, | |
length_function = len | |
) | |
import pinecone | |
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration | |
YOUR_API_KEY = os.environ["PINECONE_API_KEY"] | |
YOUR_ENV = os.environ["PINECONE_ENV"] | |
index_name = 'arxiv-paper-index2' | |
pinecone.init( | |
api_key=YOUR_API_KEY, | |
environment=YOUR_ENV | |
) | |
if index_name not in pinecone.list_indexes(): | |
# we create a new index | |
pinecone.create_index( | |
name=index_name, | |
metric='cosine', | |
dimension=1536 | |
) | |
index = pinecone.GRPCIndex(index_name) | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.storage import LocalFileStore | |
store = LocalFileStore("./cache/") | |
core_embeddings_model = OpenAIEmbeddings() | |
embedder = CacheBackedEmbeddings.from_bytes_store( | |
core_embeddings_model, | |
store, | |
namespace=core_embeddings_model.model | |
) | |
from tqdm.auto import tqdm | |
from uuid import uuid4 | |
BATCH_LIMIT = 100 | |
texts = [] | |
metadatas = [] | |
'''' | |
for i in tqdm(range(len(docs))): | |
for doc in docs[i]: | |
metadata = { | |
'source_document' : doc.metadata["source"], | |
'page_number' : doc.metadata["page"] | |
} | |
record_texts = text_splitter.split_text(doc.page_content) | |
record_metadatas = [{ | |
"chunk": j, "text": text, **metadata | |
} for j, text in enumerate(record_texts)] | |
texts.extend(record_texts) | |
metadatas.extend(record_metadatas) | |
if len(texts) >= BATCH_LIMIT: | |
ids = [str(uuid4()) for _ in range(len(texts))] | |
embeds = embedder.embed_documents(texts) | |
index.upsert(vectors=zip(ids, embeds, metadatas)) | |
texts = [] | |
metadatas = [] | |
if len(texts) > 0: | |
ids = [str(uuid4()) for _ in range(len(texts))] | |
embeds = embedder.embed_documents(texts) | |
index.upsert(vectors=zip(ids, embeds, metadatas)) | |
''' | |
from langchain.vectorstores import Pinecone | |
text_field = "text" | |
index = pinecone.Index(index_name) | |
vectorstore = Pinecone( | |
index, | |
embedder.embed_query, | |
text_field | |
) | |
'''' | |
query = "What is dense vector retrieval?" | |
''' | |
'''' | |
vectorstore.similarity_search( | |
query, | |
k=3 | |
) | |
''' | |
from langchain.chat_models import ChatOpenAI | |
llm = ChatOpenAI( | |
model="gpt-3.5-turbo", | |
temperature=0 | |
) | |
from langchain.prompts import ChatPromptTemplate | |
system_template = """Answer the following question with the provided context only. If you aren't able to get the answer from the provided context only, then please don't answer the question. | |
### CONTEXT | |
{context} | |
###QUESTION | |
{question} | |
""" | |
retriever = vectorstore.as_retriever() | |
from langchain.prompts import ChatPromptTemplate | |
prompt = ChatPromptTemplate.from_template(system_template) | |
from operator import itemgetter | |
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
from langchain.schema import format_document | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain.prompts.prompt import PromptTemplate | |
retrieval_augmented_qa_chain = ( | |
{"context": itemgetter("question") | retriever, | |
"question": itemgetter("question") | |
} | |
| RunnablePassthrough.assign( | |
context=itemgetter("context") | |
) | |
| { | |
"response": prompt | llm, | |
"context": itemgetter("context"), | |
} | |
) | |
import langchain | |
from langchain.cache import InMemoryCache | |
from langchain.globals import set_llm_cache | |
set_llm_cache(InMemoryCache()) | |
async def on_chat_start(): | |
print("starting up") | |
async def on_message(message: cl.Message): | |
await (cl.Message(content=retrieval_augmented_qa_chain.invoke({"question":message.content})).send()) | |