File size: 5,712 Bytes
7fea1f4
de09bee
7fea1f4
 
 
de09bee
 
7fea1f4
de09bee
7fea1f4
 
 
 
 
 
de09bee
7fea1f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de09bee
7fea1f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de09bee
7fea1f4
 
 
 
 
de09bee
7fea1f4
 
 
 
 
de09bee
7fea1f4
 
 
 
 
 
 
 
 
 
 
de09bee
7fea1f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pandas as pd

from src.tools.retriever import Retriever
from src.tools.llm import LlmAgent
from src.model.block import Block


class Controller:

    def __init__(self, retriever: Retriever, llm: LlmAgent, plan_language: str, content_language: str, specials: {}):
        self.plan_language = plan_language
        self.content_language = content_language
        self.retriever = retriever
        self.specials = specials
        self.llm = llm

    def get_response(self, query_fr: str) -> (str, [Block]):
        query = self.llm.translate(text=query_fr) if self.plan_language == 'en' else query_fr
        block_sources = self.retriever.similarity_search(query=query)
        block_sources = self._select_best_sources(block_sources)
        for block in block_sources:
            self._expand_block_with_specials(block, query_fr)
        sources_contents = [s.content for s in block_sources]
        context = '\n'.join(sources_contents)
        answer = self.llm.generate_paragraph(query=query, context=context, language=self.content_language)
        sources_contents_fr = [s.content_fr for s in block_sources[:2]]
        context_fr = '\n'.join(sources_contents_fr)
        if self.content_language == 'en':
            answer = self.llm.generate_answer(answer_en=answer, query=query_fr, context_fr=context_fr)
        answer = answer.strip().strip("'''").strip("```")
        return answer, block_sources

    @staticmethod
    def _select_best_sources(sources: [Block], delta_1_2=0.1, delta_1_n=0.25, absolute=1.1, alpha=0.85) -> [Block]:
        """
        Select the best sources: not far from the very best, not far from the last selected, and not too bad per se
        """
        best_sources = []
        for idx, s in enumerate(sources):
            if idx == 0 \
                    or (s.distance - sources[idx - 1].distance < delta_1_2
                        and s.distance - sources[0].distance < delta_1_n) \
                    or s.distance < absolute:
                best_sources.append(s)
                delta_1_2 *= alpha
                delta_1_n *= alpha
                absolute *= alpha
            else:
                break
        return best_sources

    def _expand_block_with_specials(self, block: Block, query: str) -> Block:
        """
        Performs special treatments for blocks expanding the text in the block
        For example, it may add specific content extracted from a table based on elements of the query
        """

        def any_in(l1: [], l2: []) -> bool:
            """
            checks if any of el in l1 belongs to l2
            """
            return 0 < len([el for el in l1 if el in l2])

        def get_countries_names(df: pd.DataFrame) -> [str]:
            """
            extends the ortograph of countries: ex. Etats-Unis = USA = Etats Unis, etc.
            """
            countries_fr = list(df['pays'])
            countries_en = list(df['country'])
            countries_names = {c_fr: [c_fr, c_en] for c_fr, c_en in zip(countries_fr, countries_en)}
            countries_extensions = self.specials['countries_extensions']
            for c in set(countries_extensions.keys()).intersection(set(countries_names.keys())):
                countries_names[c] += countries_extensions[c]
            return countries_names

        def remote_rate_fn(ctrl: Controller, block: Block, query: str) -> Block:
            remote_rate_df = self.specials['remote_rate_df']
            remote_rate_known = self.specials['remote_rate_known']
            remote_rate_unknown = self.specials['remote_rate_unknown']
            countries_fr = list(remote_rate_df['pays'])
            countries_names = get_countries_names(remote_rate_df)
            countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)]
            for c in countries_of_interest:
                rate = remote_rate_df[remote_rate_df['pays'] == c]['rate'].values[0]
                block.content += remote_rate_known + c + " is " + rate + '\n'
            if len(countries_of_interest) == 0:
                block.content += remote_rate_unknown
            return block

        def accommodation_meal_fn(ctrl: Controller, block: Block, query: str) -> Block:
            accommodation_meal_df = self.specials['accommodation_meal_df']
            accommodation_meal_known = self.specials['accommodation_meal_known']
            accommodation_meal_unknown = self.specials['accommodation_meal_unknown']
            countries_fr = list(accommodation_meal_df['pays'])
            countries_names = get_countries_names(df=accommodation_meal_df)
            countries_of_interest = [c for c in countries_fr if any_in(countries_names[c], query)]
            for c in countries_of_interest:
                rate = accommodation_meal_df[accommodation_meal_df['pays'] == c][['meal', 'accommodation']].values
                block.content += accommodation_meal_known + c + " is " + rate[0][0] + ' for meals and ' \
                                 + rate[0][1] + ' for accommodation\n'
            if len(countries_of_interest) == 0:
                block.content += accommodation_meal_unknown
            return block

        def expand_block(special: str, ctrl: Controller, block: Block, query: str) -> Block:
            routing_table = {'RemotenessRateTable': remote_rate_fn,
                             'AccommodationMealTable': accommodation_meal_fn, }
            if special in routing_table.keys():
                fn = routing_table[special]
                block = fn(ctrl, block, query)
            return block

        for special in block.specials:
            block = expand_block(special, self, block, query)
        return block