File size: 12,112 Bytes
05a8b3a
 
 
 
 
 
 
 
 
52bc1cc
05a8b3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52bc1cc
 
 
 
05a8b3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52bc1cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05a8b3a
52bc1cc
 
 
 
 
 
 
 
05a8b3a
 
 
 
 
 
 
 
 
 
 
 
 
52bc1cc
 
05a8b3a
 
 
 
 
 
 
52bc1cc
05a8b3a
 
 
 
 
 
 
 
52bc1cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05a8b3a
52bc1cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05a8b3a
 
52bc1cc
 
05a8b3a
 
16522e2
 
52bc1cc
 
 
05a8b3a
 
 
 
 
 
 
52bc1cc
05a8b3a
 
 
 
52bc1cc
16522e2
 
52bc1cc
 
 
 
 
 
 
 
 
 
 
 
 
16522e2
52bc1cc
 
 
 
 
 
 
 
05a8b3a
 
 
 
52bc1cc
05a8b3a
 
 
 
 
 
 
 
52bc1cc
05a8b3a
 
 
 
 
 
 
 
 
52bc1cc
 
 
 
 
ca02ad4
05a8b3a
52bc1cc
 
 
05a8b3a
28684d8
 
05a8b3a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300


from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
from typing import Literal
from langchain.prompts import ChatPromptTemplate
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser

# OLD QUERY ANALYSIS
    # keywords: List[str] = Field(
    #     description="""
    #     Extract the keywords from the user query to feed a search engine as a list
    #     Maximum 3 keywords

    #     Examples:
    #     - "What is the impact of deep sea mining ?" -> deep sea mining
    #     - "How will El Nino be impacted by climate change" -> el nino;climate change
    #     - "Is climate change a hoax" -> climate change;hoax
    #     """
    # )

    # alternative_queries: List[str] = Field(
    #     description="""
    #     Generate alternative search questions from the user query to feed a search engine
    #     """
    # )

    # step_back_question: str = Field(
    #     description="""
    #     You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer.
    #     This questions should help you get more context and information about the user query
    #     """
    # )
    # - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature 
    # 
    
    
    # topics: List[Literal[
    #     "Climate change",
    #     "Biodiversity",
    #     "Energy",
    #     "Decarbonization",
    #     "Climate science",
    #     "Nature",
    #     "Climate policy and justice",
    #     "Oceans",
    #     "Deep sea mining",
    #     "ESG and regulations",
    #     "CSRD",
    # ]] = Field(
    #     ...,
    #     description = """
    #         Choose the topics that are most relevant to the user query, ex: Climate change, Energy, Biodiversity, ...
    #     """,
    # )
    # date: str = Field(description="The date or period mentioned, ex: 2050, between 2020 and 2050")
    # location:Location



ROUTING_INDEX = {
    "IPx":["IPCC", "IPBES", "IPOS"],
    "POC": ["AcclimaTerra", "PCAET","Biodiv"],
    "OpenAlex":["OpenAlex"],
}

POSSIBLE_SOURCES = [y for values in ROUTING_INDEX.values() for y in values]

# Prompt from the original paper https://arxiv.org/pdf/2305.14283
# Query Rewriting for Retrieval-Augmented Large Language Models
class QueryDecomposition(BaseModel):
    """
    Decompose the user query into smaller parts to think step by step to answer this question
    Act as a simple planning agent
    """

    questions: List[str] = Field(
        description="""
        Think step by step to answer this question, and provide one or several search engine questions in the provided language for knowledge that you need. 
        Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
        - If it's already a standalone and explicit question, just return the reformulated question for the search engine
        - If you need to decompose the question, output a list of maximum 2 to 3 questions
    """
    )


class Location(BaseModel):
    country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
    location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")

class QueryTranslation(BaseModel):
    """Translate the query into a given language"""
    
    question : str = Field(
        description="""
        Translate the questions into the given language
        If the question is alrealdy in the given language, just return the same question
        """,
    )
    
    
class QueryAnalysis(BaseModel):
    """
    Analyze the user query to extract the relevant sources
    
    Deprecated:
    Analyzing the user query to extract topics, sources and date
    Also do query expansion to get alternative search queries
    Also provide simple keywords to feed a search engine
    """

    sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra", "PCAET","Biodiv"]] = Field( #,"OpenAlex"]] = Field(
        ...,
        description="""
            Given a user question choose which documents would be most relevant for answering their question,
            - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
            - IPBES is for questions about biodiversity and nature
            - IPOS is for questions about the ocean and deep sea mining
            - AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
            - PCAET is the Plan Climat Eneregie Territorial for the city of Paris
            - Biodiv is the Biodiversity plan for the city of Paris
        """,
    )



def make_query_decomposition_chain(llm):
    """Chain to decompose a query into smaller parts to think step by step to answer this question

    Args:
        llm (_type_): _description_

    Returns:
        _type_: _description_
    """

    openai_functions = [convert_to_openai_function(QueryDecomposition)]
    llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryDecomposition"})

    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
        ("user", "input: {input}")
    ])

    chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
    return chain


def make_query_analysis_chain(llm):
    """Analyze the user query to extract the relevant sources"""

    openai_functions = [convert_to_openai_function(QueryAnalysis)]
    llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"})



    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a helpful assistant, you will analyze the user input message using the function provided"),
        ("user", "input: {input}")
    ])


    chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
    return chain


def make_query_translation_chain(llm):
    """Analyze the user query to extract the relevant sources"""

    openai_functions = [convert_to_openai_function(QueryTranslation)]
    llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryTranslation"})



    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a helpful assistant, translate the question into {language}"),
        ("user", "input: {input}")
    ])


    chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
    return chain

def group_by_sources_types(sources):
    sources_types = {}
    IPx_sources = ["IPCC", "IPBES", "IPOS"]
    local_sources = ["AcclimaTerra", "PCAET","Biodiv"]
    if any(source in IPx_sources for source in sources):
        sources_types["IPx"] = list(set(sources).intersection(IPx_sources))
    if any(source in local_sources for source in sources):
        sources_types["POC"] = list(set(sources).intersection(local_sources))
    return sources_types


def make_query_transform_node(llm,k_final=15):
    """
    Creates a query transformation node that processes and transforms a given query state.
    Args:
        llm: The language model to be used for query decomposition and rewriting.
        k_final (int, optional): The final number of questions to be generated. Defaults to 15.
    Returns:
        function: A function that takes a query state and returns a transformed state.
    The returned function performs the following steps:
        1. Checks if the query should be processed in auto mode based on the state.
        2. Retrieves the input sources from the state or defaults to a predefined routing index.
        3. Decomposes the query using the decomposition chain.
        4. Analyzes each decomposed question using the rewriter chain.
        5. Ensures that the sources returned by the language model are valid.
        6. Explodes the questions into multiple questions with different sources based on the mode.
        7. Constructs a new state with the transformed questions and their respective sources.
    """


    decomposition_chain = make_query_decomposition_chain(llm)
    query_analysis_chain = make_query_analysis_chain(llm)
    query_translation_chain = make_query_translation_chain(llm)
        
    def transform_query(state):
        print("---- Transform query ----")

        auto_mode = state.get("sources_auto", True)
        sources_input = state.get("sources_input", ROUTING_INDEX["IPx"])
        
        
        new_state = {}
            
        # Decomposition
        decomposition_output = decomposition_chain.invoke({"input":state["query"]})
        new_state.update(decomposition_output)
        
                
        # Query Analysis
        questions = []
        for question in new_state["questions"]:
            question_state = {"question":question}
            query_analysis_output = query_analysis_chain.invoke({"input":question}) 
            
            # TODO WARNING llm should always return smthg
            # The case when the llm does not return any sources or wrong ouput
            if not query_analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS","AcclimaTerra", "PCAET","Biodiv"] for source in query_analysis_output["sources"]):
                query_analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]

            sources_types = group_by_sources_types(query_analysis_output["sources"])
            for source_type,sources in sources_types.items():
                question_state = {
                    "question":question,
                    "sources":sources,
                    "source_type":source_type
                }
        
                questions.append(question_state)

        # Translate question into the document language
        for q in questions:
            if q["source_type"]=="IPx":
                translation_output = query_translation_chain.invoke({"input":q["question"],"language":"English"})
                q["question"] = translation_output["question"]
            elif q["source_type"]=="POC":
                translation_output = query_translation_chain.invoke({"input":q["question"],"language":"French"})
                q["question"] = translation_output["question"]

        # Explode the questions into multiple questions with different sources
        new_questions = []
        for q in questions:
            question,sources,source_type = q["question"],q["sources"], q["source_type"]

            # If not auto mode we take the configuration
            if not auto_mode:
                sources = sources_input

            for index,index_sources in ROUTING_INDEX.items():
                selected_sources = list(set(sources).intersection(index_sources))
                if len(selected_sources) > 0:
                    new_questions.append({"question":question,"sources":selected_sources,"index":index, "source_type":source_type})

        # # Add the number of questions to search
        # k_by_question = k_final // len(new_questions)
        # for q in new_questions:
        #     q["k"] = k_by_question

        # new_state["questions"] = new_questions
        # new_state["remaining_questions"] = new_questions

        n_questions = {
            "total":len(new_questions),
            "IPx":len([q for q in new_questions if q["index"] == "IPx"]),
            "POC":len([q for q in new_questions if q["index"] == "POC"]),
        }

        new_state = {
            "questions_list":new_questions,
            "n_questions":n_questions,
            "handled_questions_index":[],            
        }
        print("New questions")
        print(new_questions)
        return new_state
    
    return transform_query