Update app.py
Browse files
app.py
CHANGED
@@ -4,112 +4,62 @@ from langchain.document_loaders import PyPDFLoader
|
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
from langchain.embeddings import OpenAIEmbeddings
|
6 |
from langchain.vectorstores import FAISS
|
7 |
-
from langchain.chains.base import Chain
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
-
from langchain.chains import
|
10 |
from langchain.memory import ConversationBufferMemory
|
11 |
from langchain.prompts import PromptTemplate
|
12 |
|
13 |
-
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
14 |
-
|
15 |
class AdvancedPdfChatbot:
|
16 |
def __init__(self, openai_api_key):
|
17 |
os.environ["OPENAI_API_KEY"] = openai_api_key
|
18 |
self.embeddings = OpenAIEmbeddings()
|
19 |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
20 |
self.llm = ChatOpenAI(temperature=0, model_name='gpt-4')
|
21 |
-
self.refinement_llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo')
|
22 |
|
23 |
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
24 |
-
self.overall_chain = None
|
25 |
self.db = None
|
26 |
-
|
27 |
-
self.refinement_prompt = PromptTemplate(
|
28 |
-
input_variables=['query', 'chat_history'],
|
29 |
-
template="""Given the user's query and the conversation history, refine the query to be more specific and detailed.
|
30 |
-
If the query is too vague, make reasonable assumptions based on the conversation context.
|
31 |
-
Output the refined query."""
|
32 |
-
)
|
33 |
|
34 |
self.template = """
|
35 |
-
You are a study partner assistant
|
36 |
|
37 |
-
Answer the question based on the most recent provided resources
|
38 |
-
|
39 |
|
40 |
Context: {context}
|
41 |
Question: {question}
|
42 |
Answer:
|
43 |
-
(Note: YOUR OUTPUT IS RENDERED IN PROPER PARAGRAPHS or BULLET POINTS when needed, modify the response formats as needed, only choose the formats based on the type of question asked)
|
44 |
"""
|
45 |
-
self.
|
|
|
|
|
|
|
46 |
|
47 |
def load_and_process_pdf(self, pdf_path):
|
48 |
loader = PyPDFLoader(pdf_path)
|
49 |
documents = loader.load()
|
50 |
texts = self.text_splitter.split_documents(documents)
|
51 |
self.db = FAISS.from_documents(texts, self.embeddings)
|
52 |
-
self.setup_conversation_chain()
|
53 |
-
|
54 |
-
def setup_conversation_chain(self):
|
55 |
-
if not self.db:
|
56 |
-
raise ValueError("Database not initialized. Please upload a PDF first.")
|
57 |
-
|
58 |
-
refinement_chain = LLMChain(
|
59 |
-
llm=self.refinement_llm,
|
60 |
-
prompt=self.refinement_prompt
|
61 |
-
)
|
62 |
|
63 |
-
|
64 |
-
self.llm,
|
65 |
retriever=self.db.as_retriever(),
|
66 |
memory=self.memory,
|
67 |
-
combine_docs_chain_kwargs={"prompt": self.
|
68 |
)
|
69 |
-
|
70 |
-
self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)
|
71 |
-
|
72 |
-
class CustomChain(Chain):
|
73 |
-
def __init__(self, refinement_chain, qa_chain):
|
74 |
-
super().__init__()
|
75 |
-
self._refinement_chain = refinement_chain
|
76 |
-
self._qa_chain = qa_chain
|
77 |
-
|
78 |
-
@property
|
79 |
-
def input_keys(self):
|
80 |
-
return ["query", "chat_history"]
|
81 |
-
|
82 |
-
@property
|
83 |
-
def output_keys(self):
|
84 |
-
return ["answer"]
|
85 |
-
|
86 |
-
def _call(self, inputs):
|
87 |
-
query = inputs['query']
|
88 |
-
chat_history = inputs.get('chat_history', [])
|
89 |
-
|
90 |
-
refinement_inputs = {'query': query, 'chat_history': chat_history}
|
91 |
-
refined_query = self._refinement_chain.run(refinement_inputs)
|
92 |
-
|
93 |
-
qa_inputs = {"question": refined_query, "chat_history": chat_history}
|
94 |
-
response = self._qa_chain(qa_inputs)
|
95 |
-
|
96 |
-
return {"answer": response['answer']}
|
97 |
|
98 |
def chat(self, query):
|
99 |
-
if not self.
|
100 |
return "Please upload a PDF first."
|
101 |
-
|
102 |
-
result = self.
|
103 |
return result['answer']
|
104 |
|
105 |
-
def
|
106 |
-
|
107 |
-
return self.db.path
|
108 |
-
else:
|
109 |
-
return "No PDF uploaded yet."
|
110 |
|
111 |
-
#
|
112 |
-
pdf_chatbot = AdvancedPdfChatbot(
|
113 |
|
114 |
def upload_pdf(pdf_file):
|
115 |
if pdf_file is None:
|
@@ -132,13 +82,10 @@ def respond(message, history):
|
|
132 |
return f"Error: {str(e)}", history
|
133 |
|
134 |
def clear_chatbot():
|
135 |
-
pdf_chatbot.
|
136 |
return []
|
137 |
|
138 |
-
|
139 |
-
return pdf_chatbot.get_pdf_path()
|
140 |
-
|
141 |
-
# Create the Gradio interface
|
142 |
with gr.Blocks() as demo:
|
143 |
gr.Markdown("# PDF Chatbot")
|
144 |
|
|
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
from langchain.embeddings import OpenAIEmbeddings
|
6 |
from langchain.vectorstores import FAISS
|
|
|
7 |
from langchain.chat_models import ChatOpenAI
|
8 |
+
from langchain.chains import ConversationalRetrievalChain
|
9 |
from langchain.memory import ConversationBufferMemory
|
10 |
from langchain.prompts import PromptTemplate
|
11 |
|
|
|
|
|
12 |
class AdvancedPdfChatbot:
|
13 |
def __init__(self, openai_api_key):
|
14 |
os.environ["OPENAI_API_KEY"] = openai_api_key
|
15 |
self.embeddings = OpenAIEmbeddings()
|
16 |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
17 |
self.llm = ChatOpenAI(temperature=0, model_name='gpt-4')
|
|
|
18 |
|
19 |
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
|
|
20 |
self.db = None
|
21 |
+
self.chain = None
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
self.template = """
|
24 |
+
You are a study partner assistant helping students analyze PDF documents.
|
25 |
|
26 |
+
Answer the question based only on the most recent provided resources.
|
27 |
+
Provide the most relevant and concise answer possible.
|
28 |
|
29 |
Context: {context}
|
30 |
Question: {question}
|
31 |
Answer:
|
|
|
32 |
"""
|
33 |
+
self.qa_prompt = PromptTemplate(
|
34 |
+
template=self.template,
|
35 |
+
input_variables=["context", "question"]
|
36 |
+
)
|
37 |
|
38 |
def load_and_process_pdf(self, pdf_path):
|
39 |
loader = PyPDFLoader(pdf_path)
|
40 |
documents = loader.load()
|
41 |
texts = self.text_splitter.split_documents(documents)
|
42 |
self.db = FAISS.from_documents(texts, self.embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
self.chain = ConversationalRetrievalChain.from_llm(
|
45 |
+
llm=self.llm,
|
46 |
retriever=self.db.as_retriever(),
|
47 |
memory=self.memory,
|
48 |
+
combine_docs_chain_kwargs={"prompt": self.qa_prompt}
|
49 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def chat(self, query):
|
52 |
+
if not self.chain:
|
53 |
return "Please upload a PDF first."
|
54 |
+
|
55 |
+
result = self.chain({"question": query})
|
56 |
return result['answer']
|
57 |
|
58 |
+
def clear_memory(self):
|
59 |
+
self.memory.clear()
|
|
|
|
|
|
|
60 |
|
61 |
+
# Gradio interface setup remains mostly the same
|
62 |
+
pdf_chatbot = AdvancedPdfChatbot(os.environ.get("OPENAI_API_KEY"))
|
63 |
|
64 |
def upload_pdf(pdf_file):
|
65 |
if pdf_file is None:
|
|
|
82 |
return f"Error: {str(e)}", history
|
83 |
|
84 |
def clear_chatbot():
|
85 |
+
pdf_chatbot.clear_memory()
|
86 |
return []
|
87 |
|
88 |
+
# Gradio interface
|
|
|
|
|
|
|
89 |
with gr.Blocks() as demo:
|
90 |
gr.Markdown("# PDF Chatbot")
|
91 |
|