File size: 5,717 Bytes
711bc31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
from typing import Any
import asyncio
from climateqa.engine.llm import get_llm
from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
from climateqa.engine.talk_to_data.objects.plot import Plot
from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
from climateqa.engine.talk_to_data.drias.config import DRIAS_TABLES, DRIAS_INDICATOR_COLUMNS_PER_TABLE, DRIAS_PLOT_PARAMETERS
from climateqa.engine.talk_to_data.drias.plots import DRIAS_PLOTS
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
async def process_output(
output_title: str,
table: str,
plot: Plot,
params: dict[str, Any]
) -> tuple[str, TTDOutput, dict[str, bool]]:
"""
Processes a table for a given plot and parameters: builds the SQL query, executes it,
and generates the corresponding figure.
Args:
output_title (str): Title for the output (used as key in outputs dict).
table (str): The name of the table to process.
plot (Plot): The plot object containing SQL query and visualization function.
params (dict[str, Any]): Parameters used for querying the table.
Returns:
tuple: (output_title, results dict, errors dict)
"""
results: TTDOutput = {
'status': 'OK',
'plot': plot,
'table': table,
'sql_query': None,
'dataframe': None,
'figure': None,
'plot_information': None
}
errors = {
'have_sql_query': False,
'have_dataframe': False
}
# Find the indicator column for this table
indicator_column = find_indicator_column(table, DRIAS_INDICATOR_COLUMNS_PER_TABLE)
if indicator_column:
params['indicator_column'] = indicator_column
# Build the SQL query
sql_query = plot['sql_query'](table, params)
if not sql_query:
results['status'] = 'ERROR'
return output_title, results, errors
results['plot_information'] = plot['plot_information'](table, params)
results['sql_query'] = sql_query
errors['have_sql_query'] = True
# Execute the SQL query
df = await execute_sql_query(sql_query)
if df is not None and len(df) > 0:
results['dataframe'] = df
errors['have_dataframe'] = True
else:
results['status'] = 'NO_DATA'
# Generate the figure (always, even if df is empty, for consistency)
results['figure'] = plot['plot_function'](params)
return output_title, results, errors
async def drias_workflow(user_input: str) -> State:
"""
Orchestrates the DRIAS workflow: from user input to SQL queries, dataframes, and figures.
Args:
user_input (str): The user's question.
Returns:
State: Final state with all results and error messages if any.
"""
state: State = {
'user_input': user_input,
'plots': [],
'outputs': {},
'error': ''
}
llm = get_llm(provider="openai")
plots = await find_relevant_plots(state, llm, DRIAS_PLOTS)
if not plots:
state['error'] = 'There is no plot to answer to the question'
return state
plots = plots[:2] # limit to 2 types of plots
state['plots'] = plots
errors = {
'have_relevant_table': False,
'have_sql_query': False,
'have_dataframe': False
}
outputs = {}
# Find relevant tables for each plot and prepare outputs
for plot_name in plots:
plot = next((p for p in DRIAS_PLOTS if p['name'] == plot_name), None)
if plot is None:
continue
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, DRIAS_TABLES)
if relevant_tables:
errors['have_relevant_table'] = True
for table in relevant_tables:
output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
outputs[output_title] = {
'table': table,
'plot': plot,
'status': 'OK'
}
# Gather all required parameters
params = {}
for param_name in DRIAS_PLOT_PARAMETERS:
param = await find_param(state, param_name, mode='DRIAS')
if param:
params.update(param)
# Process all outputs in parallel using process_output
tasks = [
process_output(output_title, output['table'], output['plot'], params.copy())
for output_title, output in outputs.items()
]
results = await asyncio.gather(*tasks)
# Update outputs with results and error flags
for output_title, task_results, task_errors in results:
outputs[output_title]['sql_query'] = task_results['sql_query']
outputs[output_title]['dataframe'] = task_results['dataframe']
outputs[output_title]['figure'] = task_results['figure']
outputs[output_title]['plot_information'] = task_results['plot_information']
outputs[output_title]['status'] = task_results['status']
errors['have_sql_query'] |= task_errors['have_sql_query']
errors['have_dataframe'] |= task_errors['have_dataframe']
state['outputs'] = outputs
# Set error messages if needed
if not errors['have_relevant_table']:
state['error'] = "There is no relevant table in our database to answer your question"
elif not errors['have_sql_query']:
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
elif not errors['have_dataframe']:
state['error'] = "There is no data in our table that can answer to your question"
return state |