File size: 5,570 Bytes
6c5c0ad
d2e3c7f
 
 
 
 
b840efb
6c5c0ad
b840efb
d2e3c7f
d43bb1b
633ac28
aff3a65
a526ade
d2e3c7f
 
 
 
 
3f31c68
 
355b657
d2e3c7f
3f31c68
 
 
 
 
 
 
 
 
 
355b657
3f31c68
355b657
 
 
 
 
 
 
3f31c68
355b657
 
5e8e8f0
d2e3c7f
 
 
 
f74eb2e
d2e3c7f
 
 
b840efb
 
 
 
 
 
 
 
 
 
 
 
3f31c68
b840efb
 
 
 
 
3f31c68
b840efb
 
 
3f31c68
b840efb
 
 
3f31c68
b840efb
 
 
 
 
 
d2e3c7f
 
3f31c68
d2e3c7f
3f31c68
 
d2e3c7f
 
ff0e62c
3f31c68
 
ff0e62c
 
 
d2e3c7f
 
 
 
 
 
 
 
5afc751
d2e3c7f
 
 
 
 
 
f74eb2e
 
 
 
ff0e62c
 
 
d2e3c7f
 
 
f74eb2e
d2e3c7f
 
 
5e8e8f0
d2e3c7f
 
596dcf4
0ae5df7
d2e3c7f
 
 
5e8e8f0
d2e3c7f
f74eb2e
ff0e62c
5e8e8f0
d2e3c7f
3f31c68
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
  import os
import gradio as gr
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains.base import Chain
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate

openai_api_key = os.environ.get("OPENAI_API_KEY")

class AdvancedPdfChatbot:
    def __init__(self, openai_api_key):
        os.environ["OPENAI_API_KEY"] = openai_api_key
        self.embeddings = OpenAIEmbeddings()
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        self.llm = ChatOpenAI(temperature=0, model_name='gpt-4')  # Corrected model name
        self.refinement_llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo')
        
        self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
        self.overall_chain = None
        self.db = None
        
        self.refinement_prompt = PromptTemplate(
            input_variables=['query', 'chat_history'],
            template="""Given the user's query and the conversation history, refine the query to be more specific and detailed.
            If the query is too vague, make reasonable assumptions based on the conversation context.
            Output the refined query."""
        )
        
        self.template = """
        You are a study partner assistant, students give you pdfs and you help them to answer their questions.
        
        Answer the question based on the most recent provided resources only.
        Give the most relevant answer.
        
        Context: {context}
        Question: {question}
        Answer:
        (Note: YOUR OUTPUT IS RENDERED IN PROPER PARAGRAPHS or BULLET POINTS when needed, modify the response formats as needed, only choose the formats based on the type of question asked)
        """
        self.prompt = PromptTemplate(template=self.template, input_variables=["context", "question"])

    def load_and_process_pdf(self, pdf_path):
        loader = PyPDFLoader(pdf_path)
        documents = loader.load()
        texts = self.text_splitter.split_documents(documents)
        self.db = FAISS.from_documents(texts, self.embeddings)
        self.setup_conversation_chain()

    def setup_conversation_chain(self):
        refinement_chain = LLMChain(
            llm=self.refinement_llm,
            prompt=self.refinement_prompt,
            output_key='refined_query'
        )
        qa_chain = ConversationalRetrievalChain.from_llm(
            self.llm,
            retriever=self.db.as_retriever(),
            memory=self.memory,
            combine_docs_chain_kwargs={"prompt": self.prompt}
        )
        self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)

    class CustomChain(Chain):
        def __init__(self, refinement_chain, qa_chain):
            super().__init__()
            self.refinement_chain = refinement_chain
            self.qa_chain = qa_chain

        @property
        def input_keys(self):
            return ["query", "chat_history"]

        @property
        def output_keys(self):
            return ["answer"]

        def _call(self, inputs):
            query = inputs['query']
            chat_history = inputs.get('chat_history', [])
            refined_query = self.refinement_chain.run(query=query, chat_history=chat_history)
            response = self.qa_chain({"question": refined_query, "chat_history": chat_history})
            return {"answer": response['answer']}

    def chat(self, query):
        if not self.overall_chain:
            return "Please upload a PDF first."
        chat_history = self.memory.load_memory_variables({})['chat_history']
        result = self.overall_chain({'query': query, 'chat_history': chat_history})
        return result['answer']

    def get_pdf_path(self):
        if self.db:
            return self.db.path
        else:
            return "No PDF uploaded yet."

# Initialize the chatbot
pdf_chatbot = AdvancedPdfChatbot(openai_api_key)

def upload_pdf(pdf_file):
    if pdf_file is None:
        return "Please upload a PDF file."
    file_path = pdf_file.name
    pdf_chatbot.load_and_process_pdf(file_path)
    return file_path

def respond(message, history):
    bot_message = pdf_chatbot.chat(message)
    history.append((message, bot_message))
    return "", history

def clear_chatbot():
    pdf_chatbot.memory.clear()
    return []

def get_pdf_path():
    return pdf_chatbot.get_pdf_path()

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# PDF Chatbot")
    
    with gr.Row():
        pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
        upload_button = gr.Button("Process PDF")

    upload_status = gr.Textbox(label="Upload Status")
    upload_button.click(upload_pdf, inputs=[pdf_upload], outputs=[upload_status])
    path_button = gr.Button("Get PDF Path")
    pdf_path_display = gr.Textbox(label="Current PDF Path")
    chatbot_interface = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    msg.submit(respond, inputs=[msg, chatbot_interface], outputs=[msg, chatbot_interface])
    clear.click(clear_chatbot, outputs=[chatbot_interface])
    path_button.click(get_pdf_path, outputs=[pdf_path_display])

if __name__ == "__main__":
    demo.launch()