timeki's picture
increase Nominatim timeout
f79bf0a
from typing import Any, Literal, Optional, cast
import ast
from langchain_core.prompts import ChatPromptTemplate
from geopy.geocoders import Nominatim
from climateqa.engine.llm import get_llm
import duckdb
import os
from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH
from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
from climateqa.engine.talk_to_data.objects.location import Location
from climateqa.engine.talk_to_data.objects.plot import Plot
from climateqa.engine.talk_to_data.objects.states import State
async def detect_location_with_openai(sentence: str) -> str:
"""
Detects locations in a sentence using OpenAI's API via LangChain.
"""
llm = get_llm()
prompt = f"""
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
Return the result as a Python list. If no locations are mentioned, return an empty list.
Sentence: "{sentence}"
"""
response = await llm.ainvoke(prompt)
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
if location_list:
return location_list[0]
else:
return ""
def loc_to_coords(location: str) -> tuple[float, float]:
"""Converts a location name to geographic coordinates.
This function uses the Nominatim geocoding service to convert
a location name (e.g., city name) to its latitude and longitude.
Args:
location (str): The name of the location to geocode
Returns:
tuple[float, float]: A tuple containing (latitude, longitude)
Raises:
AttributeError: If the location cannot be found
"""
geolocator = Nominatim(user_agent="city_to_latlong", timeout=5)
coords = geolocator.geocode(location)
return (coords.latitude, coords.longitude)
def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
"""Converts geographic coordinates to a country name.
This function uses the Nominatim reverse geocoding service to convert
latitude and longitude coordinates to a country name.
Args:
coords (tuple[float, float]): A tuple containing (latitude, longitude)
Returns:
tuple[str,str]: A tuple containg (country_code, country_name, admin1)
Raises:
AttributeError: If the coordinates cannot be found
"""
geolocator = Nominatim(user_agent="latlong_to_country")
location = geolocator.reverse(coords)
address = location.raw['address']
return address['country_code'].upper(), address['country']
def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
long = round(location[1], 3)
lat = round(location[0], 3)
conn = duckdb.connect()
if mode == 'DRIAS':
table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'"
results = conn.sql(
f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
).fetchdf()
else:
table_path = f"'{IPCC_COORDINATES_PATH}'"
results = conn.sql(
f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
).fetchdf()
if len(results) == 0:
return "", "", ""
if 'admin1' in results.columns:
admin1 = results['admin1'].iloc[0]
else:
admin1 = None
return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1
async def detect_year_with_openai(sentence: str) -> str:
"""
Detects years in a sentence using OpenAI's API via LangChain.
"""
llm = get_llm()
prompt = """
Extract all years mentioned in the following sentence.
Return the result as a Python list. If no year are mentioned, return an empty list.
Sentence: "{sentence}"
"""
prompt = ChatPromptTemplate.from_template(prompt)
structured_llm = llm.with_structured_output(ArrayOutput)
chain = prompt | structured_llm
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
years_list = ast.literal_eval(response['array'])
if len(years_list) > 0:
return years_list[0]
else:
return ""
async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
"""Identifies relevant tables for a plot based on user input.
This function uses an LLM to analyze the user's question and the plot
description to determine which tables in the DRIAS database would be
most relevant for generating the requested visualization.
Args:
user_question (str): The user's question about climate data
plot (Plot): The plot configuration object
llm: The language model instance to use for analysis
Returns:
list[str]: A list of table names that are relevant for the plot
Example:
>>> detect_relevant_tables(
... "What will the temperature be like in Paris?",
... indicator_evolution_at_location,
... llm
... )
['mean_annual_temperature', 'mean_summer_temperature']
"""
# Get all table names
prompt = (
f"You are helping to build a plot following this description : {plot['description']}."
f"You are given a list of tables and a user question."
f"Based on the description of the plot, which table are appropriate for that kind of plot."
f"Write the 3 most relevant tables to use. Answer only a python list of table name."
f"### List of tables : {table_names_list}"
f"### User question : {user_question}"
f"### List of table name : "
)
table_names = ast.literal_eval(
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
)
return table_names
async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]:
plots_description = ""
for plot in plot_list:
plots_description += "Name: " + plot["name"]
plots_description += " - Description: " + plot["description"] + "\n"
prompt = (
"You are helping to answer a question with insightful visualizations.\n"
"You are given a user question and a list of plots with their name and description.\n"
"Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
"Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
"For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
"Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
f"### Descriptions of the plots : {plots_description}"
f"### User question : {user_question}\n"
f"### Names of the plots : "
)
plot_names = ast.literal_eval(
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
)
return plot_names
async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
print(f"---- Find location in user input ----")
location = await detect_location_with_openai(user_input)
output: Location = {
'location' : location,
'longitude' : None,
'latitude' : None,
'country_code' : None,
'country_name' : None,
'admin1' : None
}
if location:
coords = loc_to_coords(location)
country_code, country_name = coords_to_country(coords)
neighbour = nearest_neighbour_sql(coords, mode)
output.update({
"latitude": neighbour[0],
"longitude": neighbour[1],
"country_code": country_code,
"country_name": country_name,
"admin1": neighbour[2]
})
output = cast(Location, output)
return output
async def find_year(user_input: str) -> str| None:
"""Extracts year information from user input using LLM.
This function uses an LLM to identify and extract year information from the
user's query, which is used to filter data in subsequent queries.
Args:
user_input (str): The user's query text
Returns:
str: The extracted year, or empty string if no year found
"""
print(f"---- Find year ---")
year = await detect_year_with_openai(user_input)
if year == "":
return None
return year
async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
print("---- Find relevant plots ----")
relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
return relevant_plots
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]:
print(f"---- Find relevant tables for {plot['name']} ----")
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
return relevant_tables
async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
"""Perform the good method to retrieve the desired parameter
Args:
state (State): state of the workflow
param_name (str): name of the desired parameter
table (str): name of the table
Returns:
dict[str, Any] | None:
"""
if param_name == 'location':
location = await find_location(state['user_input'], mode)
return location
if param_name == 'year':
year = await find_year(state['user_input'])
return {'year': year}
return None