|
from climateqa.engine.talk_to_data.workflow.drias import drias_workflow |
|
from climateqa.engine.talk_to_data.workflow.ipcc import ipcc_workflow |
|
from climateqa.logging import log_drias_interaction_to_huggingface |
|
|
|
async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None) -> tuple: |
|
"""Main function to process a DRIAS query and return results. |
|
|
|
This function orchestrates the DRIAS workflow, processing a user query to generate |
|
SQL queries, dataframes, and visualizations. It handles multiple results and allows |
|
pagination through them. |
|
|
|
Args: |
|
query (str): The user's question about climate data |
|
index_state (int, optional): The index of the result to return. Defaults to 0. |
|
|
|
Returns: |
|
tuple: A tuple containing: |
|
- sql_query (str): The SQL query used |
|
- dataframe (pd.DataFrame): The resulting data |
|
- figure (Callable): Function to generate the visualization |
|
- sql_queries (list): All generated SQL queries |
|
- result_dataframes (list): All resulting dataframes |
|
- figures (list): All figure generation functions |
|
- index_state (int): Current result index |
|
- table_list (list): List of table names used |
|
- error (str): Error message if any |
|
""" |
|
final_state = await drias_workflow(query) |
|
sql_queries = [] |
|
result_dataframes = [] |
|
figures = [] |
|
plot_title_list = [] |
|
plot_informations = [] |
|
|
|
for output_title, output in final_state['outputs'].items(): |
|
if output['status'] == 'OK': |
|
if output['table'] is not None: |
|
plot_title_list.append(output_title) |
|
|
|
if output['plot_information'] is not None: |
|
plot_informations.append(output['plot_information']) |
|
|
|
if output['sql_query'] is not None: |
|
sql_queries.append(output['sql_query']) |
|
|
|
if output['dataframe'] is not None: |
|
result_dataframes.append(output['dataframe']) |
|
if output['figure'] is not None: |
|
figures.append(output['figure']) |
|
|
|
if "error" in final_state and final_state["error"] != "": |
|
|
|
return None, None, None, None, [], [], [], 0, [], final_state["error"] |
|
|
|
sql_query = sql_queries[index_state] |
|
dataframe = result_dataframes[index_state] |
|
figure = figures[index_state](dataframe) |
|
plot_information = plot_informations[index_state] |
|
|
|
|
|
log_drias_interaction_to_huggingface(query, sql_query, user_id) |
|
|
|
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, "" |
|
|
|
|
|
|
|
async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None) -> tuple: |
|
"""Main function to process a DRIAS query and return results. |
|
|
|
This function orchestrates the DRIAS workflow, processing a user query to generate |
|
SQL queries, dataframes, and visualizations. It handles multiple results and allows |
|
pagination through them. |
|
|
|
Args: |
|
query (str): The user's question about climate data |
|
index_state (int, optional): The index of the result to return. Defaults to 0. |
|
|
|
Returns: |
|
tuple: A tuple containing: |
|
- sql_query (str): The SQL query used |
|
- dataframe (pd.DataFrame): The resulting data |
|
- figure (Callable): Function to generate the visualization |
|
- sql_queries (list): All generated SQL queries |
|
- result_dataframes (list): All resulting dataframes |
|
- figures (list): All figure generation functions |
|
- index_state (int): Current result index |
|
- table_list (list): List of table names used |
|
- error (str): Error message if any |
|
""" |
|
final_state = await ipcc_workflow(query) |
|
sql_queries = [] |
|
result_dataframes = [] |
|
figures = [] |
|
plot_title_list = [] |
|
plot_informations = [] |
|
|
|
for output_title, output in final_state['outputs'].items(): |
|
if output['status'] == 'OK': |
|
if output['table'] is not None: |
|
plot_title_list.append(output_title) |
|
|
|
if output['plot_information'] is not None: |
|
plot_informations.append(output['plot_information']) |
|
|
|
if output['sql_query'] is not None: |
|
sql_queries.append(output['sql_query']) |
|
|
|
if output['dataframe'] is not None: |
|
result_dataframes.append(output['dataframe']) |
|
if output['figure'] is not None: |
|
figures.append(output['figure']) |
|
|
|
if "error" in final_state and final_state["error"] != "": |
|
|
|
return None, None, None, None, [], [], [], 0, [], final_state["error"] |
|
|
|
sql_query = sql_queries[index_state] |
|
dataframe = result_dataframes[index_state] |
|
figure = figures[index_state](dataframe) |
|
plot_information = plot_informations[index_state] |
|
|
|
log_drias_interaction_to_huggingface(query, sql_query, user_id) |
|
|
|
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, "" |
|
|