|
from typing import Any, Literal, Optional, cast |
|
import ast |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from geopy.geocoders import Nominatim |
|
from climateqa.engine.llm import get_llm |
|
import duckdb |
|
import os |
|
from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH |
|
from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput |
|
from climateqa.engine.talk_to_data.objects.location import Location |
|
from climateqa.engine.talk_to_data.objects.plot import Plot |
|
from climateqa.engine.talk_to_data.objects.states import State |
|
|
|
async def detect_location_with_openai(sentence: str) -> str: |
|
""" |
|
Detects locations in a sentence using OpenAI's API via LangChain. |
|
""" |
|
llm = get_llm() |
|
|
|
prompt = f""" |
|
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence. |
|
Return the result as a Python list. If no locations are mentioned, return an empty list. |
|
|
|
Sentence: "{sentence}" |
|
""" |
|
|
|
response = await llm.ainvoke(prompt) |
|
location_list = ast.literal_eval(response.content.strip("```python\n").strip()) |
|
if location_list: |
|
return location_list[0] |
|
else: |
|
return "" |
|
|
|
def loc_to_coords(location: str) -> tuple[float, float]: |
|
"""Converts a location name to geographic coordinates. |
|
|
|
This function uses the Nominatim geocoding service to convert |
|
a location name (e.g., city name) to its latitude and longitude. |
|
|
|
Args: |
|
location (str): The name of the location to geocode |
|
|
|
Returns: |
|
tuple[float, float]: A tuple containing (latitude, longitude) |
|
|
|
Raises: |
|
AttributeError: If the location cannot be found |
|
""" |
|
geolocator = Nominatim(user_agent="city_to_latlong", timeout=5) |
|
coords = geolocator.geocode(location) |
|
return (coords.latitude, coords.longitude) |
|
|
|
def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]: |
|
"""Converts geographic coordinates to a country name. |
|
|
|
This function uses the Nominatim reverse geocoding service to convert |
|
latitude and longitude coordinates to a country name. |
|
|
|
Args: |
|
coords (tuple[float, float]): A tuple containing (latitude, longitude) |
|
|
|
Returns: |
|
tuple[str,str]: A tuple containg (country_code, country_name, admin1) |
|
|
|
Raises: |
|
AttributeError: If the coordinates cannot be found |
|
""" |
|
geolocator = Nominatim(user_agent="latlong_to_country") |
|
location = geolocator.reverse(coords) |
|
address = location.raw['address'] |
|
return address['country_code'].upper(), address['country'] |
|
|
|
def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]: |
|
long = round(location[1], 3) |
|
lat = round(location[0], 3) |
|
conn = duckdb.connect() |
|
|
|
if mode == 'DRIAS': |
|
table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'" |
|
results = conn.sql( |
|
f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}" |
|
).fetchdf() |
|
else: |
|
table_path = f"'{IPCC_COORDINATES_PATH}'" |
|
results = conn.sql( |
|
f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}" |
|
).fetchdf() |
|
|
|
|
|
if len(results) == 0: |
|
return "", "", "" |
|
|
|
if 'admin1' in results.columns: |
|
admin1 = results['admin1'].iloc[0] |
|
else: |
|
admin1 = None |
|
return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1 |
|
|
|
async def detect_year_with_openai(sentence: str) -> str: |
|
""" |
|
Detects years in a sentence using OpenAI's API via LangChain. |
|
""" |
|
llm = get_llm() |
|
|
|
prompt = """ |
|
Extract all years mentioned in the following sentence. |
|
Return the result as a Python list. If no year are mentioned, return an empty list. |
|
|
|
Sentence: "{sentence}" |
|
""" |
|
|
|
prompt = ChatPromptTemplate.from_template(prompt) |
|
structured_llm = llm.with_structured_output(ArrayOutput) |
|
chain = prompt | structured_llm |
|
response: ArrayOutput = await chain.ainvoke({"sentence": sentence}) |
|
years_list = ast.literal_eval(response['array']) |
|
if len(years_list) > 0: |
|
return years_list[0] |
|
else: |
|
return "" |
|
|
|
|
|
async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]: |
|
"""Identifies relevant tables for a plot based on user input. |
|
|
|
This function uses an LLM to analyze the user's question and the plot |
|
description to determine which tables in the DRIAS database would be |
|
most relevant for generating the requested visualization. |
|
|
|
Args: |
|
user_question (str): The user's question about climate data |
|
plot (Plot): The plot configuration object |
|
llm: The language model instance to use for analysis |
|
|
|
Returns: |
|
list[str]: A list of table names that are relevant for the plot |
|
|
|
Example: |
|
>>> detect_relevant_tables( |
|
... "What will the temperature be like in Paris?", |
|
... indicator_evolution_at_location, |
|
... llm |
|
... ) |
|
['mean_annual_temperature', 'mean_summer_temperature'] |
|
""" |
|
|
|
|
|
prompt = ( |
|
f"You are helping to build a plot following this description : {plot['description']}." |
|
f"You are given a list of tables and a user question." |
|
f"Based on the description of the plot, which table are appropriate for that kind of plot." |
|
f"Write the 3 most relevant tables to use. Answer only a python list of table name." |
|
f"### List of tables : {table_names_list}" |
|
f"### User question : {user_question}" |
|
f"### List of table name : " |
|
) |
|
|
|
table_names = ast.literal_eval( |
|
(await llm.ainvoke(prompt)).content.strip("```python\n").strip() |
|
) |
|
return table_names |
|
|
|
async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]: |
|
plots_description = "" |
|
for plot in plot_list: |
|
plots_description += "Name: " + plot["name"] |
|
plots_description += " - Description: " + plot["description"] + "\n" |
|
|
|
prompt = ( |
|
"You are helping to answer a question with insightful visualizations.\n" |
|
"You are given a user question and a list of plots with their name and description.\n" |
|
"Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. " |
|
"Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n" |
|
"For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n" |
|
"Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n" |
|
f"### Descriptions of the plots : {plots_description}" |
|
f"### User question : {user_question}\n" |
|
f"### Names of the plots : " |
|
) |
|
|
|
plot_names = ast.literal_eval( |
|
(await llm.ainvoke(prompt)).content.strip("```python\n").strip() |
|
) |
|
return plot_names |
|
|
|
async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location: |
|
print(f"---- Find location in user input ----") |
|
location = await detect_location_with_openai(user_input) |
|
output: Location = { |
|
'location' : location, |
|
'longitude' : None, |
|
'latitude' : None, |
|
'country_code' : None, |
|
'country_name' : None, |
|
'admin1' : None |
|
} |
|
|
|
if location: |
|
coords = loc_to_coords(location) |
|
country_code, country_name = coords_to_country(coords) |
|
neighbour = nearest_neighbour_sql(coords, mode) |
|
output.update({ |
|
"latitude": neighbour[0], |
|
"longitude": neighbour[1], |
|
"country_code": country_code, |
|
"country_name": country_name, |
|
"admin1": neighbour[2] |
|
}) |
|
output = cast(Location, output) |
|
return output |
|
|
|
async def find_year(user_input: str) -> str| None: |
|
"""Extracts year information from user input using LLM. |
|
|
|
This function uses an LLM to identify and extract year information from the |
|
user's query, which is used to filter data in subsequent queries. |
|
|
|
Args: |
|
user_input (str): The user's query text |
|
|
|
Returns: |
|
str: The extracted year, or empty string if no year found |
|
""" |
|
print(f"---- Find year ---") |
|
year = await detect_year_with_openai(user_input) |
|
if year == "": |
|
return None |
|
return year |
|
|
|
async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]: |
|
print("---- Find relevant plots ----") |
|
relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots) |
|
return relevant_plots |
|
|
|
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]: |
|
print(f"---- Find relevant tables for {plot['name']} ----") |
|
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables) |
|
return relevant_tables |
|
|
|
async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None: |
|
"""Perform the good method to retrieve the desired parameter |
|
|
|
Args: |
|
state (State): state of the workflow |
|
param_name (str): name of the desired parameter |
|
table (str): name of the table |
|
|
|
Returns: |
|
dict[str, Any] | None: |
|
""" |
|
if param_name == 'location': |
|
location = await find_location(state['user_input'], mode) |
|
return location |
|
if param_name == 'year': |
|
year = await find_year(state['user_input']) |
|
return {'year': year} |
|
return None |
|
|