import os from typing import List from langchain.document_loaders import PyPDFLoader, TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings.openai import OpenAIEmbeddings from langchain.vectorstores.pinecone import Pinecone from langchain.chains import RetrievalQA from langchain.chat_models import ChatOpenAI from langchain.memory import ChatMessageHistory, ConversationBufferMemory from langchain.docstore.document import Document import pinecone import chainlit as cl from chainlit.types import AskFileResponse from langchain.prompts import PromptTemplate from dotenv import load_dotenv load_dotenv() openai_api_key = os.getenv("OPENAI_API_KEY") pinecone.init( api_key="2b6aa6bf-2e20-4445-a560-f7dd4952e59e", environment="gcp-starter", ) index_name = "skandhaar" text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) namespaces = set() welcome_message = """Welcome to the Chainlit PDF QA demo! To get started: 1. Upload a PDF or text file """ def process_file(file: AskFileResponse): import tempfile if file.type == "text/plain": Loader = TextLoader elif file.type == "application/pdf": Loader = PyPDFLoader with tempfile.NamedTemporaryFile(mode="wb", delete=False) as tempfile: if file.type == "text/plain": tempfile.write(file.content) elif file.type == "application/pdf": with open(tempfile.name, "wb") as f: f.write(file.content) loader = Loader(tempfile.name) documents = loader.load() docs = text_splitter.split_documents(documents) for i, doc in enumerate(docs): doc.metadata["source"] = f"source_{i}" return docs def get_docsearch(file: AskFileResponse): docs = process_file(file) # Save data in the user session cl.user_session.set("docs", docs) # Create a unique namespace for the file namespace = str(hash(file.content)) if namespace in namespaces: docsearch = Pinecone.from_existing_index( index_name=index_name, embedding=embeddings ) else: docsearch = Pinecone.from_documents( docs, embeddings, index_name=index_name ) namespaces.add(namespace) return docsearch @cl.on_chat_start async def start(): await cl.Avatar( name="Chatbot", url="https://avatars.githubusercontent.com/u/128686189?s=400&u=a1d1553023f8ea0921fba0debbe92a8c5f840dd9&v=4", ).send() files = None while files is None: files = await cl.AskFileMessage( content=welcome_message, accept=["text/plain", "application/pdf"], max_size_mb=20, timeout=180, disable_human_feedback=True, ).send() for file in files: msg = cl.Message( content=f"Processing `{file.name}`...", disable_human_feedback=True ) await msg.send() # No async implementation in the Pinecone client, fallback to sync docsearch = await cl.make_async(get_docsearch)(file) message_history = ChatMessageHistory() memory = ConversationBufferMemory( memory_key="chat_history", output_key="result", chat_memory=message_history, return_messages=True, ) PROMPT = PromptTemplate( template="""Your name is Skandhaar docchat and you are working for Skandhaar org. and your job is to answer the user question from the given context. You are not allowed make an answer and create something that's not there in the context. You strictly follow the context and give extractive answers. Respond for user greetings. If you encounter with out of context questions reply with I'm here to help you with given knowledge source, i can't assist with that. context:{context} question:{question} Answer in the Markdown. """, input_variables=["context", "question"] ) chain_type_kwargs = {"prompt": PROMPT} chain = RetrievalQA.from_chain_type( ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True, openai_api_key=openai_api_key), chain_type="stuff", retriever=docsearch.as_retriever(), return_source_documents=True, chain_type_kwargs=chain_type_kwargs ) # Let the user know that the system is ready msg.content = f"`{file.name}` processed. You can now ask questions!" await msg.update() cl.user_session.set("chain", chain) @cl.on_message async def main(message: cl.Message): chain = cl.user_session.get("chain") # type: ConversationalRetrievalChain cb = cl.AsyncLangchainCallbackHandler() res = await chain.acall(message.content, callbacks=[cb]) answer = res["result"] source_documents = res["source_documents"] # type: List[Document] text_elements = [] # type: List[cl.Text] if source_documents: for source_idx, source_doc in enumerate(source_documents): source_name = f"source_{source_idx}" # Create the text element referenced in the message text_elements.append( cl.Text(content=source_doc.page_content, name=source_name) ) source_names = [text_el.name for text_el in text_elements] if source_names: answer += f"\nSources: {', '.join(source_names)}" else: answer += "\nNo sources found" await cl.Message(content=answer, elements=text_elements).send()