File size: 5,287 Bytes
d8c3a88
d2e3c7f
 
 
 
 
6c5c0ad
6a6fbcd
d2e3c7f
d43bb1b
633ac28
6a6fbcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2e3c7f
 
 
 
 
8af0aff
355b657
d2e3c7f
6a6fbcd
3f31c68
ccff99d
3f31c68
ccff99d
6a6fbcd
 
 
 
 
 
 
ccff99d
 
5e8e8f0
d2e3c7f
 
 
 
f74eb2e
873a6e6
ccff99d
 
6a6fbcd
b840efb
ccff99d
b840efb
a261843
d2e3c7f
ccff99d
d2e3c7f
ccff99d
6a6fbcd
 
 
 
d2e3c7f
 
6a6fbcd
 
 
 
 
 
 
 
 
 
ccff99d
 
ff0e62c
6a6fbcd
ccff99d
d2e3c7f
 
 
 
8af0aff
 
 
 
 
 
d2e3c7f
 
8af0aff
 
 
 
 
 
 
 
d2e3c7f
f74eb2e
ccff99d
f74eb2e
 
6a6fbcd
d2e3c7f
6a6fbcd
f74eb2e
d2e3c7f
 
 
5e8e8f0
d2e3c7f
 
6a6fbcd
d2e3c7f
6a6fbcd
d2e3c7f
6a6fbcd
 
 
5e8e8f0
d2e3c7f
91326a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import gradio as gr
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate

class QueryRefiner:
    def __init__(self):
        self.refinement_llm = ChatOpenAI(temperature=0.2, model_name='gpt-3.5-turbo')
        self.refinement_prompt = PromptTemplate(
            input_variables=['query', 'context'],
            template="""Refine and enhance the following query for maximum clarity and precision:

Original Query: {query}
Document Context: {context}

Enhanced Query Requirements:
- Clarify any ambiguous terms
- Add specific context-driven details
- Ensure precise information retrieval
- Restructure for optimal comprehension

Refined Query:"""
        )
        self.refinement_chain = LLMChain(
            llm=self.refinement_llm, 
            prompt=self.refinement_prompt
        )

    def refine_query(self, original_query, context_hints=''):
        try:
            refined_query = self.refinement_chain.run({
                'query': original_query, 
                'context': context_hints or "General academic document"
            })
            return refined_query.strip()
        except Exception as e:
            print(f"Query refinement error: {e}")
            return original_query

class AdvancedPdfChatbot:
    def __init__(self, openai_api_key):
        os.environ["OPENAI_API_KEY"] = openai_api_key
        self.embeddings = OpenAIEmbeddings()
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        self.llm = ChatOpenAI(temperature=0, model_name='gpt-4')
        
        self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
        self.query_refiner = QueryRefiner()
        self.db = None
        self.chain = None
        
        self.qa_prompt = PromptTemplate(
            template="""You are an expert academic assistant analyzing a document.

Context: {context}
Question: {question}

Provide a comprehensive, precise answer based strictly on the document's content.
If the answer isn't directly available, explain why.""",
            input_variables=["context", "question"]
        )

    def load_and_process_pdf(self, pdf_path):
        loader = PyPDFLoader(pdf_path)
        documents = loader.load()
        texts = self.text_splitter.split_documents(documents)
        self.db = FAISS.from_documents(texts, self.embeddings)
        
        self.chain = ConversationalRetrievalChain.from_llm(
            llm=self.llm,
            retriever=self.db.as_retriever(search_kwargs={"k": 3}),
            memory=self.memory,
            combine_docs_chain_kwargs={"prompt": self.qa_prompt}
        )

    def chat(self, query):
        if not self.chain:
            return "Please upload a PDF first."
        
        context_hints = self._extract_document_type()
        refined_query = self.query_refiner.refine_query(query, context_hints)
        
        result = self.chain({"question": refined_query})
        return result['answer']

    def _extract_document_type(self):
        """Extract basic document characteristics"""
        if not self.db:
            return ""
        try:
            first_doc = list(self.db.docstore._dict.values())[0].page_content[:500]
            return f"Document appears to cover: {first_doc[:100]}..."
        except:
            return "Academic/technical document"

    def clear_memory(self):
        self.memory.clear()

# Gradio Interface
pdf_chatbot = AdvancedPdfChatbot(os.environ.get("OPENAI_API_KEY"))

def upload_pdf(pdf_file):
    if pdf_file is None:
        return "Please upload a PDF file."
    file_path = pdf_file.name if hasattr(pdf_file, 'name') else pdf_file
    try:
        pdf_chatbot.load_and_process_pdf(file_path)
        return f"PDF processed successfully: {file_path}"
    except Exception as e:
        return f"Error processing PDF: {str(e)}"

def respond(message, history):
    if not message:
        return "", history
    try:
        bot_message = pdf_chatbot.chat(message)
        history.append((message, bot_message))
        return "", history
    except Exception as e:
        return f"Error: {str(e)}", history

def clear_chatbot():
    pdf_chatbot.clear_memory()
    return []

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Advanced PDF Chatbot")
    
    with gr.Row():
        pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
        upload_button = gr.Button("Process PDF")

    upload_status = gr.Textbox(label="Upload Status")
    upload_button.click(upload_pdf, inputs=[pdf_upload], outputs=[upload_status])
    
    chatbot_interface = gr.Chatbot()
    msg = gr.Textbox(placeholder="Enter your query...")
    msg.submit(respond, inputs=[msg, chatbot_interface], outputs=[msg, chatbot_interface])
    
    clear_button = gr.Button("Clear Conversation")
    clear_button.click(clear_chatbot, outputs=[chatbot_interface])

if __name__ == "__main__":
    demo.launch()