File size: 5,672 Bytes
711bc31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
from typing import Any
import asyncio
from climateqa.engine.llm import get_llm
from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
from climateqa.engine.talk_to_data.objects.plot import Plot
from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
from climateqa.engine.talk_to_data.ipcc.config import IPCC_TABLES, IPCC_INDICATOR_COLUMNS_PER_TABLE, IPCC_PLOT_PARAMETERS
from climateqa.engine.talk_to_data.ipcc.plots import IPCC_PLOTS
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
async def process_output(
output_title: str,
table: str,
plot: Plot,
params: dict[str, Any]
) -> tuple[str, TTDOutput, dict[str, bool]]:
"""
Process a table for a given plot and parameters: builds the SQL query, executes it,
and generates the corresponding figure.
Args:
output_title (str): Title for the output (used as key in outputs dict).
table (str): The name of the table to process.
plot (Plot): The plot object containing SQL query and visualization function.
params (dict[str, Any]): Parameters used for querying the table.
Returns:
tuple: (output_title, results dict, errors dict)
"""
results: TTDOutput = {
'status': 'OK',
'plot': plot,
'table': table,
'sql_query': None,
'dataframe': None,
'figure': None,
'plot_information': None,
}
errors = {
'have_sql_query': False,
'have_dataframe': False
}
# Find the indicator column for this table
indicator_column = find_indicator_column(table, IPCC_INDICATOR_COLUMNS_PER_TABLE)
if indicator_column:
params['indicator_column'] = indicator_column
# Build the SQL query
sql_query = plot['sql_query'](table, params)
if not sql_query:
results['status'] = 'ERROR'
return output_title, results, errors
results['plot_information'] = plot['plot_information'](table, params)
results['sql_query'] = sql_query
errors['have_sql_query'] = True
# Execute the SQL query
df = await execute_sql_query(sql_query)
if df is not None and not df.empty:
results['dataframe'] = df
errors['have_dataframe'] = True
else:
results['status'] = 'NO_DATA'
# Generate the figure (always, even if df is empty, for consistency)
results['figure'] = plot['plot_function'](params)
return output_title, results, errors
async def ipcc_workflow(user_input: str) -> State:
"""
Performs the complete workflow of Talk To IPCC: from user input to SQL queries, dataframes, and figures.
Args:
user_input (str): The user's question.
Returns:
State: Final state with all the results and error messages if any.
"""
state: State = {
'user_input': user_input,
'plots': [],
'outputs': {},
'error': ''
}
llm = get_llm(provider="openai")
plots = await find_relevant_plots(state, llm, IPCC_PLOTS)
state['plots'] = plots
if not plots:
state['error'] = 'There is no plot to answer to the question'
return state
errors = {
'have_relevant_table': False,
'have_sql_query': False,
'have_dataframe': False
}
outputs = {}
# Find relevant tables for each plot and prepare outputs
for plot_name in plots:
plot = next((p for p in IPCC_PLOTS if p['name'] == plot_name), None)
if plot is None:
continue
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, IPCC_TABLES)
if relevant_tables:
errors['have_relevant_table'] = True
for table in relevant_tables:
output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
outputs[output_title] = {
'table': table,
'plot': plot,
'status': 'OK'
}
# Gather all required parameters
params = {}
for param_name in IPCC_PLOT_PARAMETERS:
param = await find_param(state, param_name, mode='IPCC')
if param:
params.update(param)
# Process all outputs in parallel using process_output
tasks = [
process_output(output_title, output['table'], output['plot'], params.copy())
for output_title, output in outputs.items()
]
results = await asyncio.gather(*tasks)
# Update outputs with results and error flags
for output_title, task_results, task_errors in results:
outputs[output_title]['sql_query'] = task_results['sql_query']
outputs[output_title]['dataframe'] = task_results['dataframe']
outputs[output_title]['figure'] = task_results['figure']
outputs[output_title]['plot_information'] = task_results['plot_information']
outputs[output_title]['status'] = task_results['status']
errors['have_sql_query'] |= task_errors['have_sql_query']
errors['have_dataframe'] |= task_errors['have_dataframe']
state['outputs'] = outputs
# Set error messages if needed
if not errors['have_relevant_table']:
state['error'] = "There is no relevant table in our database to answer your question"
elif not errors['have_sql_query']:
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
elif not errors['have_dataframe']:
state['error'] = "There is no data in our table that can answer to your question"
return state |