File size: 3,380 Bytes
c2af1e5
d2e3c7f
 
 
 
 
28f9d4d
d2e3c7f
 
bd0ebd6
d43bb1b
633ac28
d43bb1b
 
aff3a65
a526ade
d2e3c7f
 
 
 
 
0a080de
355b657
d2e3c7f
 
355b657
 
 
 
 
 
 
 
 
 
 
 
5e8e8f0
d2e3c7f
 
 
 
f74eb2e
d2e3c7f
 
 
 
 
 
510767a
355b657
d2e3c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f74eb2e
 
 
 
d2e3c7f
 
 
f74eb2e
d2e3c7f
 
 
5e8e8f0
d2e3c7f
 
5e8e8f0
d2e3c7f
 
 
5e8e8f0
d2e3c7f
f74eb2e
5e8e8f0
d2e3c7f
bd0ebd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import gradio as gr
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory

from langchain.prompts import PromptTemplate



openai_api_key = os.environ.get("OPENAI_API_KEY")

class AdvancedPdfChatbot:
    def __init__(self, openai_api_key):
        os.environ["OPENAI_API_KEY"] = openai_api_key
        self.embeddings = OpenAIEmbeddings()
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        self.llm =  ChatOpenAI(temperature=0,model_name='gpt-4o-mini')
        
        self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
        self.qa_chain = None
        self.template = """
        You are a study partner assistant, students give you pdfs
        and you help them to answer their questions.
        
        Answer the question based on the most recent provided resources only.
        Give the most relevant answer.
        
        Context: {context}
        Question: {question}
        Answer:
        """
        self.prompt = PromptTemplate(template=self.template, input_variables=["context", "question"])

    def load_and_process_pdf(self, pdf_path):
        loader = PyPDFLoader(pdf_path)
        documents = loader.load()
        texts = self.text_splitter.split_documents(documents)
        self.db = FAISS.from_documents(texts, self.embeddings)
        self.setup_conversation_chain()

    def setup_conversation_chain(self):
        self.qa_chain = ConversationalRetrievalChain.from_llm(
            self.llm,
            retriever=self.db.as_retriever(),
            memory=self.memory,
            combine_docs_chain_kwargs={"prompt": self.prompt}
        )

    def chat(self, query):
        if not self.qa_chain:
            return "Please upload a PDF first."
        result = self.qa_chain({"question": query})
        return result['answer']

# Initialize the chatbot
pdf_chatbot = AdvancedPdfChatbot(openai_api_key)

def upload_pdf(pdf_file):
    if pdf_file is None:
        return "Please upload a PDF file."
    file_path = pdf_file.name
    pdf_chatbot.load_and_process_pdf(file_path)
    return "PDF uploaded and processed successfully. You can now start chatting!"

def respond(message, history):
    bot_message = pdf_chatbot.chat(message)
    history.append((message, bot_message))
    return "", history

def clear_chatbot():
    pdf_chatbot.memory.clear()
    return []

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# PDF Chatbot")
    
    with gr.Row():
        pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
        upload_button = gr.Button("Process PDF")

    upload_status = gr.Textbox(label="Upload Status")
    upload_button.click(upload_pdf, inputs=[pdf_upload], outputs=[upload_status])

    chatbot_interface = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    msg.submit(respond, inputs=[msg, chatbot_interface], outputs=[msg, chatbot_interface])
    clear.click(clear_chatbot, outputs=[chatbot_interface])

if __name__ == "__main__":
    demo.launch()