Spaces:
Runtime error
Runtime error
import asyncio | |
import openai | |
import chainlit as cl # importing chainlit for our app | |
from chainlit.prompt import Prompt, PromptMessage # importing prompt tools | |
import os | |
import getpass | |
from dotenv import load_dotenv | |
load_dotenv() | |
os.environ["PINECONE_ENV"] = "gcp-starter" | |
import arxiv | |
arxiv_client = arxiv.Client() | |
paper_urls = [] | |
'''' | |
search = arxiv.Search( | |
query = "Retrieval Augmented Generation", | |
max_results = 5, | |
sort_by = arxiv.SortCriterion.Relevance | |
) | |
for result in arxiv_client.results(search): | |
paper_urls.append(result.pdf_url) | |
print(paper_urls) | |
''' | |
from langchain.document_loaders import PyPDFLoader | |
docs = [] | |
'''' | |
for paper_url in paper_urls: | |
loader = PyPDFLoader(paper_url) | |
docs.append(loader.load()) | |
print(docs[0][6]) | |
''' | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size = 1000, | |
chunk_overlap = 100, | |
length_function = len | |
) | |
import pinecone | |
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration | |
YOUR_API_KEY = os.environ["PINECONE_API_KEY"] | |
YOUR_ENV = os.environ["PINECONE_ENV"] | |
index_name = 'arxiv-paper-index2' | |
pinecone.init( | |
api_key=YOUR_API_KEY, | |
environment=YOUR_ENV | |
) | |
if index_name not in pinecone.list_indexes(): | |
# we create a new index | |
pinecone.create_index( | |
name=index_name, | |
metric='cosine', | |
dimension=1536 | |
) | |
index = pinecone.GRPCIndex(index_name) | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.storage import LocalFileStore | |
store = LocalFileStore("./cache/") | |
core_embeddings_model = OpenAIEmbeddings() | |
embedder = CacheBackedEmbeddings.from_bytes_store( | |
core_embeddings_model, | |
store, | |
namespace=core_embeddings_model.model | |
) | |
from tqdm.auto import tqdm | |
from uuid import uuid4 | |
BATCH_LIMIT = 100 | |
texts = [] | |
metadatas = [] | |
'''' | |
for i in tqdm(range(len(docs))): | |
for doc in docs[i]: | |
metadata = { | |
'source_document' : doc.metadata["source"], | |
'page_number' : doc.metadata["page"] | |
} | |
record_texts = text_splitter.split_text(doc.page_content) | |
record_metadatas = [{ | |
"chunk": j, "text": text, **metadata | |
} for j, text in enumerate(record_texts)] | |
texts.extend(record_texts) | |
metadatas.extend(record_metadatas) | |
if len(texts) >= BATCH_LIMIT: | |
ids = [str(uuid4()) for _ in range(len(texts))] | |
embeds = embedder.embed_documents(texts) | |
index.upsert(vectors=zip(ids, embeds, metadatas)) | |
texts = [] | |
metadatas = [] | |
if len(texts) > 0: | |
ids = [str(uuid4()) for _ in range(len(texts))] | |
embeds = embedder.embed_documents(texts) | |
index.upsert(vectors=zip(ids, embeds, metadatas)) | |
''' | |
from langchain.vectorstores import Pinecone | |
text_field = "text" | |
index = pinecone.Index(index_name) | |
vectorstore = Pinecone( | |
index, | |
embedder.embed_query, | |
text_field | |
) | |
'''' | |
query = "What is dense vector retrieval?" | |
''' | |
'''' | |
vectorstore.similarity_search( | |
query, | |
k=3 | |
) | |
''' | |
from langchain.chat_models import ChatOpenAI | |
llm = ChatOpenAI( | |
model="gpt-3.5-turbo", | |
temperature=0 | |
) | |
from langchain.prompts import ChatPromptTemplate | |
system_template = """Answer the following question with the provided context only in the voice of hulk hogan. If you aren't able to get the answer from that, then please don't answer the question. | |
### CONTEXT | |
{context} | |
###QUESTION | |
{question} | |
""" | |
retriever = vectorstore.as_retriever() | |
from langchain.prompts import ChatPromptTemplate | |
prompt = ChatPromptTemplate.from_template(system_template) | |
from operator import itemgetter | |
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
from langchain.schema import format_document | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain.prompts.prompt import PromptTemplate | |
retrieval_augmented_qa_chain = ( | |
{"context": itemgetter("question") | retriever, | |
"question": itemgetter("question") | |
} | |
| RunnablePassthrough.assign( | |
context=itemgetter("context") | |
) | |
| { | |
"response": prompt | llm, | |
"context": itemgetter("context"), | |
} | |
) | |
import langchain | |
from langchain.cache import InMemoryCache | |
from langchain.globals import set_llm_cache | |
set_llm_cache(InMemoryCache()) | |
async def on_chat_start(): | |
print("starting up") | |
async def on_message(message: cl.Message): | |
await (cl.Message(content=retrieval_augmented_qa_chain.invoke({"question":message.content})).send()) | |