timeki's picture
talk_to_ipcc (#29)
711bc31 verified
import os
from typing import Any
import asyncio
from climateqa.engine.llm import get_llm
from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
from climateqa.engine.talk_to_data.objects.plot import Plot
from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
from climateqa.engine.talk_to_data.drias.config import DRIAS_TABLES, DRIAS_INDICATOR_COLUMNS_PER_TABLE, DRIAS_PLOT_PARAMETERS
from climateqa.engine.talk_to_data.drias.plots import DRIAS_PLOTS
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
async def process_output(
output_title: str,
table: str,
plot: Plot,
params: dict[str, Any]
) -> tuple[str, TTDOutput, dict[str, bool]]:
"""
Processes a table for a given plot and parameters: builds the SQL query, executes it,
and generates the corresponding figure.
Args:
output_title (str): Title for the output (used as key in outputs dict).
table (str): The name of the table to process.
plot (Plot): The plot object containing SQL query and visualization function.
params (dict[str, Any]): Parameters used for querying the table.
Returns:
tuple: (output_title, results dict, errors dict)
"""
results: TTDOutput = {
'status': 'OK',
'plot': plot,
'table': table,
'sql_query': None,
'dataframe': None,
'figure': None,
'plot_information': None
}
errors = {
'have_sql_query': False,
'have_dataframe': False
}
# Find the indicator column for this table
indicator_column = find_indicator_column(table, DRIAS_INDICATOR_COLUMNS_PER_TABLE)
if indicator_column:
params['indicator_column'] = indicator_column
# Build the SQL query
sql_query = plot['sql_query'](table, params)
if not sql_query:
results['status'] = 'ERROR'
return output_title, results, errors
results['plot_information'] = plot['plot_information'](table, params)
results['sql_query'] = sql_query
errors['have_sql_query'] = True
# Execute the SQL query
df = await execute_sql_query(sql_query)
if df is not None and len(df) > 0:
results['dataframe'] = df
errors['have_dataframe'] = True
else:
results['status'] = 'NO_DATA'
# Generate the figure (always, even if df is empty, for consistency)
results['figure'] = plot['plot_function'](params)
return output_title, results, errors
async def drias_workflow(user_input: str) -> State:
"""
Orchestrates the DRIAS workflow: from user input to SQL queries, dataframes, and figures.
Args:
user_input (str): The user's question.
Returns:
State: Final state with all results and error messages if any.
"""
state: State = {
'user_input': user_input,
'plots': [],
'outputs': {},
'error': ''
}
llm = get_llm(provider="openai")
plots = await find_relevant_plots(state, llm, DRIAS_PLOTS)
if not plots:
state['error'] = 'There is no plot to answer to the question'
return state
plots = plots[:2] # limit to 2 types of plots
state['plots'] = plots
errors = {
'have_relevant_table': False,
'have_sql_query': False,
'have_dataframe': False
}
outputs = {}
# Find relevant tables for each plot and prepare outputs
for plot_name in plots:
plot = next((p for p in DRIAS_PLOTS if p['name'] == plot_name), None)
if plot is None:
continue
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, DRIAS_TABLES)
if relevant_tables:
errors['have_relevant_table'] = True
for table in relevant_tables:
output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
outputs[output_title] = {
'table': table,
'plot': plot,
'status': 'OK'
}
# Gather all required parameters
params = {}
for param_name in DRIAS_PLOT_PARAMETERS:
param = await find_param(state, param_name, mode='DRIAS')
if param:
params.update(param)
# Process all outputs in parallel using process_output
tasks = [
process_output(output_title, output['table'], output['plot'], params.copy())
for output_title, output in outputs.items()
]
results = await asyncio.gather(*tasks)
# Update outputs with results and error flags
for output_title, task_results, task_errors in results:
outputs[output_title]['sql_query'] = task_results['sql_query']
outputs[output_title]['dataframe'] = task_results['dataframe']
outputs[output_title]['figure'] = task_results['figure']
outputs[output_title]['plot_information'] = task_results['plot_information']
outputs[output_title]['status'] = task_results['status']
errors['have_sql_query'] |= task_errors['have_sql_query']
errors['have_dataframe'] |= task_errors['have_dataframe']
state['outputs'] = outputs
# Set error messages if needed
if not errors['have_relevant_table']:
state['error'] = "There is no relevant table in our database to answer your question"
elif not errors['have_sql_query']:
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
elif not errors['have_dataframe']:
state['error'] = "There is no data in our table that can answer to your question"
return state