Pavan178 commited on
Commit
ccff99d
·
verified ·
1 Parent(s): 91326a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -74
app.py CHANGED
@@ -4,112 +4,62 @@ from langchain.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.embeddings import OpenAIEmbeddings
6
  from langchain.vectorstores import FAISS
7
- from langchain.chains.base import Chain
8
  from langchain.chat_models import ChatOpenAI
9
- from langchain.chains import LLMChain, ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.prompts import PromptTemplate
12
 
13
- openai_api_key = os.environ.get("OPENAI_API_KEY")
14
-
15
  class AdvancedPdfChatbot:
16
  def __init__(self, openai_api_key):
17
  os.environ["OPENAI_API_KEY"] = openai_api_key
18
  self.embeddings = OpenAIEmbeddings()
19
  self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
20
  self.llm = ChatOpenAI(temperature=0, model_name='gpt-4')
21
- self.refinement_llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo')
22
 
23
  self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
24
- self.overall_chain = None
25
  self.db = None
26
-
27
- self.refinement_prompt = PromptTemplate(
28
- input_variables=['query', 'chat_history'],
29
- template="""Given the user's query and the conversation history, refine the query to be more specific and detailed.
30
- If the query is too vague, make reasonable assumptions based on the conversation context.
31
- Output the refined query."""
32
- )
33
 
34
  self.template = """
35
- You are a study partner assistant, students give you pdfs and you help them to answer their questions.
36
 
37
- Answer the question based on the most recent provided resources only.
38
- Give the most relevant answer.
39
 
40
  Context: {context}
41
  Question: {question}
42
  Answer:
43
- (Note: YOUR OUTPUT IS RENDERED IN PROPER PARAGRAPHS or BULLET POINTS when needed, modify the response formats as needed, only choose the formats based on the type of question asked)
44
  """
45
- self.prompt = PromptTemplate(template=self.template, input_variables=["context", "question"])
 
 
 
46
 
47
  def load_and_process_pdf(self, pdf_path):
48
  loader = PyPDFLoader(pdf_path)
49
  documents = loader.load()
50
  texts = self.text_splitter.split_documents(documents)
51
  self.db = FAISS.from_documents(texts, self.embeddings)
52
- self.setup_conversation_chain()
53
-
54
- def setup_conversation_chain(self):
55
- if not self.db:
56
- raise ValueError("Database not initialized. Please upload a PDF first.")
57
-
58
- refinement_chain = LLMChain(
59
- llm=self.refinement_llm,
60
- prompt=self.refinement_prompt
61
- )
62
 
63
- qa_chain = ConversationalRetrievalChain.from_llm(
64
- self.llm,
65
  retriever=self.db.as_retriever(),
66
  memory=self.memory,
67
- combine_docs_chain_kwargs={"prompt": self.prompt}
68
  )
69
-
70
- self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)
71
-
72
- class CustomChain(Chain):
73
- def __init__(self, refinement_chain, qa_chain):
74
- super().__init__()
75
- self._refinement_chain = refinement_chain
76
- self._qa_chain = qa_chain
77
-
78
- @property
79
- def input_keys(self):
80
- return ["query", "chat_history"]
81
-
82
- @property
83
- def output_keys(self):
84
- return ["answer"]
85
-
86
- def _call(self, inputs):
87
- query = inputs['query']
88
- chat_history = inputs.get('chat_history', [])
89
-
90
- refinement_inputs = {'query': query, 'chat_history': chat_history}
91
- refined_query = self._refinement_chain.run(refinement_inputs)
92
-
93
- qa_inputs = {"question": refined_query, "chat_history": chat_history}
94
- response = self._qa_chain(qa_inputs)
95
-
96
- return {"answer": response['answer']}
97
 
98
  def chat(self, query):
99
- if not self.overall_chain:
100
  return "Please upload a PDF first."
101
- chat_history = self.memory.load_memory_variables({})['chat_history']
102
- result = self.overall_chain({'query': query, 'chat_history': chat_history})
103
  return result['answer']
104
 
105
- def get_pdf_path(self):
106
- if self.db:
107
- return self.db.path
108
- else:
109
- return "No PDF uploaded yet."
110
 
111
- # Initialize the chatbot
112
- pdf_chatbot = AdvancedPdfChatbot(openai_api_key)
113
 
114
  def upload_pdf(pdf_file):
115
  if pdf_file is None:
@@ -132,13 +82,10 @@ def respond(message, history):
132
  return f"Error: {str(e)}", history
133
 
134
  def clear_chatbot():
135
- pdf_chatbot.memory.clear()
136
  return []
137
 
138
- def get_pdf_path():
139
- return pdf_chatbot.get_pdf_path()
140
-
141
- # Create the Gradio interface
142
  with gr.Blocks() as demo:
143
  gr.Markdown("# PDF Chatbot")
144
 
 
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.embeddings import OpenAIEmbeddings
6
  from langchain.vectorstores import FAISS
 
7
  from langchain.chat_models import ChatOpenAI
8
+ from langchain.chains import ConversationalRetrievalChain
9
  from langchain.memory import ConversationBufferMemory
10
  from langchain.prompts import PromptTemplate
11
 
 
 
12
  class AdvancedPdfChatbot:
13
  def __init__(self, openai_api_key):
14
  os.environ["OPENAI_API_KEY"] = openai_api_key
15
  self.embeddings = OpenAIEmbeddings()
16
  self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
17
  self.llm = ChatOpenAI(temperature=0, model_name='gpt-4')
 
18
 
19
  self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
 
20
  self.db = None
21
+ self.chain = None
 
 
 
 
 
 
22
 
23
  self.template = """
24
+ You are a study partner assistant helping students analyze PDF documents.
25
 
26
+ Answer the question based only on the most recent provided resources.
27
+ Provide the most relevant and concise answer possible.
28
 
29
  Context: {context}
30
  Question: {question}
31
  Answer:
 
32
  """
33
+ self.qa_prompt = PromptTemplate(
34
+ template=self.template,
35
+ input_variables=["context", "question"]
36
+ )
37
 
38
  def load_and_process_pdf(self, pdf_path):
39
  loader = PyPDFLoader(pdf_path)
40
  documents = loader.load()
41
  texts = self.text_splitter.split_documents(documents)
42
  self.db = FAISS.from_documents(texts, self.embeddings)
 
 
 
 
 
 
 
 
 
 
43
 
44
+ self.chain = ConversationalRetrievalChain.from_llm(
45
+ llm=self.llm,
46
  retriever=self.db.as_retriever(),
47
  memory=self.memory,
48
+ combine_docs_chain_kwargs={"prompt": self.qa_prompt}
49
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def chat(self, query):
52
+ if not self.chain:
53
  return "Please upload a PDF first."
54
+
55
+ result = self.chain({"question": query})
56
  return result['answer']
57
 
58
+ def clear_memory(self):
59
+ self.memory.clear()
 
 
 
60
 
61
+ # Gradio interface setup remains mostly the same
62
+ pdf_chatbot = AdvancedPdfChatbot(os.environ.get("OPENAI_API_KEY"))
63
 
64
  def upload_pdf(pdf_file):
65
  if pdf_file is None:
 
82
  return f"Error: {str(e)}", history
83
 
84
  def clear_chatbot():
85
+ pdf_chatbot.clear_memory()
86
  return []
87
 
88
+ # Gradio interface
 
 
 
89
  with gr.Blocks() as demo:
90
  gr.Markdown("# PDF Chatbot")
91