import pandas as pd from src.tools.retriever import Retriever from src.tools.llm import LlmAgent from src.model.block import Block class Controller: def __init__(self, retriever: Retriever, llm: LlmAgent, plan_language: str, content_language: str, specials: {}): self.plan_language = plan_language self.content_language = content_language self.retriever = retriever self.specials = specials self.llm = llm def get_response(self, query_fr: str, histo_fr: [(str, str)]) -> (str, [Block]): histo_conversation, histo_queries = self._get_histo(histo_fr) queries = self.llm.translate(text=histo_queries) if self.plan_language == 'en' else histo_queries block_sources = self.retriever.similarity_search(query=queries) block_sources = self._select_best_sources(block_sources) for block in block_sources: self._expand_block_with_specials(block, histo_queries) sources_contents = [s.content for s in block_sources] context = '\n'.join(sources_contents) answer = self.llm.generate_paragraph(query=queries, histo=histo_conversation, context=context, language=self.content_language) sources_contents_fr = [s.content_fr for s in block_sources[:2]] context_fr = '\n'.join(sources_contents_fr) if self.content_language == 'en': answer = self.llm.generate_answer(answer_en=answer, query=query_fr, histo_fr=histo_conversation, context_fr=context_fr) answer = self._clean_answer(answer) return answer, block_sources @staticmethod def _get_histo(histo: [(str, str)]) -> (str, str): histo_conversation = "" histo_queries = "" for (query, answer) in histo[-5:]: histo_conversation += f'user: {query} \n bot: {answer}\n' histo_queries += query + '\n' return histo_conversation[:-1], histo_queries @staticmethod def _clean_answer(answer: str) -> str: answer = answer.strip('bot:') while answer and answer[-1] in {"'", '"', " ", "`"}: answer = answer[:-1] while answer and answer[0] in {"'", '"', " ", "`"}: answer = answer[1:] answer = answer.strip('bot:') if answer: if answer[-1] != ".": answer += "." return answer @staticmethod def _select_best_sources(sources: [Block], delta_1_2=0.15, delta_1_n=0.3, absolute=1.2, alpha=0.9) -> [Block]: """ Select the best sources: not far from the very best, not far from the last selected, and not too bad per se """ best_sources = [] for idx, s in enumerate(sources): if idx == 0 \ or (s.distance - sources[idx - 1].distance < delta_1_2 and s.distance - sources[0].distance < delta_1_n) \ or s.distance < absolute: best_sources.append(s) delta_1_2 *= alpha delta_1_n *= alpha absolute *= alpha else: break return best_sources def _expand_block_with_specials(self, block: Block, query: str) -> Block: """ Performs special treatments for blocks expanding the text in the block For example, it may add specific content extracted from a table based on elements of the query """ def any_in(l1: [], l2: []) -> bool: """ checks if any of el in l1 belongs to l2 """ return 0 < len([el for el in l1 if el in l2]) def get_countries_names(df: pd.DataFrame) -> [str]: """ extends the ortograph of countries: ex. Etats-Unis = USA = Etats Unis, etc. """ countries_fr = list(df['pays']) countries_en = list(df['country']) countries_names = {c_fr: [c_fr, c_en] for c_fr, c_en in zip(countries_fr, countries_en)} countries_extensions = self.specials['countries_extensions'] for c in set(countries_extensions.keys()).intersection(set(countries_names.keys())): countries_names[c] += countries_extensions[c] return countries_names def remote_rate_fn(ctrl: Controller, block: Block, query: str) -> Block: remote_rate_df = self.specials['remote_rate_df'] remote_rate_known = self.specials['remote_rate_known'] remote_rate_unknown = self.specials['remote_rate_unknown'] countries_fr = list(remote_rate_df['pays']) countries_names = get_countries_names(remote_rate_df) countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)] for c in countries_of_interest: rate = remote_rate_df[remote_rate_df['pays'] == c]['rate'].values[0] block.content += remote_rate_known + c + " is " + rate + '\n' if len(countries_of_interest) == 0: block.content += remote_rate_unknown return block def accommodation_meal_fn(ctrl: Controller, block: Block, query: str) -> Block: accommodation_meal_df = self.specials['accommodation_meal_df'] accommodation_meal_known = self.specials['accommodation_meal_known'] accommodation_meal_unknown = self.specials['accommodation_meal_unknown'] countries_fr = list(accommodation_meal_df['pays']) countries_names = get_countries_names(df=accommodation_meal_df) countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)] for c in countries_of_interest: rate = accommodation_meal_df[accommodation_meal_df['pays'] == c][['meal', 'accommodation']].values block.content += accommodation_meal_known + c + " is " + rate[0][0] + ' for meals and ' \ + rate[0][1] + ' for accommodation\n' if len(countries_of_interest) == 0: block.content += accommodation_meal_unknown return block def expand_block(special: str, ctrl: Controller, block: Block, query: str) -> Block: routing_table = {'RemotenessRateTable': remote_rate_fn, 'AccommodationMealTable': accommodation_meal_fn, } if special in routing_table.keys(): fn = routing_table[special] block = fn(ctrl, block, query) return block for special in block.specials: block = expand_block(special, self, block, query) return block