import os from typing import Any import asyncio from climateqa.engine.llm import get_llm from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column from climateqa.engine.talk_to_data.objects.plot import Plot from climateqa.engine.talk_to_data.objects.states import State, TTDOutput from climateqa.engine.talk_to_data.drias.config import DRIAS_TABLES, DRIAS_INDICATOR_COLUMNS_PER_TABLE, DRIAS_PLOT_PARAMETERS from climateqa.engine.talk_to_data.drias.plots import DRIAS_PLOTS ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd())) async def process_output( output_title: str, table: str, plot: Plot, params: dict[str, Any] ) -> tuple[str, TTDOutput, dict[str, bool]]: """ Processes a table for a given plot and parameters: builds the SQL query, executes it, and generates the corresponding figure. Args: output_title (str): Title for the output (used as key in outputs dict). table (str): The name of the table to process. plot (Plot): The plot object containing SQL query and visualization function. params (dict[str, Any]): Parameters used for querying the table. Returns: tuple: (output_title, results dict, errors dict) """ results: TTDOutput = { 'status': 'OK', 'plot': plot, 'table': table, 'sql_query': None, 'dataframe': None, 'figure': None, 'plot_information': None } errors = { 'have_sql_query': False, 'have_dataframe': False } # Find the indicator column for this table indicator_column = find_indicator_column(table, DRIAS_INDICATOR_COLUMNS_PER_TABLE) if indicator_column: params['indicator_column'] = indicator_column # Build the SQL query sql_query = plot['sql_query'](table, params) if not sql_query: results['status'] = 'ERROR' return output_title, results, errors results['plot_information'] = plot['plot_information'](table, params) results['sql_query'] = sql_query errors['have_sql_query'] = True # Execute the SQL query df = await execute_sql_query(sql_query) if df is not None and len(df) > 0: results['dataframe'] = df errors['have_dataframe'] = True else: results['status'] = 'NO_DATA' # Generate the figure (always, even if df is empty, for consistency) results['figure'] = plot['plot_function'](params) return output_title, results, errors async def drias_workflow(user_input: str) -> State: """ Orchestrates the DRIAS workflow: from user input to SQL queries, dataframes, and figures. Args: user_input (str): The user's question. Returns: State: Final state with all results and error messages if any. """ state: State = { 'user_input': user_input, 'plots': [], 'outputs': {}, 'error': '' } llm = get_llm(provider="openai") plots = await find_relevant_plots(state, llm, DRIAS_PLOTS) if not plots: state['error'] = 'There is no plot to answer to the question' return state plots = plots[:2] # limit to 2 types of plots state['plots'] = plots errors = { 'have_relevant_table': False, 'have_sql_query': False, 'have_dataframe': False } outputs = {} # Find relevant tables for each plot and prepare outputs for plot_name in plots: plot = next((p for p in DRIAS_PLOTS if p['name'] == plot_name), None) if plot is None: continue relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, DRIAS_TABLES) if relevant_tables: errors['have_relevant_table'] = True for table in relevant_tables: output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}" outputs[output_title] = { 'table': table, 'plot': plot, 'status': 'OK' } # Gather all required parameters params = {} for param_name in DRIAS_PLOT_PARAMETERS: param = await find_param(state, param_name, mode='DRIAS') if param: params.update(param) # Process all outputs in parallel using process_output tasks = [ process_output(output_title, output['table'], output['plot'], params.copy()) for output_title, output in outputs.items() ] results = await asyncio.gather(*tasks) # Update outputs with results and error flags for output_title, task_results, task_errors in results: outputs[output_title]['sql_query'] = task_results['sql_query'] outputs[output_title]['dataframe'] = task_results['dataframe'] outputs[output_title]['figure'] = task_results['figure'] outputs[output_title]['plot_information'] = task_results['plot_information'] outputs[output_title]['status'] = task_results['status'] errors['have_sql_query'] |= task_errors['have_sql_query'] errors['have_dataframe'] |= task_errors['have_dataframe'] state['outputs'] = outputs # Set error messages if needed if not errors['have_relevant_table']: state['error'] = "There is no relevant table in our database to answer your question" elif not errors['have_sql_query']: state['error'] = "There is no relevant sql query on our database that can help to answer your question" elif not errors['have_dataframe']: state['error'] = "There is no data in our table that can answer to your question" return state