File size: 9,988 Bytes
711bc31 f79bf0a 711bc31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
from typing import Any, Literal, Optional, cast
import ast
from langchain_core.prompts import ChatPromptTemplate
from geopy.geocoders import Nominatim
from climateqa.engine.llm import get_llm
import duckdb
import os
from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH
from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
from climateqa.engine.talk_to_data.objects.location import Location
from climateqa.engine.talk_to_data.objects.plot import Plot
from climateqa.engine.talk_to_data.objects.states import State
async def detect_location_with_openai(sentence: str) -> str:
"""
Detects locations in a sentence using OpenAI's API via LangChain.
"""
llm = get_llm()
prompt = f"""
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
Return the result as a Python list. If no locations are mentioned, return an empty list.
Sentence: "{sentence}"
"""
response = await llm.ainvoke(prompt)
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
if location_list:
return location_list[0]
else:
return ""
def loc_to_coords(location: str) -> tuple[float, float]:
"""Converts a location name to geographic coordinates.
This function uses the Nominatim geocoding service to convert
a location name (e.g., city name) to its latitude and longitude.
Args:
location (str): The name of the location to geocode
Returns:
tuple[float, float]: A tuple containing (latitude, longitude)
Raises:
AttributeError: If the location cannot be found
"""
geolocator = Nominatim(user_agent="city_to_latlong", timeout=5)
coords = geolocator.geocode(location)
return (coords.latitude, coords.longitude)
def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
"""Converts geographic coordinates to a country name.
This function uses the Nominatim reverse geocoding service to convert
latitude and longitude coordinates to a country name.
Args:
coords (tuple[float, float]): A tuple containing (latitude, longitude)
Returns:
tuple[str,str]: A tuple containg (country_code, country_name, admin1)
Raises:
AttributeError: If the coordinates cannot be found
"""
geolocator = Nominatim(user_agent="latlong_to_country")
location = geolocator.reverse(coords)
address = location.raw['address']
return address['country_code'].upper(), address['country']
def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
long = round(location[1], 3)
lat = round(location[0], 3)
conn = duckdb.connect()
if mode == 'DRIAS':
table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'"
results = conn.sql(
f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
).fetchdf()
else:
table_path = f"'{IPCC_COORDINATES_PATH}'"
results = conn.sql(
f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
).fetchdf()
if len(results) == 0:
return "", "", ""
if 'admin1' in results.columns:
admin1 = results['admin1'].iloc[0]
else:
admin1 = None
return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1
async def detect_year_with_openai(sentence: str) -> str:
"""
Detects years in a sentence using OpenAI's API via LangChain.
"""
llm = get_llm()
prompt = """
Extract all years mentioned in the following sentence.
Return the result as a Python list. If no year are mentioned, return an empty list.
Sentence: "{sentence}"
"""
prompt = ChatPromptTemplate.from_template(prompt)
structured_llm = llm.with_structured_output(ArrayOutput)
chain = prompt | structured_llm
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
years_list = ast.literal_eval(response['array'])
if len(years_list) > 0:
return years_list[0]
else:
return ""
async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
"""Identifies relevant tables for a plot based on user input.
This function uses an LLM to analyze the user's question and the plot
description to determine which tables in the DRIAS database would be
most relevant for generating the requested visualization.
Args:
user_question (str): The user's question about climate data
plot (Plot): The plot configuration object
llm: The language model instance to use for analysis
Returns:
list[str]: A list of table names that are relevant for the plot
Example:
>>> detect_relevant_tables(
... "What will the temperature be like in Paris?",
... indicator_evolution_at_location,
... llm
... )
['mean_annual_temperature', 'mean_summer_temperature']
"""
# Get all table names
prompt = (
f"You are helping to build a plot following this description : {plot['description']}."
f"You are given a list of tables and a user question."
f"Based on the description of the plot, which table are appropriate for that kind of plot."
f"Write the 3 most relevant tables to use. Answer only a python list of table name."
f"### List of tables : {table_names_list}"
f"### User question : {user_question}"
f"### List of table name : "
)
table_names = ast.literal_eval(
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
)
return table_names
async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]:
plots_description = ""
for plot in plot_list:
plots_description += "Name: " + plot["name"]
plots_description += " - Description: " + plot["description"] + "\n"
prompt = (
"You are helping to answer a question with insightful visualizations.\n"
"You are given a user question and a list of plots with their name and description.\n"
"Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
"Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
"For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
"Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
f"### Descriptions of the plots : {plots_description}"
f"### User question : {user_question}\n"
f"### Names of the plots : "
)
plot_names = ast.literal_eval(
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
)
return plot_names
async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
print(f"---- Find location in user input ----")
location = await detect_location_with_openai(user_input)
output: Location = {
'location' : location,
'longitude' : None,
'latitude' : None,
'country_code' : None,
'country_name' : None,
'admin1' : None
}
if location:
coords = loc_to_coords(location)
country_code, country_name = coords_to_country(coords)
neighbour = nearest_neighbour_sql(coords, mode)
output.update({
"latitude": neighbour[0],
"longitude": neighbour[1],
"country_code": country_code,
"country_name": country_name,
"admin1": neighbour[2]
})
output = cast(Location, output)
return output
async def find_year(user_input: str) -> str| None:
"""Extracts year information from user input using LLM.
This function uses an LLM to identify and extract year information from the
user's query, which is used to filter data in subsequent queries.
Args:
user_input (str): The user's query text
Returns:
str: The extracted year, or empty string if no year found
"""
print(f"---- Find year ---")
year = await detect_year_with_openai(user_input)
if year == "":
return None
return year
async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
print("---- Find relevant plots ----")
relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
return relevant_plots
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]:
print(f"---- Find relevant tables for {plot['name']} ----")
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
return relevant_tables
async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
"""Perform the good method to retrieve the desired parameter
Args:
state (State): state of the workflow
param_name (str): name of the desired parameter
table (str): name of the table
Returns:
dict[str, Any] | None:
"""
if param_name == 'location':
location = await find_location(state['user_input'], mode)
return location
if param_name == 'year':
year = await find_year(state['user_input'])
return {'year': year}
return None
|