File size: 6,284 Bytes
d8c3a88 d2e3c7f b840efb 6c5c0ad b840efb d2e3c7f d43bb1b 633ac28 aff3a65 a526ade d2e3c7f 8af0aff 3f31c68 355b657 d2e3c7f 3f31c68 355b657 3f31c68 355b657 3f31c68 355b657 5e8e8f0 d2e3c7f f74eb2e d2e3c7f 8af0aff 873a6e6 b840efb 160264d b840efb 873a6e6 b840efb 873a6e6 b840efb 3f31c68 b3de9a3 e224a4b 873a6e6 e224a4b b3de9a3 e224a4b 160264d e224a4b b3de9a3 e224a4b 160264d e224a4b b3de9a3 e224a4b 160264d b3de9a3 160264d b3de9a3 e224a4b a261843 d2e3c7f 3f31c68 d2e3c7f 3f31c68 d2e3c7f ff0e62c 3f31c68 ff0e62c d2e3c7f 8af0aff d2e3c7f 8af0aff d2e3c7f f74eb2e ff0e62c d2e3c7f f74eb2e d2e3c7f 5e8e8f0 d2e3c7f 95c7827 d2e3c7f 5e8e8f0 d2e3c7f 8af0aff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
import gradio as gr
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains.base import Chain
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
openai_api_key = os.environ.get("OPENAI_API_KEY")
class AdvancedPdfChatbot:
def __init__(self, openai_api_key):
os.environ["OPENAI_API_KEY"] = openai_api_key
self.embeddings = OpenAIEmbeddings()
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
self.llm = ChatOpenAI(temperature=0, model_name='gpt-4')
self.refinement_llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo')
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
self.overall_chain = None
self.db = None
self.refinement_prompt = PromptTemplate(
input_variables=['query', 'chat_history'],
template="""Given the user's query and the conversation history, refine the query to be more specific and detailed.
If the query is too vague, make reasonable assumptions based on the conversation context.
Output the refined query."""
)
self.template = """
You are a study partner assistant, students give you pdfs and you help them to answer their questions.
Answer the question based on the most recent provided resources only.
Give the most relevant answer.
Context: {context}
Question: {question}
Answer:
(Note: YOUR OUTPUT IS RENDERED IN PROPER PARAGRAPHS or BULLET POINTS when needed, modify the response formats as needed, only choose the formats based on the type of question asked)
"""
self.prompt = PromptTemplate(template=self.template, input_variables=["context", "question"])
def load_and_process_pdf(self, pdf_path):
loader = PyPDFLoader(pdf_path)
documents = loader.load()
texts = self.text_splitter.split_documents(documents)
self.db = FAISS.from_documents(texts, self.embeddings)
self.setup_conversation_chain()
def setup_conversation_chain(self):
if not self.db:
raise ValueError("Database not initialized. Please upload a PDF first.")
refinement_chain = LLMChain(
llm=self.refinement_llm,
prompt=self.refinement_prompt
)
qa_chain = ConversationalRetrievalChain.from_llm(
self.llm,
retriever=self.db.as_retriever(),
memory=self.memory,
combine_docs_chain_kwargs={"prompt": self.prompt}
)
self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)
class CustomChain(Chain):
def __init__(self, refinement_chain, qa_chain):
"""Initialize refinement and QA chains as instance attributes."""
super().__init__()
self._refinement_chain = refinement_chain # Use a different attribute name
self._qa_chain = qa_chain
@property
def input_keys(self):
"""Define the input keys that this chain expects."""
return ["query", "chat_history"]
@property
def output_keys(self):
"""Define the output keys that this chain returns."""
return ["answer"]
def _call(self, inputs):
query = inputs['query']
chat_history = inputs.get('chat_history', [])
# Run the refinement chain to refine the query
refinement_inputs = {'query': query, 'chat_history': chat_history}
refined_query = self._refinement_chain.run(refinement_inputs) # Use the renamed attribute
# Run the QA chain using the refined query and the chat history
qa_inputs = {"question": refined_query, "chat_history": chat_history}
response = self._qa_chain(qa_inputs) # Use the renamed attribute
return {"answer": response['answer']}
def chat(self, query):
if not self.overall_chain:
return "Please upload a PDF first."
chat_history = self.memory.load_memory_variables({})['chat_history']
result = self.overall_chain({'query': query, 'chat_history': chat_history})
return result['answer']
def get_pdf_path(self):
if self.db:
return self.db.path
else:
return "No PDF uploaded yet."
# Initialize the chatbot
pdf_chatbot = AdvancedPdfChatbot(openai_api_key)
def upload_pdf(pdf_file):
if pdf_file is None:
return "Please upload a PDF file."
file_path = pdf_file.name if hasattr(pdf_file, 'name') else pdf_file
try:
pdf_chatbot.load_and_process_pdf(file_path)
return f"PDF processed successfully: {file_path}"
except Exception as e:
return f"Error processing PDF: {str(e)}"
def respond(message, history):
if not message:
return "", history
try:
bot_message = pdf_chatbot.chat(message)
history.append((message, bot_message))
return "", history
except Exception as e:
return f"Error: {str(e)}", history
def clear_chatbot():
pdf_chatbot.memory.clear()
return []
def get_pdf_path():
return pdf_chatbot.get_pdf_path()
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# PDF Chatbot")
with gr.Row():
pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
upload_button = gr.Button("Process PDF")
upload_status = gr.Textbox(label="Upload Status")
upload_button.click(upload_pdf, inputs=[pdf_upload], outputs=[upload_status])
chatbot_interface = gr.Chatbot()
msg = gr.Textbox()
msg.submit(respond, inputs=[msg, chatbot_interface], outputs=[msg, chatbot_interface])
if __name__ == "__main__":
demo.launch()
|