Spaces:
Sleeping
Sleeping
# ui/callbacks.py | |
# -*- coding: utf-8 -*- | |
# | |
# PROJECT: CognitiveEDA v5.7 - The QuantumLeap Intelligence Platform | |
# | |
# DESCRIPTION: This module contains the core logic for all Gradio event handlers. | |
# The clustering callback is now updated to include persona profiling. | |
import gradio as gr | |
import pandas as pd | |
import logging | |
from threading import Thread | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from core.analyzer import DataAnalyzer, engineer_features | |
from core.llm import GeminiNarrativeGenerator | |
from core.config import settings | |
from core.exceptions import DataProcessingError | |
from modules.clustering import perform_clustering | |
# --- NEW IMPORT --- | |
from modules.profiling import profile_clusters | |
# --- Primary Analysis Chain (Unchanged) --- | |
def run_initial_analysis(file_obj, progress=gr.Progress(track_tqdm=True)): | |
if file_obj is None: raise gr.Error("No file uploaded.") | |
progress(0, desc="Validating configuration...") | |
if not settings.GOOGLE_API_KEY: raise gr.Error("CRITICAL: GOOGLE_API_KEY is not configured.") | |
try: | |
progress(0.1, desc="Loading raw data...") | |
df_raw = pd.read_csv(file_obj.name) if file_obj.name.endswith('.csv') else pd.read_excel(file_obj.name) | |
if len(df_raw) > settings.MAX_UI_ROWS: | |
df_raw = df_raw.sample(n=settings.MAX_UI_ROWS, random_state=42) | |
progress(0.5, desc="Applying strategic feature engineering...") | |
df_engineered = engineer_features(df_raw) | |
progress(0.8, desc="Instantiating analysis engine...") | |
analyzer = DataAnalyzer(df_engineered) | |
progress(1.0, desc="Analysis complete. Generating reports...") | |
return analyzer | |
except Exception as e: | |
logging.error(f"Error in initial analysis: {e}", exc_info=True) | |
raise gr.Error(f"Analysis Failed: {str(e)}") | |
def generate_reports_and_visuals(analyzer, progress=gr.Progress(track_tqdm=True)): | |
if not isinstance(analyzer, DataAnalyzer): | |
yield (None,) * 14 | |
return | |
progress(0, desc="Spawning AI report thread...") | |
ai_report_queue = [""] | |
def generate_ai_report_threaded(a): | |
narrative_generator = GeminiNarrativeGenerator(settings.GOOGLE_API_KEY) | |
ai_report_queue[0] = narrative_generator.generate_narrative(a) | |
thread = Thread(target=generate_ai_report_threaded, args=(analyzer,)) | |
thread.start() | |
progress(0.4, desc="Generating reports and visuals...") | |
meta = analyzer.metadata | |
missing_df, num_df, cat_df = analyzer.get_profiling_reports() | |
fig_types, fig_missing, fig_corr = analyzer.get_overview_visuals() | |
initial_updates = ( | |
gr.update(value="⏳ Generating AI report..."), gr.update(value=missing_df), | |
gr.update(value=num_df), gr.update(value=cat_df), gr.update(value=fig_types), | |
gr.update(value=fig_missing), gr.update(value=fig_corr), | |
gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][0] if meta['numeric_cols'] else None), | |
gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][0] if meta['numeric_cols'] else None), | |
gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][1] if len(meta['numeric_cols']) > 1 else None), | |
gr.update(choices=meta['columns']), gr.update(visible=bool(meta['datetime_cols'])), | |
gr.update(visible=bool(meta['text_cols'])), gr.update(visible=len(meta['numeric_cols']) > 1) | |
) | |
yield initial_updates | |
thread.join() | |
progress(1.0, desc="AI Report complete!") | |
final_updates_list = list(initial_updates) | |
final_updates_list[0] = gr.update(value=ai_report_queue[0]) | |
yield tuple(final_updates_list) | |
# --- Interactive Explorer Callbacks (Unchanged) --- | |
def create_histogram(analyzer, col): | |
if not isinstance(analyzer, DataAnalyzer) or not col: return go.Figure() | |
return px.histogram(analyzer.df, x=col, title=f"<b>Distribution of {col}</b>", marginal="box") | |
def create_scatterplot(analyzer, x_col, y_col, color_col): | |
if not isinstance(analyzer, DataAnalyzer) or not x_col or not y_col: return go.Figure() | |
df_sample = analyzer.df.sample(n=min(len(analyzer.df), 10000)) | |
return px.scatter(df_sample, x=x_col, y=y_col, color=color_col if color_col else None) | |
# --- MODIFIED CLUSTERING CALLBACK --- | |
def update_clustering(analyzer, k): | |
""" | |
Orchestrates the full clustering workflow: | |
1. Runs K-Means clustering. | |
2. Receives cluster labels. | |
3. Calls the profiling module to analyze the segments. | |
4. Returns all results to the UI. | |
""" | |
if not isinstance(analyzer, DataAnalyzer): | |
# Return empty updates for all 5 clustering output components | |
return go.Figure(), go.Figure(), "", "", go.Figure() | |
# Step 1: Perform Clustering to get visuals and labels | |
fig_cluster, fig_elbow, summary, cluster_labels = perform_clustering( | |
analyzer.df, analyzer.metadata['numeric_cols'], k | |
) | |
if cluster_labels.empty: | |
# Handle cases where clustering fails (e.g., not enough data) | |
return fig_cluster, fig_elbow, summary, "Clustering failed. No personas to profile.", go.Figure() | |
# Step 2: Profile the resulting clusters | |
numeric_to_profile = ['Total_Revenue', 'Quantity_Ordered', 'Hour'] | |
cats_to_profile = ['City', 'Product', 'Day_of_Week'] | |
# Filter to only use columns that actually exist in the engineered dataframe | |
numeric_to_profile = [c for c in numeric_to_profile if c in analyzer.df.columns] | |
cats_to_profile = [c for c in cats_to_profile if c in analyzer.df.columns] | |
md_personas, fig_profile = profile_clusters( | |
analyzer.df, cluster_labels, numeric_to_profile, cats_to_profile | |
) | |
# Step 3: Return all 5 results in the correct order for the UI | |
return fig_cluster, fig_elbow, summary, md_personas, fig_profile |