QnA / src /control /control.py
YvesP's picture
added file management
7fea1f4
raw
history blame
5.71 kB
import pandas as pd
from src.tools.retriever import Retriever
from src.tools.llm import LlmAgent
from src.model.block import Block
class Controller:
def __init__(self, retriever: Retriever, llm: LlmAgent, plan_language: str, content_language: str, specials: {}):
self.plan_language = plan_language
self.content_language = content_language
self.retriever = retriever
self.specials = specials
self.llm = llm
def get_response(self, query_fr: str) -> (str, [Block]):
query = self.llm.translate(text=query_fr) if self.plan_language == 'en' else query_fr
block_sources = self.retriever.similarity_search(query=query)
block_sources = self._select_best_sources(block_sources)
for block in block_sources:
self._expand_block_with_specials(block, query_fr)
sources_contents = [s.content for s in block_sources]
context = '\n'.join(sources_contents)
answer = self.llm.generate_paragraph(query=query, context=context, language=self.content_language)
sources_contents_fr = [s.content_fr for s in block_sources[:2]]
context_fr = '\n'.join(sources_contents_fr)
if self.content_language == 'en':
answer = self.llm.generate_answer(answer_en=answer, query=query_fr, context_fr=context_fr)
answer = answer.strip().strip("'''").strip("```")
return answer, block_sources
@staticmethod
def _select_best_sources(sources: [Block], delta_1_2=0.1, delta_1_n=0.25, absolute=1.1, alpha=0.85) -> [Block]:
"""
Select the best sources: not far from the very best, not far from the last selected, and not too bad per se
"""
best_sources = []
for idx, s in enumerate(sources):
if idx == 0 \
or (s.distance - sources[idx - 1].distance < delta_1_2
and s.distance - sources[0].distance < delta_1_n) \
or s.distance < absolute:
best_sources.append(s)
delta_1_2 *= alpha
delta_1_n *= alpha
absolute *= alpha
else:
break
return best_sources
def _expand_block_with_specials(self, block: Block, query: str) -> Block:
"""
Performs special treatments for blocks expanding the text in the block
For example, it may add specific content extracted from a table based on elements of the query
"""
def any_in(l1: [], l2: []) -> bool:
"""
checks if any of el in l1 belongs to l2
"""
return 0 < len([el for el in l1 if el in l2])
def get_countries_names(df: pd.DataFrame) -> [str]:
"""
extends the ortograph of countries: ex. Etats-Unis = USA = Etats Unis, etc.
"""
countries_fr = list(df['pays'])
countries_en = list(df['country'])
countries_names = {c_fr: [c_fr, c_en] for c_fr, c_en in zip(countries_fr, countries_en)}
countries_extensions = self.specials['countries_extensions']
for c in set(countries_extensions.keys()).intersection(set(countries_names.keys())):
countries_names[c] += countries_extensions[c]
return countries_names
def remote_rate_fn(ctrl: Controller, block: Block, query: str) -> Block:
remote_rate_df = self.specials['remote_rate_df']
remote_rate_known = self.specials['remote_rate_known']
remote_rate_unknown = self.specials['remote_rate_unknown']
countries_fr = list(remote_rate_df['pays'])
countries_names = get_countries_names(remote_rate_df)
countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)]
for c in countries_of_interest:
rate = remote_rate_df[remote_rate_df['pays'] == c]['rate'].values[0]
block.content += remote_rate_known + c + " is " + rate + '\n'
if len(countries_of_interest) == 0:
block.content += remote_rate_unknown
return block
def accommodation_meal_fn(ctrl: Controller, block: Block, query: str) -> Block:
accommodation_meal_df = self.specials['accommodation_meal_df']
accommodation_meal_known = self.specials['accommodation_meal_known']
accommodation_meal_unknown = self.specials['accommodation_meal_unknown']
countries_fr = list(accommodation_meal_df['pays'])
countries_names = get_countries_names(df=accommodation_meal_df)
countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)]
for c in countries_of_interest:
rate = accommodation_meal_df[accommodation_meal_df['pays'] == c][['meal', 'accommodation']].values
block.content += accommodation_meal_known + c + " is " + rate[0][0] + ' for meals and ' \
+ rate[0][1] + ' for accommodation\n'
if len(countries_of_interest) == 0:
block.content += accommodation_meal_unknown
return block
def expand_block(special: str, ctrl: Controller, block: Block, query: str) -> Block:
routing_table = {'RemotenessRateTable': remote_rate_fn,
'AccommodationMealTable': accommodation_meal_fn, }
if special in routing_table.keys():
fn = routing_table[special]
block = fn(ctrl, block, query)
return block
for special in block.specials:
block = expand_block(special, self, block, query)
return block