File size: 6,284 Bytes
d8c3a88
d2e3c7f
 
 
 
 
b840efb
6c5c0ad
b840efb
d2e3c7f
d43bb1b
633ac28
aff3a65
a526ade
d2e3c7f
 
 
 
 
8af0aff
3f31c68
355b657
d2e3c7f
3f31c68
 
 
 
 
 
 
 
 
 
355b657
3f31c68
355b657
 
 
 
 
 
 
3f31c68
355b657
 
5e8e8f0
d2e3c7f
 
 
 
f74eb2e
d2e3c7f
 
 
8af0aff
 
873a6e6
b840efb
 
160264d
b840efb
873a6e6
b840efb
 
 
 
 
 
873a6e6
b840efb
3f31c68
b3de9a3
e224a4b
873a6e6
e224a4b
b3de9a3
 
 
e224a4b
 
160264d
e224a4b
b3de9a3
e224a4b
 
160264d
e224a4b
b3de9a3
e224a4b
 
 
 
160264d
 
b3de9a3
160264d
 
 
b3de9a3
e224a4b
 
a261843
d2e3c7f
3f31c68
d2e3c7f
3f31c68
 
d2e3c7f
 
ff0e62c
3f31c68
 
ff0e62c
 
 
d2e3c7f
 
 
 
 
 
8af0aff
 
 
 
 
 
d2e3c7f
 
8af0aff
 
 
 
 
 
 
 
d2e3c7f
f74eb2e
 
 
 
ff0e62c
 
 
d2e3c7f
 
 
f74eb2e
d2e3c7f
 
 
5e8e8f0
d2e3c7f
 
 
95c7827
d2e3c7f
5e8e8f0
d2e3c7f
8af0aff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import gradio as gr
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains.base import Chain
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate

openai_api_key = os.environ.get("OPENAI_API_KEY")

class AdvancedPdfChatbot:
    def __init__(self, openai_api_key):
        os.environ["OPENAI_API_KEY"] = openai_api_key
        self.embeddings = OpenAIEmbeddings()
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        self.llm = ChatOpenAI(temperature=0, model_name='gpt-4')
        self.refinement_llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo')
        
        self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
        self.overall_chain = None
        self.db = None
        
        self.refinement_prompt = PromptTemplate(
            input_variables=['query', 'chat_history'],
            template="""Given the user's query and the conversation history, refine the query to be more specific and detailed.
            If the query is too vague, make reasonable assumptions based on the conversation context.
            Output the refined query."""
        )
        
        self.template = """
        You are a study partner assistant, students give you pdfs and you help them to answer their questions.
        
        Answer the question based on the most recent provided resources only.
        Give the most relevant answer.
        
        Context: {context}
        Question: {question}
        Answer:
        (Note: YOUR OUTPUT IS RENDERED IN PROPER PARAGRAPHS or BULLET POINTS when needed, modify the response formats as needed, only choose the formats based on the type of question asked)
        """
        self.prompt = PromptTemplate(template=self.template, input_variables=["context", "question"])

    def load_and_process_pdf(self, pdf_path):
        loader = PyPDFLoader(pdf_path)
        documents = loader.load()
        texts = self.text_splitter.split_documents(documents)
        self.db = FAISS.from_documents(texts, self.embeddings)
        self.setup_conversation_chain()

    def setup_conversation_chain(self):
        if not self.db:
            raise ValueError("Database not initialized. Please upload a PDF first.")
        
        refinement_chain = LLMChain(
            llm=self.refinement_llm,
            prompt=self.refinement_prompt
        )
        
        qa_chain = ConversationalRetrievalChain.from_llm(
            self.llm,
            retriever=self.db.as_retriever(),
            memory=self.memory,
            combine_docs_chain_kwargs={"prompt": self.prompt}
        )
        
        self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)

     class CustomChain(Chain):
        def __init__(self, refinement_chain, qa_chain):
            """Initialize refinement and QA chains as instance attributes."""
            super().__init__()
            self._refinement_chain = refinement_chain  # Use a different attribute name
            self._qa_chain = qa_chain
        
        @property
        def input_keys(self):
            """Define the input keys that this chain expects."""
            return ["query", "chat_history"]
        
        @property
        def output_keys(self):
            """Define the output keys that this chain returns."""
            return ["answer"]
        
        def _call(self, inputs):
            query = inputs['query']
            chat_history = inputs.get('chat_history', [])
            
            # Run the refinement chain to refine the query
            refinement_inputs = {'query': query, 'chat_history': chat_history}
            refined_query = self._refinement_chain.run(refinement_inputs)  # Use the renamed attribute
            
            # Run the QA chain using the refined query and the chat history
            qa_inputs = {"question": refined_query, "chat_history": chat_history}
            response = self._qa_chain(qa_inputs)  # Use the renamed attribute
            
            return {"answer": response['answer']}

    def chat(self, query):
        if not self.overall_chain:
            return "Please upload a PDF first."
        chat_history = self.memory.load_memory_variables({})['chat_history']
        result = self.overall_chain({'query': query, 'chat_history': chat_history})
        return result['answer']

    def get_pdf_path(self):
        if self.db:
            return self.db.path
        else:
            return "No PDF uploaded yet."

# Initialize the chatbot
pdf_chatbot = AdvancedPdfChatbot(openai_api_key)

def upload_pdf(pdf_file):
    if pdf_file is None:
        return "Please upload a PDF file."
    file_path = pdf_file.name if hasattr(pdf_file, 'name') else pdf_file
    try:
        pdf_chatbot.load_and_process_pdf(file_path)
        return f"PDF processed successfully: {file_path}"
    except Exception as e:
        return f"Error processing PDF: {str(e)}"

def respond(message, history):
    if not message:
        return "", history
    try:
        bot_message = pdf_chatbot.chat(message)
        history.append((message, bot_message))
        return "", history
    except Exception as e:
        return f"Error: {str(e)}", history

def clear_chatbot():
    pdf_chatbot.memory.clear()
    return []

def get_pdf_path():
    return pdf_chatbot.get_pdf_path()

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# PDF Chatbot")
    
    with gr.Row():
        pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
        upload_button = gr.Button("Process PDF")

    upload_status = gr.Textbox(label="Upload Status")
    upload_button.click(upload_pdf, inputs=[pdf_upload], outputs=[upload_status])
    chatbot_interface = gr.Chatbot()
    msg = gr.Textbox()
    msg.submit(respond, inputs=[msg, chatbot_interface], outputs=[msg, chatbot_interface])

if __name__ == "__main__":
    demo.launch()