CognitiveEDA / ui /callbacks.py
mgbam's picture
Update ui/callbacks.py
fe02df7 verified
raw
history blame
5.73 kB
# ui/callbacks.py
# -*- coding: utf-8 -*-
#
# PROJECT: CognitiveEDA v5.6 - The QuantumLeap Intelligence Platform
#
# DESCRIPTION: This module contains the core logic for all Gradio event handlers.
# The main analysis pipeline now includes a strategic feature
# engineering step before analysis.
import gradio as gr
import pandas as pd
import logging
from threading import Thread
import plotly.graph_objects as go
import plotly.express as px
# --- MODIFIED IMPORT ---
# Import both the analyzer class and the new feature engineering function
from core.analyzer import DataAnalyzer, engineer_features
from core.llm import GeminiNarrativeGenerator
from core.config import settings
from core.exceptions import DataProcessingError
from modules.clustering import perform_clustering
# --- Primary Analysis Chain ---
def run_initial_analysis(file_obj, progress=gr.Progress(track_tqdm=True)):
"""
Phase 1: Now includes the strategic feature engineering step.
Validates inputs, loads raw data, applies feature engineering, and then
creates the core DataAnalyzer object on the transformed data.
"""
if file_obj is None:
raise gr.Error("No file uploaded. Please upload a CSV or Excel file.")
progress(0, desc="Validating configuration...")
if not settings.GOOGLE_API_KEY:
logging.error("Analysis attempted without GOOGLE_API_KEY set.")
raise gr.Error("CRITICAL: GOOGLE_API_KEY is not configured. Please add it as a secret.")
try:
progress(0.1, desc="Loading raw data...")
df_raw = pd.read_csv(file_obj.name) if file_obj.name.endswith('.csv') else pd.read_excel(file_obj.name)
if len(df_raw) > settings.MAX_UI_ROWS:
df_raw = df_raw.sample(n=settings.MAX_UI_ROWS, random_state=42)
logging.info(f"DataFrame sampled down to {settings.MAX_UI_ROWS} rows.")
# --- INTEGRATION POINT ---
# Apply the feature engineering function immediately after loading
progress(0.5, desc="Applying strategic feature engineering...")
df_engineered = engineer_features(df_raw)
# -------------------------
progress(0.8, desc="Instantiating analysis engine on engineered data...")
# The analyzer now works with the transformed, high-value dataset
analyzer = DataAnalyzer(df_engineered)
progress(1.0, desc="Analysis complete. Generating reports...")
return analyzer
except Exception as e:
logging.error(f"A critical error occurred during initial analysis: {e}", exc_info=True)
raise gr.Error(f"Analysis Failed! An unexpected error occurred: {str(e)}")
def generate_reports_and_visuals(analyzer, progress=gr.Progress(track_tqdm=True)):
"""
Phase 2: Slower, multi-stage report and visual generation.
Yields tuples of UI updates based on the *engineered* data.
"""
if not isinstance(analyzer, DataAnalyzer):
logging.warning("generate_reports_and_visuals called without a valid analyzer. Clearing UI.")
yield (None,) * 14
return
progress(0, desc="Spawning AI report thread...")
ai_report_queue = [""]
def generate_ai_report_threaded(analyzer_instance):
narrative_generator = GeminiNarrativeGenerator(api_key=settings.GOOGLE_API_KEY)
ai_report_queue[0] = narrative_generator.generate_narrative(analyzer_instance)
thread = Thread(target=generate_ai_report_threaded, args=(analyzer,))
thread.start()
progress(0.4, desc="Generating reports and visuals...")
meta = analyzer.metadata
missing_df, num_df, cat_df = analyzer.get_profiling_reports()
fig_types, fig_missing, fig_corr = analyzer.get_overview_visuals()
initial_updates = (
gr.update(value="⏳ Generating AI report... Dashboard is ready."),
gr.update(value=missing_df),
gr.update(value=num_df),
gr.update(value=cat_df),
gr.update(value=fig_types),
gr.update(value=fig_missing),
gr.update(value=fig_corr),
gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][0] if meta['numeric_cols'] else None),
gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][0] if meta['numeric_cols'] else None),
gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][1] if len(meta['numeric_cols']) > 1 else None),
gr.update(choices=meta['columns']),
gr.update(visible=bool(meta['datetime_cols'])),
gr.update(visible=bool(meta['text_cols'])),
gr.update(visible=len(meta['numeric_cols']) > 1)
)
yield initial_updates
thread.join()
progress(1.0, desc="AI Report complete!")
final_updates_list = list(initial_updates)
final_updates_list[0] = gr.update(value=ai_report_queue[0])
yield tuple(final_updates_list)
# --- Interactive Explorer & Module Callbacks ---
def create_histogram(analyzer, col):
if not isinstance(analyzer, DataAnalyzer) or not col:
return go.Figure()
return px.histogram(analyzer.df, x=col, title=f"<b>Distribution of {col}</b>", marginal="box")
def create_scatterplot(analyzer, x_col, y_col, color_col):
if not isinstance(analyzer, DataAnalyzer) or not x_col or not y_col:
return go.Figure()
df_sample = analyzer.df.sample(n=min(len(analyzer.df), 10000))
return px.scatter(df_sample, x=x_col, y=y_col, color=color_col if color_col else None)
def update_clustering(analyzer, k):
if not isinstance(analyzer, DataAnalyzer):
return gr.update(), gr.update(), gr.update()
fig_cluster, fig_elbow, summary = perform_clustering(analyzer.df, analyzer.metadata['numeric_cols'], k)
return fig_cluster, fig_elbow, summary