Pavan178's picture
Update app.py
6c5c0ad verified
raw
history blame
5.57 kB
import os
import gradio as gr
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains.base import Chain
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
openai_api_key = os.environ.get("OPENAI_API_KEY")
class AdvancedPdfChatbot:
def __init__(self, openai_api_key):
os.environ["OPENAI_API_KEY"] = openai_api_key
self.embeddings = OpenAIEmbeddings()
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
self.llm = ChatOpenAI(temperature=0, model_name='gpt-4') # Corrected model name
self.refinement_llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo')
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
self.overall_chain = None
self.db = None
self.refinement_prompt = PromptTemplate(
input_variables=['query', 'chat_history'],
template="""Given the user's query and the conversation history, refine the query to be more specific and detailed.
If the query is too vague, make reasonable assumptions based on the conversation context.
Output the refined query."""
)
self.template = """
You are a study partner assistant, students give you pdfs and you help them to answer their questions.
Answer the question based on the most recent provided resources only.
Give the most relevant answer.
Context: {context}
Question: {question}
Answer:
(Note: YOUR OUTPUT IS RENDERED IN PROPER PARAGRAPHS or BULLET POINTS when needed, modify the response formats as needed, only choose the formats based on the type of question asked)
"""
self.prompt = PromptTemplate(template=self.template, input_variables=["context", "question"])
def load_and_process_pdf(self, pdf_path):
loader = PyPDFLoader(pdf_path)
documents = loader.load()
texts = self.text_splitter.split_documents(documents)
self.db = FAISS.from_documents(texts, self.embeddings)
self.setup_conversation_chain()
def setup_conversation_chain(self):
refinement_chain = LLMChain(
llm=self.refinement_llm,
prompt=self.refinement_prompt,
output_key='refined_query'
)
qa_chain = ConversationalRetrievalChain.from_llm(
self.llm,
retriever=self.db.as_retriever(),
memory=self.memory,
combine_docs_chain_kwargs={"prompt": self.prompt}
)
self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)
class CustomChain(Chain):
def __init__(self, refinement_chain, qa_chain):
super().__init__()
self.refinement_chain = refinement_chain
self.qa_chain = qa_chain
@property
def input_keys(self):
return ["query", "chat_history"]
@property
def output_keys(self):
return ["answer"]
def _call(self, inputs):
query = inputs['query']
chat_history = inputs.get('chat_history', [])
refined_query = self.refinement_chain.run(query=query, chat_history=chat_history)
response = self.qa_chain({"question": refined_query, "chat_history": chat_history})
return {"answer": response['answer']}
def chat(self, query):
if not self.overall_chain:
return "Please upload a PDF first."
chat_history = self.memory.load_memory_variables({})['chat_history']
result = self.overall_chain({'query': query, 'chat_history': chat_history})
return result['answer']
def get_pdf_path(self):
if self.db:
return self.db.path
else:
return "No PDF uploaded yet."
# Initialize the chatbot
pdf_chatbot = AdvancedPdfChatbot(openai_api_key)
def upload_pdf(pdf_file):
if pdf_file is None:
return "Please upload a PDF file."
file_path = pdf_file.name
pdf_chatbot.load_and_process_pdf(file_path)
return file_path
def respond(message, history):
bot_message = pdf_chatbot.chat(message)
history.append((message, bot_message))
return "", history
def clear_chatbot():
pdf_chatbot.memory.clear()
return []
def get_pdf_path():
return pdf_chatbot.get_pdf_path()
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# PDF Chatbot")
with gr.Row():
pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
upload_button = gr.Button("Process PDF")
upload_status = gr.Textbox(label="Upload Status")
upload_button.click(upload_pdf, inputs=[pdf_upload], outputs=[upload_status])
path_button = gr.Button("Get PDF Path")
pdf_path_display = gr.Textbox(label="Current PDF Path")
chatbot_interface = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
msg.submit(respond, inputs=[msg, chatbot_interface], outputs=[msg, chatbot_interface])
clear.click(clear_chatbot, outputs=[chatbot_interface])
path_button.click(get_pdf_path, outputs=[pdf_path_display])
if __name__ == "__main__":
demo.launch()