File size: 9,988 Bytes
711bc31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f79bf0a
711bc31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
from typing import Any, Literal, Optional, cast
import ast
from langchain_core.prompts import ChatPromptTemplate
from geopy.geocoders import Nominatim
from climateqa.engine.llm import get_llm
import duckdb
import os 
from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH
from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
from climateqa.engine.talk_to_data.objects.location import Location
from climateqa.engine.talk_to_data.objects.plot import Plot
from climateqa.engine.talk_to_data.objects.states import State

async def detect_location_with_openai(sentence: str) -> str:
    """
    Detects locations in a sentence using OpenAI's API via LangChain.
    """
    llm = get_llm()

    prompt = f"""
    Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
    Return the result as a Python list. If no locations are mentioned, return an empty list.
    
    Sentence: "{sentence}"
    """

    response = await llm.ainvoke(prompt)
    location_list = ast.literal_eval(response.content.strip("```python\n").strip())
    if location_list:
        return location_list[0]
    else:
        return ""

def loc_to_coords(location: str) -> tuple[float, float]:
    """Converts a location name to geographic coordinates.
    
    This function uses the Nominatim geocoding service to convert
    a location name (e.g., city name) to its latitude and longitude.
    
    Args:
        location (str): The name of the location to geocode
        
    Returns:
        tuple[float, float]: A tuple containing (latitude, longitude)
        
    Raises:
        AttributeError: If the location cannot be found
    """
    geolocator = Nominatim(user_agent="city_to_latlong", timeout=5)
    coords = geolocator.geocode(location)
    return (coords.latitude, coords.longitude)

def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
    """Converts geographic coordinates to a country name.
    
    This function uses the Nominatim reverse geocoding service to convert
    latitude and longitude coordinates to a country name.
    
    Args:
        coords (tuple[float, float]): A tuple containing (latitude, longitude)
        
    Returns:
        tuple[str,str]: A tuple containg (country_code, country_name, admin1)
        
    Raises:
        AttributeError: If the coordinates cannot be found
    """
    geolocator = Nominatim(user_agent="latlong_to_country")
    location = geolocator.reverse(coords)
    address = location.raw['address']
    return address['country_code'].upper(), address['country']

def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
    long = round(location[1], 3)
    lat = round(location[0], 3)
    conn = duckdb.connect()

    if mode == 'DRIAS':
        table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'"
        results = conn.sql(
            f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
        ).fetchdf()
    else:
        table_path = f"'{IPCC_COORDINATES_PATH}'"
        results = conn.sql(
            f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
        ).fetchdf()
    

    if len(results) == 0:
        return "", "", ""

    if 'admin1' in results.columns:
        admin1 = results['admin1'].iloc[0]
    else:
        admin1 = None
    return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1

async def detect_year_with_openai(sentence: str) -> str:
    """
    Detects years in a sentence using OpenAI's API via LangChain.
    """
    llm = get_llm()

    prompt = """
    Extract all years mentioned in the following sentence.
    Return the result as a Python list. If no year are mentioned, return an empty list.
    
    Sentence: "{sentence}"
    """

    prompt = ChatPromptTemplate.from_template(prompt)
    structured_llm = llm.with_structured_output(ArrayOutput)
    chain = prompt | structured_llm
    response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
    years_list = ast.literal_eval(response['array'])
    if len(years_list) > 0:
        return years_list[0]
    else:
        return ""
    

async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
    """Identifies relevant tables for a plot based on user input.
    
    This function uses an LLM to analyze the user's question and the plot
    description to determine which tables in the DRIAS database would be
    most relevant for generating the requested visualization.
    
    Args:
        user_question (str): The user's question about climate data
        plot (Plot): The plot configuration object
        llm: The language model instance to use for analysis
        
    Returns:
        list[str]: A list of table names that are relevant for the plot
        
    Example:
        >>> detect_relevant_tables(
        ...     "What will the temperature be like in Paris?",
        ...     indicator_evolution_at_location,
        ...     llm
        ... )
        ['mean_annual_temperature', 'mean_summer_temperature']
    """
    # Get all table names

    prompt = (
        f"You are helping to build a plot following this description : {plot['description']}."
        f"You are given a list of tables and a user question."
        f"Based on the description of the plot, which table are appropriate for that kind of plot."
        f"Write the 3 most relevant tables to use. Answer only a python list of table name."
        f"### List of tables : {table_names_list}"
        f"### User question : {user_question}"
        f"### List of table name : "
    )

    table_names = ast.literal_eval(
        (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
    )
    return table_names

async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]:
    plots_description = ""
    for plot in plot_list:
        plots_description += "Name: " + plot["name"]
        plots_description += " - Description: " + plot["description"] + "\n"

    prompt = (
        "You are helping to answer a question with insightful visualizations.\n"
        "You are given a user question and a list of plots with their name and description.\n"
        "Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
        "Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
        "For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
        "Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
        f"### Descriptions of the plots : {plots_description}"
        f"### User question : {user_question}\n"
        f"### Names of the plots : "
    )

    plot_names = ast.literal_eval(
        (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
    )
    return plot_names

async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
    print(f"---- Find location in user input ----")
    location = await detect_location_with_openai(user_input)
    output: Location = {
        'location' : location,
        'longitude' : None,
        'latitude' : None,
        'country_code' : None,
        'country_name' : None,
        'admin1' : None
        }
    
    if location:
        coords = loc_to_coords(location)
        country_code, country_name = coords_to_country(coords)
        neighbour = nearest_neighbour_sql(coords, mode)
        output.update({
            "latitude": neighbour[0],
            "longitude": neighbour[1],
            "country_code": country_code,
            "country_name": country_name,
            "admin1": neighbour[2]
        })
    output = cast(Location, output)
    return output

async def find_year(user_input: str) -> str| None:
    """Extracts year information from user input using LLM.
    
    This function uses an LLM to identify and extract year information from the
    user's query, which is used to filter data in subsequent queries.
    
    Args:
        user_input (str): The user's query text
        
    Returns:
        str: The extracted year, or empty string if no year found
    """
    print(f"---- Find year ---")
    year = await detect_year_with_openai(user_input)
    if year == "":
        return None
    return year

async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
    print("---- Find relevant plots ----")
    relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
    return relevant_plots

async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]:
    print(f"---- Find relevant tables for {plot['name']} ----")
    relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
    return relevant_tables

async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
    """Perform the good method to retrieve the desired parameter

    Args:
        state (State): state of the workflow
        param_name (str): name of the desired parameter
        table (str): name of the table

    Returns:
        dict[str, Any] | None: 
    """
    if param_name == 'location':
        location = await find_location(state['user_input'], mode)
        return location
    if param_name == 'year':
        year = await find_year(state['user_input'])
        return {'year': year}
    return None