Pavan178 commited on
Commit
160264d
·
verified ·
1 Parent(s): fad38ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -21
app.py CHANGED
@@ -57,8 +57,7 @@ class AdvancedPdfChatbot:
57
 
58
  refinement_chain = LLMChain(
59
  llm=self.refinement_llm,
60
- prompt=self.refinement_prompt,
61
- output_key='refined_query'
62
  )
63
  qa_chain = ConversationalRetrievalChain.from_llm(
64
  self.llm,
@@ -68,20 +67,6 @@ class AdvancedPdfChatbot:
68
  )
69
  self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)
70
 
71
- def setup_conversation_chain(self):
72
- refinement_chain = LLMChain(
73
- llm=self.refinement_llm,
74
- prompt=self.refinement_prompt,
75
- output_key='refined_query'
76
- )
77
- qa_chain = ConversationalRetrievalChain.from_llm(
78
- self.llm,
79
- retriever=self.db.as_retriever(),
80
- memory=self.memory,
81
- combine_docs_chain_kwargs={"prompt": self.prompt}
82
- )
83
- self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)
84
-
85
  class CustomChain(Chain):
86
  def __init__(self, refinement_chain, qa_chain):
87
  super().__init__()
@@ -90,24 +75,28 @@ class AdvancedPdfChatbot:
90
 
91
  @property
92
  def input_keys(self):
 
93
  return ["query", "chat_history"]
94
 
95
  @property
96
  def output_keys(self):
 
97
  return ["answer"]
98
 
99
  def _call(self, inputs):
100
  query = inputs['query']
101
  chat_history = inputs.get('chat_history', [])
102
 
103
- refined_query = self.refinement_chain.run({'query': query, 'chat_history': chat_history})
104
- response = self.qa_chain({'question': refined_query, 'chat_history': chat_history})
 
 
 
 
 
105
 
106
  return {"answer": response['answer']}
107
 
108
-
109
-
110
-
111
  def chat(self, query):
112
  if not self.overall_chain:
113
  return "Please upload a PDF first."
 
57
 
58
  refinement_chain = LLMChain(
59
  llm=self.refinement_llm,
60
+ prompt=self.refinement_prompt
 
61
  )
62
  qa_chain = ConversationalRetrievalChain.from_llm(
63
  self.llm,
 
67
  )
68
  self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  class CustomChain(Chain):
71
  def __init__(self, refinement_chain, qa_chain):
72
  super().__init__()
 
75
 
76
  @property
77
  def input_keys(self):
78
+ """Define the input keys that this chain expects."""
79
  return ["query", "chat_history"]
80
 
81
  @property
82
  def output_keys(self):
83
+ """Define the output keys that this chain returns."""
84
  return ["answer"]
85
 
86
  def _call(self, inputs):
87
  query = inputs['query']
88
  chat_history = inputs.get('chat_history', [])
89
 
90
+ # Run the refinement chain to refine the query
91
+ refinement_inputs = {'query': query, 'chat_history': chat_history}
92
+ refined_query = self.refinement_chain.run(refinement_inputs)
93
+
94
+ # Run the QA chain using the refined query and the chat history
95
+ qa_inputs = {"question": refined_query, "chat_history": chat_history}
96
+ response = self.qa_chain(qa_inputs)
97
 
98
  return {"answer": response['answer']}
99
 
 
 
 
100
  def chat(self, query):
101
  if not self.overall_chain:
102
  return "Please upload a PDF first."