Pavan178 commited on
Commit
a261843
·
verified ·
1 Parent(s): 95c7827

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -24
app.py CHANGED
@@ -69,31 +69,51 @@ class AdvancedPdfChatbot:
69
  self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)
70
 
71
  class CustomChain(Chain):
72
- def __init__(self, refinement_chain, qa_chain):
73
- super().__init__()
74
- self.refinement_chain = refinement_chain
75
- self.qa_chain = qa_chain
76
-
77
- @property
78
- def input_keys(self):
79
- return ["query", "chat_history"]
80
-
81
- @property
82
- def output_keys(self):
83
- return ["answer"]
84
-
85
- def _call(self, inputs):
86
- query = inputs['query']
87
- chat_history = inputs.get('chat_history', [])
88
-
89
- # Run the refinement chain to refine the query
90
- refined_query = self.refinement_chain.run({'query': query, 'chat_history': chat_history})
91
-
92
- # Run the QA chain using the refined query and the chat history
93
- response = self.qa_chain.run({"question": refined_query, "chat_history": chat_history})
 
 
 
 
 
 
 
 
 
94
 
95
- # Return the answer
96
- return {"answer": response}
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def chat(self, query):
99
  if not self.overall_chain:
 
69
  self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)
70
 
71
  class CustomChain(Chain):
72
+ def __init__(self, refinement_chain, qa_chain):
73
+ super().__init__()
74
+ self.refinement_chain = refinement_chain
75
+ self.qa_chain = qa_chain
76
+
77
+ @property
78
+ def input_keys(self):
79
+ return ["query", "chat_history"]
80
+
81
+ @property
82
+ def output_keys(self):
83
+ return ["answer"]
84
+
85
+ def _call(self, inputs):
86
+ query = inputs['query']
87
+ chat_history = inputs.get('chat_history', [])
88
+
89
+ # Run the refinement chain to refine the query
90
+ refined_query = self.refinement_chain.run({'query': query, 'chat_history': chat_history})
91
+
92
+ # Run the QA chain using the refined query and the chat history
93
+ response = self.qa_chain({'question': refined_query, 'chat_history': chat_history})
94
+
95
+ # Return the answer
96
+ return {"answer": response['answer']}
97
+
98
+
99
+
100
+ def setup_conversation_chain(self):
101
+ if not self.db:
102
+ raise ValueError("Database not initialized. Please upload a PDF first.")
103
 
104
+ refinement_chain = LLMChain(
105
+ llm=self.refinement_llm,
106
+ prompt=self.refinement_prompt,
107
+ output_key='refined_query'
108
+ )
109
+ qa_chain = ConversationalRetrievalChain.from_llm(
110
+ self.llm,
111
+ retriever=self.db.as_retriever(),
112
+ memory=self.memory,
113
+ combine_docs_chain_kwargs={"prompt": self.prompt}
114
+ )
115
+ self.overall_chain = CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)
116
+
117
 
118
  def chat(self, query):
119
  if not self.overall_chain: