QnA / src /control /control.py
YvesP's picture
changes in the word doc + minor edits
f2681b8
raw
history blame
6.73 kB
import pandas as pd
from src.tools.retriever import Retriever
from src.tools.llm import LlmAgent
from src.model.block import Block
class Controller:
def __init__(self, retriever: Retriever, llm: LlmAgent, plan_language: str, content_language: str, specials: {}):
self.plan_language = plan_language
self.content_language = content_language
self.retriever = retriever
self.specials = specials
self.llm = llm
def get_response(self, query_fr: str, histo_fr: [(str, str)]) -> (str, [Block]):
histo_conversation, histo_queries = self._get_histo(histo_fr)
queries = self.llm.translate(text=histo_queries) if self.plan_language == 'en' else histo_queries
block_sources = self.retriever.similarity_search(query=queries)
block_sources = self._select_best_sources(block_sources)
for block in block_sources:
self._expand_block_with_specials(block, histo_queries)
sources_contents = [s.content for s in block_sources]
context = '\n'.join(sources_contents)
answer = self.llm.generate_paragraph(query=queries, histo=histo_conversation, context=context,
language=self.content_language)
sources_contents_fr = [s.content_fr for s in block_sources[:2]]
context_fr = '\n'.join(sources_contents_fr)
if self.content_language == 'en':
answer = self.llm.generate_answer(answer_en=answer, query=query_fr,
histo_fr=histo_conversation, context_fr=context_fr)
answer = self._clean_answer(answer)
return answer, block_sources
@staticmethod
def _get_histo(histo: [(str, str)]) -> (str, str):
histo_conversation = ""
histo_queries = ""
for (query, answer) in histo[-5:]:
histo_conversation += f'user: {query} \n bot: {answer}\n'
histo_queries += query + '\n'
return histo_conversation[:-1], histo_queries
@staticmethod
def _clean_answer(answer: str) -> str:
answer = answer.strip('bot:')
while answer and answer[-1] in {"'", '"', " ", "`"}:
answer = answer[:-1]
while answer and answer[0] in {"'", '"', " ", "`"}:
answer = answer[1:]
answer = answer.strip('bot:')
if answer:
if answer[-1] != ".":
answer += "."
return answer
@staticmethod
def _select_best_sources(sources: [Block], delta_1_2=0.15, delta_1_n=0.3, absolute=1.2, alpha=0.9) -> [Block]:
"""
Select the best sources: not far from the very best, not far from the last selected, and not too bad per se
"""
best_sources = []
for idx, s in enumerate(sources):
if idx == 0 \
or (s.distance - sources[idx - 1].distance < delta_1_2
and s.distance - sources[0].distance < delta_1_n) \
or s.distance < absolute:
best_sources.append(s)
delta_1_2 *= alpha
delta_1_n *= alpha
absolute *= alpha
else:
break
return best_sources
def _expand_block_with_specials(self, block: Block, query: str) -> Block:
"""
Performs special treatments for blocks expanding the text in the block
For example, it may add specific content extracted from a table based on elements of the query
"""
def any_in(l1: [], l2: []) -> bool:
"""
checks if any of el in l1 belongs to l2
"""
return 0 < len([el for el in l1 if el in l2])
def get_countries_names(df: pd.DataFrame) -> [str]:
"""
extends the ortograph of countries: ex. Etats-Unis = USA = Etats Unis, etc.
"""
countries_fr = list(df['pays'])
countries_en = list(df['country'])
countries_names = {c_fr: [c_fr, c_en] for c_fr, c_en in zip(countries_fr, countries_en)}
countries_extensions = self.specials['countries_extensions']
for c in set(countries_extensions.keys()).intersection(set(countries_names.keys())):
countries_names[c] += countries_extensions[c]
return countries_names
def remote_rate_fn(ctrl: Controller, block: Block, query: str) -> Block:
remote_rate_df = self.specials['remote_rate_df']
remote_rate_known = self.specials['remote_rate_known']
remote_rate_unknown = self.specials['remote_rate_unknown']
countries_fr = list(remote_rate_df['pays'])
countries_names = get_countries_names(remote_rate_df)
countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)]
for c in countries_of_interest:
rate = remote_rate_df[remote_rate_df['pays'] == c]['rate'].values[0]
block.content += remote_rate_known + c + " is " + rate + '\n'
if len(countries_of_interest) == 0:
block.content += remote_rate_unknown
return block
def accommodation_meal_fn(ctrl: Controller, block: Block, query: str) -> Block:
accommodation_meal_df = self.specials['accommodation_meal_df']
accommodation_meal_known = self.specials['accommodation_meal_known']
accommodation_meal_unknown = self.specials['accommodation_meal_unknown']
countries_fr = list(accommodation_meal_df['pays'])
countries_names = get_countries_names(df=accommodation_meal_df)
countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)]
for c in countries_of_interest:
rate = accommodation_meal_df[accommodation_meal_df['pays'] == c][['meal', 'accommodation']].values
block.content += accommodation_meal_known + c + " is " + rate[0][0] + ' for meals and ' \
+ rate[0][1] + ' for accommodation\n'
if len(countries_of_interest) == 0:
block.content += accommodation_meal_unknown
return block
def expand_block(special: str, ctrl: Controller, block: Block, query: str) -> Block:
routing_table = {'RemotenessRateTable': remote_rate_fn,
'AccommodationMealTable': accommodation_meal_fn, }
if special in routing_table.keys():
fn = routing_table[special]
block = fn(ctrl, block, query)
return block
for special in block.specials:
block = expand_block(special, self, block, query)
return block