import os import gradio as gr from langchain.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores import FAISS from langchain.chains.base import Chain from langchain.chat_models import ChatOpenAI from langchain.chains import LLMChain, ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from langchain.prompts import PromptTemplate openai_api_key = os.environ.get("OPENAI_API_KEY") class AdvancedPdfChatbot: def __init__(self, openai_api_key): os.environ["OPENAI_API_KEY"] = openai_api_key self.embeddings = OpenAIEmbeddings() self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) self.llm = ChatOpenAI(temperature=0, model_name='gpt-4') self.refinement_llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo') self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) self.overall_chain = None self.db = None self.refinement_prompt = PromptTemplate( input_variables=['query', 'chat_history'], template="""Given the user's query and the conversation history, refine the query to be more specific and detailed. If the query is too vague, make reasonable assumptions based on the conversation context. Output the refined query.""" ) self.template = """ You are a study partner assistant, students give you pdfs and you help them to answer their questions. Answer the question based on the most recent provided resources only. Give the most relevant answer. Context: {context} Question: {question} Answer: (Note: YOUR OUTPUT IS RENDERED IN PROPER PARAGRAPHS or BULLET POINTS when needed, modify the response formats as needed, only choose the formats based on the type of question asked) """ self.prompt = PromptTemplate(template=self.template, input_variables=["context", "question"]) def load_and_process_pdf(self, pdf_path): loader = PyPDFLoader(pdf_path) documents = loader.load() texts = self.text_splitter.split_documents(documents) self.db = FAISS.from_documents(texts, self.embeddings) self.setup_conversation_chain() def setup_conversation_chain(self): if not self.db: raise ValueError("Database not initialized. Please upload a PDF first.") refinement_chain = LLMChain( llm=self.refinement_llm, prompt=self.refinement_prompt ) qa_chain = ConversationalRetrievalChain.from_llm( self.llm, retriever=self.db.as_retriever(), memory=self.memory, combine_docs_chain_kwargs={"prompt": self.prompt} ) self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain) class CustomChain(Chain): def __init__(self, refinement_chain, qa_chain): """Initialize refinement and QA chains as instance attributes.""" super().__init__() self._refinement_chain = refinement_chain # Use a different attribute name self._qa_chain = qa_chain @property def input_keys(self): """Define the input keys that this chain expects.""" return ["query", "chat_history"] @property def output_keys(self): """Define the output keys that this chain returns.""" return ["answer"] def _call(self, inputs): query = inputs['query'] chat_history = inputs.get('chat_history', []) # Run the refinement chain to refine the query refinement_inputs = {'query': query, 'chat_history': chat_history} refined_query = self._refinement_chain.run(refinement_inputs) # Use the renamed attribute # Run the QA chain using the refined query and the chat history qa_inputs = {"question": refined_query, "chat_history": chat_history} response = self._qa_chain(qa_inputs) # Use the renamed attribute return {"answer": response['answer']} def chat(self, query): if not self.overall_chain: return "Please upload a PDF first." chat_history = self.memory.load_memory_variables({})['chat_history'] result = self.overall_chain({'query': query, 'chat_history': chat_history}) return result['answer'] def get_pdf_path(self): if self.db: return self.db.path else: return "No PDF uploaded yet." # Initialize the chatbot pdf_chatbot = AdvancedPdfChatbot(openai_api_key) def upload_pdf(pdf_file): if pdf_file is None: return "Please upload a PDF file." file_path = pdf_file.name if hasattr(pdf_file, 'name') else pdf_file try: pdf_chatbot.load_and_process_pdf(file_path) return f"PDF processed successfully: {file_path}" except Exception as e: return f"Error processing PDF: {str(e)}" def respond(message, history): if not message: return "", history try: bot_message = pdf_chatbot.chat(message) history.append((message, bot_message)) return "", history except Exception as e: return f"Error: {str(e)}", history def clear_chatbot(): pdf_chatbot.memory.clear() return [] def get_pdf_path(): return pdf_chatbot.get_pdf_path() # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# PDF Chatbot") with gr.Row(): pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"]) upload_button = gr.Button("Process PDF") upload_status = gr.Textbox(label="Upload Status") upload_button.click(upload_pdf, inputs=[pdf_upload], outputs=[upload_status]) chatbot_interface = gr.Chatbot() msg = gr.Textbox() msg.submit(respond, inputs=[msg, chatbot_interface], outputs=[msg, chatbot_interface]) if __name__ == "__main__": demo.launch()