from typing import Any, Literal, Optional, cast import ast from langchain_core.prompts import ChatPromptTemplate from geopy.geocoders import Nominatim from climateqa.engine.llm import get_llm import duckdb import os from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput from climateqa.engine.talk_to_data.objects.location import Location from climateqa.engine.talk_to_data.objects.plot import Plot from climateqa.engine.talk_to_data.objects.states import State async def detect_location_with_openai(sentence: str) -> str: """ Detects locations in a sentence using OpenAI's API via LangChain. """ llm = get_llm() prompt = f""" Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence. Return the result as a Python list. If no locations are mentioned, return an empty list. Sentence: "{sentence}" """ response = await llm.ainvoke(prompt) location_list = ast.literal_eval(response.content.strip("```python\n").strip()) if location_list: return location_list[0] else: return "" def loc_to_coords(location: str) -> tuple[float, float]: """Converts a location name to geographic coordinates. This function uses the Nominatim geocoding service to convert a location name (e.g., city name) to its latitude and longitude. Args: location (str): The name of the location to geocode Returns: tuple[float, float]: A tuple containing (latitude, longitude) Raises: AttributeError: If the location cannot be found """ geolocator = Nominatim(user_agent="city_to_latlong", timeout=5) coords = geolocator.geocode(location) return (coords.latitude, coords.longitude) def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]: """Converts geographic coordinates to a country name. This function uses the Nominatim reverse geocoding service to convert latitude and longitude coordinates to a country name. Args: coords (tuple[float, float]): A tuple containing (latitude, longitude) Returns: tuple[str,str]: A tuple containg (country_code, country_name, admin1) Raises: AttributeError: If the coordinates cannot be found """ geolocator = Nominatim(user_agent="latlong_to_country") location = geolocator.reverse(coords) address = location.raw['address'] return address['country_code'].upper(), address['country'] def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]: long = round(location[1], 3) lat = round(location[0], 3) conn = duckdb.connect() if mode == 'DRIAS': table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'" results = conn.sql( f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}" ).fetchdf() else: table_path = f"'{IPCC_COORDINATES_PATH}'" results = conn.sql( f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}" ).fetchdf() if len(results) == 0: return "", "", "" if 'admin1' in results.columns: admin1 = results['admin1'].iloc[0] else: admin1 = None return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1 async def detect_year_with_openai(sentence: str) -> str: """ Detects years in a sentence using OpenAI's API via LangChain. """ llm = get_llm() prompt = """ Extract all years mentioned in the following sentence. Return the result as a Python list. If no year are mentioned, return an empty list. Sentence: "{sentence}" """ prompt = ChatPromptTemplate.from_template(prompt) structured_llm = llm.with_structured_output(ArrayOutput) chain = prompt | structured_llm response: ArrayOutput = await chain.ainvoke({"sentence": sentence}) years_list = ast.literal_eval(response['array']) if len(years_list) > 0: return years_list[0] else: return "" async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]: """Identifies relevant tables for a plot based on user input. This function uses an LLM to analyze the user's question and the plot description to determine which tables in the DRIAS database would be most relevant for generating the requested visualization. Args: user_question (str): The user's question about climate data plot (Plot): The plot configuration object llm: The language model instance to use for analysis Returns: list[str]: A list of table names that are relevant for the plot Example: >>> detect_relevant_tables( ... "What will the temperature be like in Paris?", ... indicator_evolution_at_location, ... llm ... ) ['mean_annual_temperature', 'mean_summer_temperature'] """ # Get all table names prompt = ( f"You are helping to build a plot following this description : {plot['description']}." f"You are given a list of tables and a user question." f"Based on the description of the plot, which table are appropriate for that kind of plot." f"Write the 3 most relevant tables to use. Answer only a python list of table name." f"### List of tables : {table_names_list}" f"### User question : {user_question}" f"### List of table name : " ) table_names = ast.literal_eval( (await llm.ainvoke(prompt)).content.strip("```python\n").strip() ) return table_names async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]: plots_description = "" for plot in plot_list: plots_description += "Name: " + plot["name"] plots_description += " - Description: " + plot["description"] + "\n" prompt = ( "You are helping to answer a question with insightful visualizations.\n" "You are given a user question and a list of plots with their name and description.\n" "Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. " "Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n" "For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n" "Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n" f"### Descriptions of the plots : {plots_description}" f"### User question : {user_question}\n" f"### Names of the plots : " ) plot_names = ast.literal_eval( (await llm.ainvoke(prompt)).content.strip("```python\n").strip() ) return plot_names async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location: print(f"---- Find location in user input ----") location = await detect_location_with_openai(user_input) output: Location = { 'location' : location, 'longitude' : None, 'latitude' : None, 'country_code' : None, 'country_name' : None, 'admin1' : None } if location: coords = loc_to_coords(location) country_code, country_name = coords_to_country(coords) neighbour = nearest_neighbour_sql(coords, mode) output.update({ "latitude": neighbour[0], "longitude": neighbour[1], "country_code": country_code, "country_name": country_name, "admin1": neighbour[2] }) output = cast(Location, output) return output async def find_year(user_input: str) -> str| None: """Extracts year information from user input using LLM. This function uses an LLM to identify and extract year information from the user's query, which is used to filter data in subsequent queries. Args: user_input (str): The user's query text Returns: str: The extracted year, or empty string if no year found """ print(f"---- Find year ---") year = await detect_year_with_openai(user_input) if year == "": return None return year async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]: print("---- Find relevant plots ----") relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots) return relevant_plots async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]: print(f"---- Find relevant tables for {plot['name']} ----") relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables) return relevant_tables async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None: """Perform the good method to retrieve the desired parameter Args: state (State): state of the workflow param_name (str): name of the desired parameter table (str): name of the table Returns: dict[str, Any] | None: """ if param_name == 'location': location = await find_location(state['user_input'], mode) return location if param_name == 'year': year = await find_year(state['user_input']) return {'year': year} return None