|
import pandas as pd |
|
|
|
from src.tools.retriever import Retriever |
|
from src.tools.llm import LlmAgent |
|
from src.model.block import Block |
|
|
|
|
|
class Controller: |
|
|
|
def __init__(self, retriever: Retriever, llm: LlmAgent, plan_language: str, content_language: str, specials: {}): |
|
self.plan_language = plan_language |
|
self.content_language = content_language |
|
self.retriever = retriever |
|
self.specials = specials |
|
self.llm = llm |
|
|
|
def get_response(self, query_fr: str, histo_fr: [(str, str)]) -> (str, [Block]): |
|
histo_conversation, histo_queries = self._get_histo(histo_fr) |
|
queries = self.llm.translate(text=histo_queries) if self.plan_language == 'en' else histo_queries |
|
block_sources = self.retriever.similarity_search(query=queries) |
|
block_sources = self._select_best_sources(block_sources) |
|
for block in block_sources: |
|
self._expand_block_with_specials(block, histo_queries) |
|
sources_contents = [s.content for s in block_sources] |
|
context = '\n'.join(sources_contents) |
|
answer = self.llm.generate_paragraph(query=queries, histo=histo_conversation, context=context, |
|
language=self.content_language) |
|
sources_contents_fr = [s.content_fr for s in block_sources[:2]] |
|
context_fr = '\n'.join(sources_contents_fr) |
|
if self.content_language == 'en': |
|
answer = self.llm.generate_answer(answer_en=answer, query=query_fr, |
|
histo_fr=histo_conversation, context_fr=context_fr) |
|
answer = self._clean_answer(answer) |
|
return answer, block_sources |
|
|
|
@staticmethod |
|
def _get_histo(histo: [(str, str)]) -> (str, str): |
|
histo_conversation = "" |
|
histo_queries = "" |
|
|
|
for (query, answer) in histo[-5:]: |
|
histo_conversation += f'user: {query} \n bot: {answer}\n' |
|
histo_queries += query + '\n' |
|
return histo_conversation[:-1], histo_queries |
|
|
|
@staticmethod |
|
def _clean_answer(answer: str) -> str: |
|
answer = answer.strip('bot:') |
|
while answer and answer[-1] in {"'", '"', " ", "`"}: |
|
answer = answer[:-1] |
|
while answer and answer[0] in {"'", '"', " ", "`"}: |
|
answer = answer[1:] |
|
answer = answer.strip('bot:') |
|
if answer: |
|
if answer[-1] != ".": |
|
answer += "." |
|
return answer |
|
|
|
@staticmethod |
|
def _select_best_sources(sources: [Block], delta_1_2=0.15, delta_1_n=0.3, absolute=1.2, alpha=0.9) -> [Block]: |
|
""" |
|
Select the best sources: not far from the very best, not far from the last selected, and not too bad per se |
|
""" |
|
best_sources = [] |
|
for idx, s in enumerate(sources): |
|
if idx == 0 \ |
|
or (s.distance - sources[idx - 1].distance < delta_1_2 |
|
and s.distance - sources[0].distance < delta_1_n) \ |
|
or s.distance < absolute: |
|
best_sources.append(s) |
|
delta_1_2 *= alpha |
|
delta_1_n *= alpha |
|
absolute *= alpha |
|
else: |
|
break |
|
return best_sources |
|
|
|
def _expand_block_with_specials(self, block: Block, query: str) -> Block: |
|
""" |
|
Performs special treatments for blocks expanding the text in the block |
|
For example, it may add specific content extracted from a table based on elements of the query |
|
""" |
|
|
|
def any_in(l1: [], l2: []) -> bool: |
|
""" |
|
checks if any of el in l1 belongs to l2 |
|
""" |
|
return 0 < len([el for el in l1 if el in l2]) |
|
|
|
def get_countries_names(df: pd.DataFrame) -> [str]: |
|
""" |
|
extends the ortograph of countries: ex. Etats-Unis = USA = Etats Unis, etc. |
|
""" |
|
countries_fr = list(df['pays']) |
|
countries_en = list(df['country']) |
|
countries_names = {c_fr: [c_fr, c_en] for c_fr, c_en in zip(countries_fr, countries_en)} |
|
countries_extensions = self.specials['countries_extensions'] |
|
for c in set(countries_extensions.keys()).intersection(set(countries_names.keys())): |
|
countries_names[c] += countries_extensions[c] |
|
return countries_names |
|
|
|
def remote_rate_fn(ctrl: Controller, block: Block, query: str) -> Block: |
|
remote_rate_df = self.specials['remote_rate_df'] |
|
remote_rate_known = self.specials['remote_rate_known'] |
|
remote_rate_unknown = self.specials['remote_rate_unknown'] |
|
countries_fr = list(remote_rate_df['pays']) |
|
countries_names = get_countries_names(remote_rate_df) |
|
countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)] |
|
for c in countries_of_interest: |
|
rate = remote_rate_df[remote_rate_df['pays'] == c]['rate'].values[0] |
|
block.content += remote_rate_known + c + " is " + rate + '\n' |
|
if len(countries_of_interest) == 0: |
|
block.content += remote_rate_unknown |
|
return block |
|
|
|
def accommodation_meal_fn(ctrl: Controller, block: Block, query: str) -> Block: |
|
accommodation_meal_df = self.specials['accommodation_meal_df'] |
|
accommodation_meal_known = self.specials['accommodation_meal_known'] |
|
accommodation_meal_unknown = self.specials['accommodation_meal_unknown'] |
|
countries_fr = list(accommodation_meal_df['pays']) |
|
countries_names = get_countries_names(df=accommodation_meal_df) |
|
countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)] |
|
for c in countries_of_interest: |
|
rate = accommodation_meal_df[accommodation_meal_df['pays'] == c][['meal', 'accommodation']].values |
|
block.content += accommodation_meal_known + c + " is " + rate[0][0] + ' for meals and ' \ |
|
+ rate[0][1] + ' for accommodation\n' |
|
if len(countries_of_interest) == 0: |
|
block.content += accommodation_meal_unknown |
|
return block |
|
|
|
def expand_block(special: str, ctrl: Controller, block: Block, query: str) -> Block: |
|
routing_table = {'RemotenessRateTable': remote_rate_fn, |
|
'AccommodationMealTable': accommodation_meal_fn, } |
|
if special in routing_table.keys(): |
|
fn = routing_table[special] |
|
block = fn(ctrl, block, query) |
|
return block |
|
|
|
for special in block.specials: |
|
block = expand_block(special, self, block, query) |
|
return block |
|
|