Update app.py
Browse files
app.py
CHANGED
@@ -4,14 +4,11 @@ from langchain.document_loaders import PyPDFLoader
|
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
from langchain.embeddings import OpenAIEmbeddings
|
6 |
from langchain.vectorstores import FAISS
|
7 |
-
from langchain.chains import ConversationalRetrievalChain
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
from langchain.memory import ConversationBufferMemory
|
10 |
-
|
11 |
from langchain.prompts import PromptTemplate
|
12 |
|
13 |
-
|
14 |
-
|
15 |
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
16 |
|
17 |
class AdvancedPdfChatbot:
|
@@ -19,33 +16,30 @@ class AdvancedPdfChatbot:
|
|
19 |
os.environ["OPENAI_API_KEY"] = openai_api_key
|
20 |
self.embeddings = OpenAIEmbeddings()
|
21 |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
22 |
-
self.llm =
|
|
|
23 |
|
24 |
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
25 |
-
self.
|
26 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
self.template = """
|
28 |
-
You are a study partner assistant, students give you pdfs
|
29 |
-
and you help them to answer their questions.
|
30 |
|
31 |
Answer the question based on the most recent provided resources only.
|
32 |
Give the most relevant answer.
|
33 |
-
Instructions:
|
34 |
-
|
35 |
-
Use given source for Context: Generate responses using only the provided content.
|
36 |
-
Cite Sources: Reference content using [page: paragraph] or [page: line] format.
|
37 |
-
Address Multiple Subjects: If the query relates to multiple subjects with the same name, provide distinct responses for each.
|
38 |
-
Relevance Only: Exclude irrelevant or outlier information.
|
39 |
-
Keep it Concise: Provide clear, direct, and descriptive answers, answer in great details when needed and keep short responses when needed.
|
40 |
-
No Guesswork: Do not generate information beyond the given content.
|
41 |
-
No Match: If no relevant content is found, reply with: "No relevant information found.
|
42 |
-
Add comprehensive details and break down the responses into parts whenever needed.
|
43 |
|
44 |
Context: {context}
|
45 |
Question: {question}
|
46 |
Answer:
|
47 |
-
|
48 |
-
(Note :YOUR OUTPUT IS RENDERED IN PROPER PARAGRAPHS or BULLET POINTS when needed, modify the response formats as needed, only choose the formats based on the type of question asked)
|
49 |
"""
|
50 |
self.prompt = PromptTemplate(template=self.template, input_variables=["context", "question"])
|
51 |
|
@@ -54,27 +48,64 @@ class AdvancedPdfChatbot:
|
|
54 |
documents = loader.load()
|
55 |
texts = self.text_splitter.split_documents(documents)
|
56 |
self.db = FAISS.from_documents(texts, self.embeddings)
|
57 |
-
self.pdf_path = pdf_path
|
58 |
self.setup_conversation_chain()
|
59 |
|
60 |
def setup_conversation_chain(self):
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
retriever=self.db.as_retriever(),
|
64 |
memory=self.memory,
|
65 |
-
|
66 |
)
|
67 |
|
68 |
def chat(self, query):
|
69 |
-
if not self.
|
70 |
return "Please upload a PDF first."
|
71 |
-
|
|
|
72 |
return result['answer']
|
73 |
|
74 |
def get_pdf_path(self):
|
75 |
-
|
76 |
-
|
77 |
-
return self.pdf_path
|
78 |
else:
|
79 |
return "No PDF uploaded yet."
|
80 |
|
@@ -98,7 +129,6 @@ def clear_chatbot():
|
|
98 |
return []
|
99 |
|
100 |
def get_pdf_path():
|
101 |
-
# Call the method to return the current PDF path
|
102 |
return pdf_chatbot.get_pdf_path()
|
103 |
|
104 |
# Create the Gradio interface
|
@@ -122,4 +152,4 @@ with gr.Blocks() as demo:
|
|
122 |
path_button.click(get_pdf_path, outputs=[pdf_path_display])
|
123 |
|
124 |
if __name__ == "__main__":
|
125 |
-
demo.launch()
|
|
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
from langchain.embeddings import OpenAIEmbeddings
|
6 |
from langchain.vectorstores import FAISS
|
7 |
+
from langchain.chains import ConversationalRetrievalChain, Chain
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
from langchain.memory import ConversationBufferMemory
|
|
|
10 |
from langchain.prompts import PromptTemplate
|
11 |
|
|
|
|
|
12 |
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
13 |
|
14 |
class AdvancedPdfChatbot:
|
|
|
16 |
os.environ["OPENAI_API_KEY"] = openai_api_key
|
17 |
self.embeddings = OpenAIEmbeddings()
|
18 |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
19 |
+
self.llm = ChatOpenAI(temperature=0, model_name='gpt-4') # Corrected model name
|
20 |
+
self.refinement_llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo')
|
21 |
|
22 |
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
23 |
+
self.overall_chain = None
|
24 |
+
self.db = None
|
25 |
+
|
26 |
+
self.refinement_prompt = PromptTemplate(
|
27 |
+
input_variables=['query', 'chat_history'],
|
28 |
+
template="""Given the user's query and the conversation history, refine the query to be more specific and detailed.
|
29 |
+
If the query is too vague, make reasonable assumptions based on the conversation context.
|
30 |
+
Output the refined query."""
|
31 |
+
)
|
32 |
+
|
33 |
self.template = """
|
34 |
+
You are a study partner assistant, students give you pdfs and you help them to answer their questions.
|
|
|
35 |
|
36 |
Answer the question based on the most recent provided resources only.
|
37 |
Give the most relevant answer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
Context: {context}
|
40 |
Question: {question}
|
41 |
Answer:
|
42 |
+
(Note: YOUR OUTPUT IS RENDERED IN PROPER PARAGRAPHS or BULLET POINTS when needed, modify the response formats as needed, only choose the formats based on the type of question asked)
|
|
|
43 |
"""
|
44 |
self.prompt = PromptTemplate(template=self.template, input_variables=["context", "question"])
|
45 |
|
|
|
48 |
documents = loader.load()
|
49 |
texts = self.text_splitter.split_documents(documents)
|
50 |
self.db = FAISS.from_documents(texts, self.embeddings)
|
|
|
51 |
self.setup_conversation_chain()
|
52 |
|
53 |
def setup_conversation_chain(self):
|
54 |
+
class CustomChain(Chain):
|
55 |
+
refinement_chain: Chain
|
56 |
+
qa_chain: Chain
|
57 |
+
|
58 |
+
@classmethod
|
59 |
+
def from_llms(cls, refinement_llm, qa_llm, retriever, memory, prompt):
|
60 |
+
refinement_chain = Chain(
|
61 |
+
llm_chain=LLMChain(
|
62 |
+
llm=refinement_llm,
|
63 |
+
prompt=self.refinement_prompt,
|
64 |
+
output_key='refined_query'
|
65 |
+
)
|
66 |
+
)
|
67 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
68 |
+
qa_llm,
|
69 |
+
retriever=retriever,
|
70 |
+
memory=memory,
|
71 |
+
combine_docs_chain_kwargs={"prompt": prompt}
|
72 |
+
)
|
73 |
+
return cls(refinement_chain=refinement_chain, qa_chain=qa_chain)
|
74 |
+
|
75 |
+
def _call(self, inputs):
|
76 |
+
query = inputs['query']
|
77 |
+
chat_history = inputs.get('chat_history', [])
|
78 |
+
refined_query = self.refinement_chain.run(query=query, chat_history=chat_history)
|
79 |
+
response = self.qa_chain({"question": refined_query, "chat_history": chat_history})
|
80 |
+
self.qa_chain.memory.save_context({"input": query}, {"output": response['answer']})
|
81 |
+
return {"answer": response['answer']}
|
82 |
+
|
83 |
+
@property
|
84 |
+
def input_keys(self):
|
85 |
+
return ['query', 'chat_history']
|
86 |
+
|
87 |
+
@property
|
88 |
+
def output_keys(self):
|
89 |
+
return ['answer']
|
90 |
+
|
91 |
+
self.overall_chain = CustomChain.from_llms(
|
92 |
+
refinement_llm=self.refinement_llm,
|
93 |
+
qa_llm=self.llm,
|
94 |
retriever=self.db.as_retriever(),
|
95 |
memory=self.memory,
|
96 |
+
prompt=self.prompt
|
97 |
)
|
98 |
|
99 |
def chat(self, query):
|
100 |
+
if not self.overall_chain:
|
101 |
return "Please upload a PDF first."
|
102 |
+
chat_history = self.memory.load_memory_variables({})['chat_history']
|
103 |
+
result = self.overall_chain({'query': query, 'chat_history': chat_history})
|
104 |
return result['answer']
|
105 |
|
106 |
def get_pdf_path(self):
|
107 |
+
if self.db:
|
108 |
+
return self.db.path
|
|
|
109 |
else:
|
110 |
return "No PDF uploaded yet."
|
111 |
|
|
|
129 |
return []
|
130 |
|
131 |
def get_pdf_path():
|
|
|
132 |
return pdf_chatbot.get_pdf_path()
|
133 |
|
134 |
# Create the Gradio interface
|
|
|
152 |
path_button.click(get_pdf_path, outputs=[pdf_path_display])
|
153 |
|
154 |
if __name__ == "__main__":
|
155 |
+
demo.launch()
|