Spaces:
Sleeping
Sleeping
# ui/callbacks.py | |
# -*- coding: utf-8 -*- | |
# | |
# PROJECT: CognitiveEDA v5.5 - The QuantumLeap Intelligence Platform | |
# | |
# DESCRIPTION: This module contains the core logic for all Gradio event handlers. | |
# It is designed to be completely decoupled from the UI definition. | |
# Functions here return values in a specific order (often as tuples) | |
# that correspond to a list of output components defined in app.py. | |
import gradio as gr | |
import pandas as pd | |
import logging | |
from threading import Thread | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from core.analyzer import DataAnalyzer | |
from core.llm import GeminiNarrativeGenerator | |
from core.config import settings | |
from core.exceptions import DataProcessingError | |
from modules.clustering import perform_clustering | |
# --- Primary Analysis Chain --- | |
def run_initial_analysis(file_obj, progress=gr.Progress(track_tqdm=True)): | |
""" | |
Phase 1: Fast, synchronous tasks. | |
Validates inputs, loads data, and creates the core DataAnalyzer object. This | |
function updates the gr.State object, which then triggers the next phase. | |
Args: | |
file_obj: The uploaded file object from Gradio. | |
progress: The Gradio progress tracker. | |
Returns: | |
The instantiated DataAnalyzer object, or None if processing fails. | |
""" | |
if file_obj is None: | |
raise gr.Error("No file uploaded. Please upload a CSV or Excel file.") | |
progress(0, desc="Validating configuration...") | |
if not settings.GOOGLE_API_KEY: | |
logging.error("Analysis attempted without GOOGLE_API_KEY set.") | |
raise gr.Error("CRITICAL: GOOGLE_API_KEY is not configured. Please add it as a secret.") | |
try: | |
progress(0.2, desc="Loading and parsing data file...") | |
df = pd.read_csv(file_obj.name) if file_obj.name.endswith('.csv') else pd.read_excel(file_obj.name) | |
if len(df) > settings.MAX_UI_ROWS: | |
df = df.sample(n=settings.MAX_UI_ROWS, random_state=42) | |
logging.info(f"DataFrame sampled down to {settings.MAX_UI_ROWS} rows.") | |
progress(0.7, desc="Instantiating analysis engine...") | |
analyzer = DataAnalyzer(df) | |
progress(1.0, desc="Initial analysis complete. Generating reports...") | |
return analyzer | |
except Exception as e: | |
logging.error(f"A critical error occurred during initial analysis: {e}", exc_info=True) | |
raise gr.Error(f"Analysis Failed! An unexpected error occurred: {str(e)}") | |
def generate_reports_and_visuals(analyzer, progress=gr.Progress(track_tqdm=True)): | |
""" | |
Phase 2: Slower, multi-stage report and visual generation. | |
This generator function yields tuples of UI updates. The order of the yielded | |
tuple is CRITICAL and must exactly match the `main_outputs` list in `app.py`. | |
Args: | |
analyzer: The DataAnalyzer object from the gr.State. | |
progress: The Gradio progress tracker. | |
Yields: | |
A tuple of gr.update() objects to populate the dashboard. | |
""" | |
if not isinstance(analyzer, DataAnalyzer): | |
logging.warning("generate_reports_and_visuals called without a valid analyzer. Clearing UI.") | |
# Return a tuple of Nones matching the output length to clear/reset the UI. | |
# There are 14 components in the `main_outputs` list in app.py. | |
yield (None,) * 14 | |
return | |
# 1. Start AI narrative generation in a background thread | |
progress(0, desc="Spawning AI report thread...") | |
ai_report_queue = [""] # Use a mutable list to pass string by reference | |
def generate_ai_report_threaded(analyzer_instance): | |
narrative_generator = GeminiNarrativeGenerator(api_key=settings.GOOGLE_API_KEY) | |
ai_report_queue[0] = narrative_generator.generate_narrative(analyzer_instance) | |
thread = Thread(target=generate_ai_report_threaded, args=(analyzer,)) | |
thread.start() | |
# 2. Generate standard reports and visuals | |
progress(0.4, desc="Generating data profiles and visuals...") | |
meta = analyzer.metadata | |
missing_df, num_df, cat_df = analyzer.get_profiling_reports() | |
fig_types, fig_missing, fig_corr = analyzer.get_overview_visuals() | |
# 3. Yield the first set of updates to populate the main dashboard immediately. | |
# The order of this tuple MUST match the `main_outputs` list in `app.py`. | |
initial_updates = ( | |
gr.update(value="⏳ Generating AI-powered report in the background... The main dashboard is ready now."), # 0: ai_report_output | |
gr.update(value=missing_df), # 1: profile_missing_df | |
gr.update(value=num_df), # 2: profile_numeric_df | |
gr.update(value=cat_df), # 3: profile_categorical_df | |
gr.update(value=fig_types), # 4: plot_types | |
gr.update(value=fig_missing), # 5: plot_missing | |
gr.update(value=fig_corr), # 6: plot_correlation | |
gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][0] if meta['numeric_cols'] else None), # 7: dd_hist_col | |
gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][0] if meta['numeric_cols'] else None), # 8: dd_scatter_x | |
gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][1] if len(meta['numeric_cols']) > 1 else None), # 9: dd_scatter_y | |
gr.update(choices=meta['columns']), # 10: dd_scatter_color | |
gr.update(visible=bool(meta['datetime_cols'])), # 11: tab_timeseries | |
gr.update(visible=bool(meta['text_cols'])), # 12: tab_text | |
gr.update(visible=len(meta['numeric_cols']) > 1) # 13: tab_cluster | |
) | |
yield initial_updates | |
# 4. Wait for the AI thread to complete | |
thread.join() | |
progress(1.0, desc="AI Report complete!") | |
# 5. Yield the final update. We create a mutable list from the initial tuple, | |
# update the AI report element, and convert it back to a tuple to yield. | |
final_updates_list = list(initial_updates) | |
final_updates_list[0] = gr.update(value=ai_report_queue[0]) | |
yield tuple(final_updates_list) | |
# --- Interactive Explorer Callbacks --- | |
def create_histogram(analyzer, col): | |
"""Generates a histogram for a selected numeric column.""" | |
if not isinstance(analyzer, DataAnalyzer) or not col: | |
return go.Figure().update_layout(title="Select a column to generate a histogram") | |
return px.histogram(analyzer.df, x=col, title=f"<b>Distribution of {col}</b>", marginal="box", template="plotly_white") | |
def create_scatterplot(analyzer, x_col, y_col, color_col): | |
"""Generates a scatter plot for selected X, Y, and optional color columns.""" | |
if not isinstance(analyzer, DataAnalyzer) or not x_col or not y_col: | |
return go.Figure().update_layout(title="Select X and Y axes to generate a scatter plot") | |
# Use a subset for performance on large datasets | |
df_sample = analyzer.df | |
if len(analyzer.df) > 10000: | |
df_sample = analyzer.df.sample(n=10000, random_state=42) | |
return px.scatter( | |
df_sample, x=x_col, y=y_col, color=color_col if color_col else None, | |
title=f"<b>Scatter Plot: {x_col} vs. {y_col}</b>", template="plotly_white" | |
) | |
# --- Specialized Module Callbacks --- | |
def update_clustering(analyzer, k): | |
"""Callback for the clustering module. Returns a tuple of three updates.""" | |
if not isinstance(analyzer, DataAnalyzer): | |
return gr.update(), gr.update(), gr.update(value="Run analysis first.") | |
# Delegate the heavy lifting to the specialized module | |
fig_cluster, fig_elbow, summary = perform_clustering(analyzer.df, analyzer.metadata['numeric_cols'], k) | |
return fig_cluster, fig_elbow, summary |