|
import os |
|
|
|
from typing import Any |
|
import asyncio |
|
from climateqa.engine.llm import get_llm |
|
from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot |
|
from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column |
|
from climateqa.engine.talk_to_data.objects.plot import Plot |
|
from climateqa.engine.talk_to_data.objects.states import State, TTDOutput |
|
from climateqa.engine.talk_to_data.ipcc.config import IPCC_TABLES, IPCC_INDICATOR_COLUMNS_PER_TABLE, IPCC_PLOT_PARAMETERS |
|
from climateqa.engine.talk_to_data.ipcc.plots import IPCC_PLOTS |
|
|
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd())) |
|
|
|
async def process_output( |
|
output_title: str, |
|
table: str, |
|
plot: Plot, |
|
params: dict[str, Any] |
|
) -> tuple[str, TTDOutput, dict[str, bool]]: |
|
""" |
|
Process a table for a given plot and parameters: builds the SQL query, executes it, |
|
and generates the corresponding figure. |
|
|
|
Args: |
|
output_title (str): Title for the output (used as key in outputs dict). |
|
table (str): The name of the table to process. |
|
plot (Plot): The plot object containing SQL query and visualization function. |
|
params (dict[str, Any]): Parameters used for querying the table. |
|
|
|
Returns: |
|
tuple: (output_title, results dict, errors dict) |
|
""" |
|
results: TTDOutput = { |
|
'status': 'OK', |
|
'plot': plot, |
|
'table': table, |
|
'sql_query': None, |
|
'dataframe': None, |
|
'figure': None, |
|
'plot_information': None, |
|
} |
|
errors = { |
|
'have_sql_query': False, |
|
'have_dataframe': False |
|
} |
|
|
|
|
|
indicator_column = find_indicator_column(table, IPCC_INDICATOR_COLUMNS_PER_TABLE) |
|
if indicator_column: |
|
params['indicator_column'] = indicator_column |
|
|
|
|
|
sql_query = plot['sql_query'](table, params) |
|
if not sql_query: |
|
results['status'] = 'ERROR' |
|
return output_title, results, errors |
|
|
|
results['plot_information'] = plot['plot_information'](table, params) |
|
|
|
results['sql_query'] = sql_query |
|
errors['have_sql_query'] = True |
|
|
|
|
|
df = await execute_sql_query(sql_query) |
|
if df is not None and not df.empty: |
|
results['dataframe'] = df |
|
errors['have_dataframe'] = True |
|
else: |
|
results['status'] = 'NO_DATA' |
|
|
|
|
|
results['figure'] = plot['plot_function'](params) |
|
|
|
return output_title, results, errors |
|
|
|
async def ipcc_workflow(user_input: str) -> State: |
|
""" |
|
Performs the complete workflow of Talk To IPCC: from user input to SQL queries, dataframes, and figures. |
|
|
|
Args: |
|
user_input (str): The user's question. |
|
|
|
Returns: |
|
State: Final state with all the results and error messages if any. |
|
""" |
|
state: State = { |
|
'user_input': user_input, |
|
'plots': [], |
|
'outputs': {}, |
|
'error': '' |
|
} |
|
|
|
llm = get_llm(provider="openai") |
|
plots = await find_relevant_plots(state, llm, IPCC_PLOTS) |
|
state['plots'] = plots |
|
|
|
if not plots: |
|
state['error'] = 'There is no plot to answer to the question' |
|
return state |
|
|
|
errors = { |
|
'have_relevant_table': False, |
|
'have_sql_query': False, |
|
'have_dataframe': False |
|
} |
|
outputs = {} |
|
|
|
|
|
for plot_name in plots: |
|
plot = next((p for p in IPCC_PLOTS if p['name'] == plot_name), None) |
|
if plot is None: |
|
continue |
|
|
|
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, IPCC_TABLES) |
|
if relevant_tables: |
|
errors['have_relevant_table'] = True |
|
|
|
for table in relevant_tables: |
|
output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}" |
|
outputs[output_title] = { |
|
'table': table, |
|
'plot': plot, |
|
'status': 'OK' |
|
} |
|
|
|
|
|
params = {} |
|
for param_name in IPCC_PLOT_PARAMETERS: |
|
param = await find_param(state, param_name, mode='IPCC') |
|
if param: |
|
params.update(param) |
|
|
|
|
|
tasks = [ |
|
process_output(output_title, output['table'], output['plot'], params.copy()) |
|
for output_title, output in outputs.items() |
|
] |
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
for output_title, task_results, task_errors in results: |
|
outputs[output_title]['sql_query'] = task_results['sql_query'] |
|
outputs[output_title]['dataframe'] = task_results['dataframe'] |
|
outputs[output_title]['figure'] = task_results['figure'] |
|
outputs[output_title]['plot_information'] = task_results['plot_information'] |
|
outputs[output_title]['status'] = task_results['status'] |
|
errors['have_sql_query'] |= task_errors['have_sql_query'] |
|
errors['have_dataframe'] |= task_errors['have_dataframe'] |
|
|
|
state['outputs'] = outputs |
|
|
|
|
|
if not errors['have_relevant_table']: |
|
state['error'] = "There is no relevant table in our database to answer your question" |
|
elif not errors['have_sql_query']: |
|
state['error'] = "There is no relevant sql query on our database that can help to answer your question" |
|
elif not errors['have_dataframe']: |
|
state['error'] = "There is no data in our table that can answer to your question" |
|
|
|
return state |