File size: 5,717 Bytes
711bc31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os

from typing import Any
import asyncio
from climateqa.engine.llm import get_llm
from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
from climateqa.engine.talk_to_data.objects.plot import Plot
from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
from climateqa.engine.talk_to_data.drias.config import DRIAS_TABLES, DRIAS_INDICATOR_COLUMNS_PER_TABLE, DRIAS_PLOT_PARAMETERS
from climateqa.engine.talk_to_data.drias.plots import DRIAS_PLOTS

ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))

async def process_output(
    output_title: str,
    table: str,
    plot: Plot,
    params: dict[str, Any]
) -> tuple[str, TTDOutput, dict[str, bool]]:
    """
    Processes a table for a given plot and parameters: builds the SQL query, executes it,
    and generates the corresponding figure.

    Args:
        output_title (str): Title for the output (used as key in outputs dict).
        table (str): The name of the table to process.
        plot (Plot): The plot object containing SQL query and visualization function.
        params (dict[str, Any]): Parameters used for querying the table.

    Returns:
        tuple: (output_title, results dict, errors dict)
    """
    results: TTDOutput = {
        'status': 'OK',
        'plot': plot,
        'table': table,
        'sql_query': None,
        'dataframe': None,
        'figure': None,
        'plot_information': None
    }
    errors = {
        'have_sql_query': False,
        'have_dataframe': False
    }

    # Find the indicator column for this table
    indicator_column = find_indicator_column(table, DRIAS_INDICATOR_COLUMNS_PER_TABLE)
    if indicator_column:
        params['indicator_column'] = indicator_column

    # Build the SQL query
    sql_query = plot['sql_query'](table, params)
    if not sql_query:
        results['status'] = 'ERROR'
        return output_title, results, errors

    results['plot_information'] = plot['plot_information'](table, params)

    results['sql_query'] = sql_query
    errors['have_sql_query'] = True

    # Execute the SQL query
    df = await execute_sql_query(sql_query)
    if df is not None and len(df) > 0:
        results['dataframe'] = df
        errors['have_dataframe'] = True
    else:
        results['status'] = 'NO_DATA'

    # Generate the figure (always, even if df is empty, for consistency)
    results['figure'] = plot['plot_function'](params)

    return output_title, results, errors

async def drias_workflow(user_input: str) -> State:
    """
    Orchestrates the DRIAS workflow: from user input to SQL queries, dataframes, and figures.

    Args:
        user_input (str): The user's question.

    Returns:
        State: Final state with all results and error messages if any.
    """
    state: State = {
        'user_input': user_input,
        'plots': [],
        'outputs': {},
        'error': ''
    }

    llm = get_llm(provider="openai")
    plots = await find_relevant_plots(state, llm, DRIAS_PLOTS)

    if not plots:
        state['error'] = 'There is no plot to answer to the question'
        return state

    plots = plots[:2]  # limit to 2 types of plots
    state['plots'] = plots

    errors = {
        'have_relevant_table': False,
        'have_sql_query': False,
        'have_dataframe': False
    }
    outputs = {}

    # Find relevant tables for each plot and prepare outputs
    for plot_name in plots:
        plot = next((p for p in DRIAS_PLOTS if p['name'] == plot_name), None)
        if plot is None:
            continue

        relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, DRIAS_TABLES)
        if relevant_tables:
            errors['have_relevant_table'] = True

        for table in relevant_tables:
            output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
            outputs[output_title] = {
                'table': table,
                'plot': plot,
                'status': 'OK'
            }

    # Gather all required parameters
    params = {}
    for param_name in DRIAS_PLOT_PARAMETERS:
        param = await find_param(state, param_name, mode='DRIAS')
        if param:
            params.update(param)

    # Process all outputs in parallel using process_output
    tasks = [
        process_output(output_title, output['table'], output['plot'], params.copy())
        for output_title, output in outputs.items()
    ]
    results = await asyncio.gather(*tasks)

    # Update outputs with results and error flags
    for output_title, task_results, task_errors in results:
        outputs[output_title]['sql_query'] = task_results['sql_query']
        outputs[output_title]['dataframe'] = task_results['dataframe']
        outputs[output_title]['figure'] = task_results['figure']
        outputs[output_title]['plot_information'] = task_results['plot_information']
        outputs[output_title]['status'] = task_results['status']
        errors['have_sql_query'] |= task_errors['have_sql_query']
        errors['have_dataframe'] |= task_errors['have_dataframe']

    state['outputs'] = outputs

    # Set error messages if needed
    if not errors['have_relevant_table']:
        state['error'] = "There is no relevant table in our database to answer your question"
    elif not errors['have_sql_query']:
        state['error'] = "There is no relevant sql query on our database that can help to answer your question"
    elif not errors['have_dataframe']:
        state['error'] = "There is no data in our table that can answer to your question"

    return state