Spaces:
Runtime error
Runtime error
import os | |
from typing import List | |
from langchain.document_loaders import PyPDFLoader, TextLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.vectorstores.pinecone import Pinecone | |
from langchain.chains import RetrievalQA | |
from langchain.chat_models import ChatOpenAI | |
from langchain.memory import ChatMessageHistory, ConversationBufferMemory | |
from langchain.docstore.document import Document | |
import pinecone | |
import chainlit as cl | |
from chainlit.types import AskFileResponse | |
from langchain.prompts import PromptTemplate | |
from dotenv import load_dotenv | |
load_dotenv() | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
pinecone.init( | |
api_key="2b6aa6bf-2e20-4445-a560-f7dd4952e59e", | |
environment="gcp-starter", | |
) | |
index_name = "skandhaar" | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) | |
namespaces = set() | |
welcome_message = """Welcome to the Chainlit PDF QA demo! To get started: | |
1. Upload a PDF or text file | |
""" | |
def process_file(file: AskFileResponse): | |
import tempfile | |
if file.type == "text/plain": | |
Loader = TextLoader | |
elif file.type == "application/pdf": | |
Loader = PyPDFLoader | |
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as tempfile: | |
if file.type == "text/plain": | |
tempfile.write(file.content) | |
elif file.type == "application/pdf": | |
with open(tempfile.name, "wb") as f: | |
f.write(file.content) | |
loader = Loader(tempfile.name) | |
documents = loader.load() | |
docs = text_splitter.split_documents(documents) | |
for i, doc in enumerate(docs): | |
doc.metadata["source"] = f"source_{i}" | |
return docs | |
def get_docsearch(file: AskFileResponse): | |
docs = process_file(file) | |
# Save data in the user session | |
cl.user_session.set("docs", docs) | |
# Create a unique namespace for the file | |
namespace = str(hash(file.content)) | |
if namespace in namespaces: | |
docsearch = Pinecone.from_existing_index( | |
index_name=index_name, embedding=embeddings | |
) | |
else: | |
docsearch = Pinecone.from_documents( | |
docs, embeddings, index_name=index_name | |
) | |
namespaces.add(namespace) | |
return docsearch | |
async def start(): | |
await cl.Avatar( | |
name="Chatbot", | |
url="https://avatars.githubusercontent.com/u/128686189?s=400&u=a1d1553023f8ea0921fba0debbe92a8c5f840dd9&v=4", | |
).send() | |
files = None | |
while files is None: | |
files = await cl.AskFileMessage( | |
content=welcome_message, | |
accept=["text/plain", "application/pdf"], | |
max_size_mb=20, | |
timeout=180, | |
disable_human_feedback=True, | |
).send() | |
for file in files: | |
msg = cl.Message( | |
content=f"Processing `{file.name}`...", disable_human_feedback=True | |
) | |
await msg.send() | |
# No async implementation in the Pinecone client, fallback to sync | |
docsearch = await cl.make_async(get_docsearch)(file) | |
message_history = ChatMessageHistory() | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
output_key="result", | |
chat_memory=message_history, | |
return_messages=True, | |
) | |
PROMPT = PromptTemplate( | |
template="""Your name is Skandhaar docchat and you are working for Skandhaar org. and your job is to answer the user question from the given context. You are not allowed make an answer and create something that's not there in the context. You strictly follow the context and give extractive answers. | |
Respond for user greetings. If you encounter with out of context questions reply with I'm here to help you with given knowledge source, i can't assist with that. | |
context:{context} | |
question:{question} | |
Answer in the Markdown. | |
""", | |
input_variables=["context", "question"] | |
) | |
chain_type_kwargs = {"prompt": PROMPT} | |
chain = RetrievalQA.from_chain_type( | |
ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True, openai_api_key=openai_api_key), | |
chain_type="stuff", | |
retriever=docsearch.as_retriever(), | |
return_source_documents=True, | |
chain_type_kwargs=chain_type_kwargs | |
) | |
# Let the user know that the system is ready | |
msg.content = f"`{file.name}` processed. You can now ask questions!" | |
await msg.update() | |
cl.user_session.set("chain", chain) | |
async def main(message: cl.Message): | |
chain = cl.user_session.get("chain") # type: ConversationalRetrievalChain | |
cb = cl.AsyncLangchainCallbackHandler() | |
res = await chain.acall(message.content, callbacks=[cb]) | |
answer = res["result"] | |
source_documents = res["source_documents"] # type: List[Document] | |
text_elements = [] # type: List[cl.Text] | |
if source_documents: | |
for source_idx, source_doc in enumerate(source_documents): | |
source_name = f"source_{source_idx}" | |
# Create the text element referenced in the message | |
text_elements.append( | |
cl.Text(content=source_doc.page_content, name=source_name) | |
) | |
source_names = [text_el.name for text_el in text_elements] | |
if source_names: | |
answer += f"\nSources: {', '.join(source_names)}" | |
else: | |
answer += "\nNo sources found" | |
await cl.Message(content=answer, elements=text_elements).send() |