import os from typing import List from langchain.document_loaders import PyPDFLoader, TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings.openai import OpenAIEmbeddings from langchain.vectorstores.pinecone import Pinecone from langchain.chains import RetrievalQA from langchain.chat_models import ChatOpenAI from langchain.memory import ChatMessageHistory, ConversationBufferMemory from langchain.docstore.document import Document import pinecone import chainlit as cl from chainlit.types import AskFileResponse from dotenv import load_dotenv load_dotenv() openai_api_key = os.getenv("OPENAI_API_KEY") pinecone.init( api_key="2b6aa6bf-2e20-4445-a560-f7dd4952e59e", environment="gcp-starter", ) index_name = "skandhaar" text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) namespaces = set() welcome_message = """Welcome to the Chainlit PDF QA demo! To get started: 1. Upload a PDF or text file """ def process_file(file: AskFileResponse): import tempfile if file.type == "text/plain": Loader = TextLoader elif file.type == "application/pdf": Loader = PyPDFLoader with tempfile.NamedTemporaryFile(mode="wb", delete=False) as tempfile: if file.type == "text/plain": tempfile.write(file.content) elif file.type == "application/pdf": with open(tempfile.name, "wb") as f: f.write(file.content) loader = Loader(tempfile.name) documents = loader.load() docs = text_splitter.split_documents(documents) for i, doc in enumerate(docs): doc.metadata["source"] = f"source_{i}" return docs def get_docsearch(file: AskFileResponse): docs = process_file(file) # Save data in the user session cl.user_session.set("docs", docs) # Create a unique namespace for the file namespace = str(hash(file.content)) if namespace in namespaces: docsearch = Pinecone.from_existing_index( index_name=index_name, embedding=embeddings ) else: docsearch = Pinecone.from_documents( docs, embeddings, index_name=index_name ) namespaces.add(namespace) return docsearch @cl.on_chat_start async def start(): await cl.Avatar( name="Chatbot", url="https://avatars.githubusercontent.com/u/128686189?s=400&u=a1d1553023f8ea0921fba0debbe92a8c5f840dd9&v=4", ).send() files = None while files is None: files = await cl.AskFileMessage( content=welcome_message, accept=["text/plain", "application/pdf"], max_size_mb=20, timeout=180, disable_human_feedback=True, ).send() for file in files: msg = cl.Message( content=f"Processing `{file.name}`...", disable_human_feedback=True ) await msg.send() # No async implementation in the Pinecone client, fallback to sync docsearch = await cl.make_async(get_docsearch)(file) message_history = ChatMessageHistory() memory = ConversationBufferMemory( memory_key="chat_history", output_key="result", chat_memory=message_history, return_messages=True, ) chain = RetrievalQA.from_chain_type( ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True, openai_api_key=openai_api_key), chain_type="stuff", retriever=docsearch.as_retriever(), return_source_documents=True, ) # Let the user know that the system is ready msg.content = f"`{file.name}` processed. You can now ask questions!" await msg.update() cl.user_session.set("chain", chain) @cl.on_message async def main(message: cl.Message): chain = cl.user_session.get("chain") # type: ConversationalRetrievalChain cb = cl.AsyncLangchainCallbackHandler() res = await chain.acall(message.content, callbacks=[cb]) answer = res["result"] source_documents = res["source_documents"] # type: List[Document] text_elements = [] # type: List[cl.Text] if source_documents: for source_idx, source_doc in enumerate(source_documents): source_name = f"source_{source_idx}" # Create the text element referenced in the message text_elements.append( cl.Text(content=source_doc.page_content, name=source_name) ) source_names = [text_el.name for text_el in text_elements] if source_names: answer += f"\nSources: {', '.join(source_names)}" else: answer += "\nNo sources found" await cl.Message(content=answer, elements=text_elements).send()