Spaces:
Runtime error
Runtime error
Update qa_engine/qa_engine.py
Browse files- qa_engine/qa_engine.py +30 -0
qa_engine/qa_engine.py
CHANGED
@@ -227,6 +227,33 @@ class QAEngine():
|
|
227 |
self.knowledge_index = FAISS.load_local('./indexes/run/', embedding_model)
|
228 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
def get_response(self, question: str, messages_context: str = '') -> Response:
|
232 |
"""
|
@@ -271,6 +298,9 @@ class QAEngine():
|
|
271 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
272 |
|
273 |
logger.info('Running LLM chain')
|
|
|
|
|
|
|
274 |
answer = self.llm_chain.run(question=question, context=context)
|
275 |
response.set_answer(answer)
|
276 |
logger.info('Received answer')
|
|
|
227 |
self.knowledge_index = FAISS.load_local('./indexes/run/', embedding_model)
|
228 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
229 |
|
230 |
+
|
231 |
+
@staticmethod
|
232 |
+
def _preprocess_question(question: str) -> str:
|
233 |
+
if question[-1] != '?':
|
234 |
+
question += '?'
|
235 |
+
return question
|
236 |
+
|
237 |
+
|
238 |
+
@staticmethod
|
239 |
+
def _postprocess_answer(answer: str) -> str:
|
240 |
+
'''
|
241 |
+
Preprocess the answer by removing unnecessary sequences and stop sequences.
|
242 |
+
'''
|
243 |
+
REMOVE_SEQUENCES = [
|
244 |
+
'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'
|
245 |
+
]
|
246 |
+
STOP_SEQUENCES = [
|
247 |
+
'\nUser:', '\nYou:'
|
248 |
+
]
|
249 |
+
for seq in REMOVE_SEQUENCES:
|
250 |
+
answer = answer.replace(seq, '')
|
251 |
+
for seq in STOP_SEQUENCES:
|
252 |
+
if seq in answer:
|
253 |
+
answer = answer[:answer.index(seq)]
|
254 |
+
answer = answer.strip()
|
255 |
+
return answer
|
256 |
+
|
257 |
|
258 |
def get_response(self, question: str, messages_context: str = '') -> Response:
|
259 |
"""
|
|
|
298 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
299 |
|
300 |
logger.info('Running LLM chain')
|
301 |
+
question_processed = QAEngine._preprocess_question(question)
|
302 |
+
answer = self.llm_chain.run(question=question_processed, context=context)
|
303 |
+
answer = QAEngine._postprocess_answer(answer)
|
304 |
answer = self.llm_chain.run(question=question, context=context)
|
305 |
response.set_answer(answer)
|
306 |
logger.info('Received answer')
|