File size: 6,731 Bytes
7fea1f4 de09bee 7fea1f4 de09bee 7fea1f4 de09bee 7fea1f4 de09bee cfd8c2d 1e64f5f 7fea1f4 1e64f5f 7fea1f4 1e64f5f 7fea1f4 f53f6c3 7fea1f4 de09bee 7fea1f4 f53f6c3 1e64f5f f53f6c3 f2681b8 f53f6c3 cfd8c2d 7fea1f4 de09bee 7fea1f4 de09bee 7fea1f4 de09bee 7fea1f4 de09bee 7fea1f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import pandas as pd
from src.tools.retriever import Retriever
from src.tools.llm import LlmAgent
from src.model.block import Block
class Controller:
def __init__(self, retriever: Retriever, llm: LlmAgent, plan_language: str, content_language: str, specials: {}):
self.plan_language = plan_language
self.content_language = content_language
self.retriever = retriever
self.specials = specials
self.llm = llm
def get_response(self, query_fr: str, histo_fr: [(str, str)]) -> (str, [Block]):
histo_conversation, histo_queries = self._get_histo(histo_fr)
queries = self.llm.translate(text=histo_queries) if self.plan_language == 'en' else histo_queries
block_sources = self.retriever.similarity_search(query=queries)
block_sources = self._select_best_sources(block_sources)
for block in block_sources:
self._expand_block_with_specials(block, histo_queries)
sources_contents = [s.content for s in block_sources]
context = '\n'.join(sources_contents)
answer = self.llm.generate_paragraph(query=queries, histo=histo_conversation, context=context,
language=self.content_language)
sources_contents_fr = [s.content_fr for s in block_sources[:2]]
context_fr = '\n'.join(sources_contents_fr)
if self.content_language == 'en':
answer = self.llm.generate_answer(answer_en=answer, query=query_fr,
histo_fr=histo_conversation, context_fr=context_fr)
answer = self._clean_answer(answer)
return answer, block_sources
@staticmethod
def _get_histo(histo: [(str, str)]) -> (str, str):
histo_conversation = ""
histo_queries = ""
for (query, answer) in histo[-5:]:
histo_conversation += f'user: {query} \n bot: {answer}\n'
histo_queries += query + '\n'
return histo_conversation[:-1], histo_queries
@staticmethod
def _clean_answer(answer: str) -> str:
answer = answer.strip('bot:')
while answer and answer[-1] in {"'", '"', " ", "`"}:
answer = answer[:-1]
while answer and answer[0] in {"'", '"', " ", "`"}:
answer = answer[1:]
answer = answer.strip('bot:')
if answer:
if answer[-1] != ".":
answer += "."
return answer
@staticmethod
def _select_best_sources(sources: [Block], delta_1_2=0.15, delta_1_n=0.3, absolute=1.2, alpha=0.9) -> [Block]:
"""
Select the best sources: not far from the very best, not far from the last selected, and not too bad per se
"""
best_sources = []
for idx, s in enumerate(sources):
if idx == 0 \
or (s.distance - sources[idx - 1].distance < delta_1_2
and s.distance - sources[0].distance < delta_1_n) \
or s.distance < absolute:
best_sources.append(s)
delta_1_2 *= alpha
delta_1_n *= alpha
absolute *= alpha
else:
break
return best_sources
def _expand_block_with_specials(self, block: Block, query: str) -> Block:
"""
Performs special treatments for blocks expanding the text in the block
For example, it may add specific content extracted from a table based on elements of the query
"""
def any_in(l1: [], l2: []) -> bool:
"""
checks if any of el in l1 belongs to l2
"""
return 0 < len([el for el in l1 if el in l2])
def get_countries_names(df: pd.DataFrame) -> [str]:
"""
extends the ortograph of countries: ex. Etats-Unis = USA = Etats Unis, etc.
"""
countries_fr = list(df['pays'])
countries_en = list(df['country'])
countries_names = {c_fr: [c_fr, c_en] for c_fr, c_en in zip(countries_fr, countries_en)}
countries_extensions = self.specials['countries_extensions']
for c in set(countries_extensions.keys()).intersection(set(countries_names.keys())):
countries_names[c] += countries_extensions[c]
return countries_names
def remote_rate_fn(ctrl: Controller, block: Block, query: str) -> Block:
remote_rate_df = self.specials['remote_rate_df']
remote_rate_known = self.specials['remote_rate_known']
remote_rate_unknown = self.specials['remote_rate_unknown']
countries_fr = list(remote_rate_df['pays'])
countries_names = get_countries_names(remote_rate_df)
countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)]
for c in countries_of_interest:
rate = remote_rate_df[remote_rate_df['pays'] == c]['rate'].values[0]
block.content += remote_rate_known + c + " is " + rate + '\n'
if len(countries_of_interest) == 0:
block.content += remote_rate_unknown
return block
def accommodation_meal_fn(ctrl: Controller, block: Block, query: str) -> Block:
accommodation_meal_df = self.specials['accommodation_meal_df']
accommodation_meal_known = self.specials['accommodation_meal_known']
accommodation_meal_unknown = self.specials['accommodation_meal_unknown']
countries_fr = list(accommodation_meal_df['pays'])
countries_names = get_countries_names(df=accommodation_meal_df)
countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)]
for c in countries_of_interest:
rate = accommodation_meal_df[accommodation_meal_df['pays'] == c][['meal', 'accommodation']].values
block.content += accommodation_meal_known + c + " is " + rate[0][0] + ' for meals and ' \
+ rate[0][1] + ' for accommodation\n'
if len(countries_of_interest) == 0:
block.content += accommodation_meal_unknown
return block
def expand_block(special: str, ctrl: Controller, block: Block, query: str) -> Block:
routing_table = {'RemotenessRateTable': remote_rate_fn,
'AccommodationMealTable': accommodation_meal_fn, }
if special in routing_table.keys():
fn = routing_table[special]
block = fn(ctrl, block, query)
return block
for special in block.specials:
block = expand_block(special, self, block, query)
return block
|