File size: 7,210 Bytes
d1943e0
12fa967
 
 
57f807e
12fa967
57f807e
 
 
12fa967
d1943e0
 
 
 
 
7d40c30
 
 
fe02df7
d1943e0
 
 
1dae368
7d40c30
57f807e
7d40c30
 
1dae368
fcc261b
1dae368
fcc261b
fe02df7
 
 
 
 
 
1dae368
fe02df7
 
7d40c30
fcc261b
1dae368
 
fcc261b
 
57f807e
 
 
7d40c30
57f807e
a00699a
57f807e
7d40c30
fe02df7
57f807e
fcc261b
 
57f807e
fe02df7
fcc261b
 
 
57f807e
 
 
 
 
 
 
 
 
a00699a
1dae368
 
 
57f807e
 
 
 
 
 
a00699a
fcc261b
57f807e
fcc261b
 
a00699a
 
 
fcc261b
57f807e
7d40c30
57f807e
 
 
 
 
 
 
 
 
 
1dae368
57f807e
1dae368
57f807e
 
1dae368
7d40c30
1dae368
 
57f807e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dae368
57f807e
1dae368
 
 
57f807e
1dae368
57f807e
 
 
1dae368
 
57f807e
1dae368
57f807e
 
 
 
 
 
1dae368
57f807e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# ui/callbacks.py

# -*- coding: utf-8 -*-
#
# PROJECT:      CognitiveEDA v5.9 - The QuantumLeap Intelligence Platform
#
# DESCRIPTION:  This module is updated with a generic, data-agnostic
#               stratification engine. It dynamically identifies candidate
#               features for filtering and updates the UI accordingly.

import gradio as gr
import pandas as pd
import logging
from threading import Thread

import plotly.graph_objects as go
import plotly.express as px

from core.analyzer import DataAnalyzer, engineer_features
from core.llm import GeminiNarrativeGenerator
from core.config import settings
from modules.clustering import perform_clustering
from modules.profiling import profile_clusters

# --- Primary Analysis Chain ---

def run_initial_analysis(file_obj, progress=gr.Progress(track_tqdm=True)):
    if file_obj is None: raise gr.Error("No file uploaded.")
    progress(0, desc="Validating configuration...")
    if not settings.GOOGLE_API_KEY: raise gr.Error("CRITICAL: GOOGLE_API_KEY is not configured.")
    try:
        progress(0.1, desc="Loading raw data...")
        df_raw = pd.read_csv(file_obj.name) if file_obj.name.endswith('.csv') else pd.read_excel(file_obj.name)
        if len(df_raw) > settings.MAX_UI_ROWS:
            df_raw = df_raw.sample(n=settings.MAX_UI_ROWS, random_state=42)
        progress(0.5, desc="Applying strategic feature engineering...")
        df_engineered = engineer_features(df_raw)
        progress(0.8, desc="Instantiating analysis engine...")
        analyzer = DataAnalyzer(df_engineered) 
        progress(1.0, desc="Analysis complete. Generating reports...")
        return analyzer
    except Exception as e:
        logging.error(f"Error in initial analysis: {e}", exc_info=True)
        raise gr.Error(f"Analysis Failed: {str(e)}")

def generate_reports_and_visuals(analyzer, progress=gr.Progress(track_tqdm=True)):
    """
    Phase 2: Now populates the generic 'Stratify By' dropdown with candidate columns.
    """
    if not isinstance(analyzer, DataAnalyzer):
        yield (None,) * 15
        return

    progress(0, desc="Spawning AI report thread...")
    ai_report_queue = [""]
    def generate_ai_report_threaded(a): ai_report_queue[0] = GeminiNarrativeGenerator(settings.GOOGLE_API_KEY).generate_narrative(a)
    thread = Thread(target=generate_ai_report_threaded, args=(analyzer,))
    thread.start()

    progress(0.4, desc="Generating reports and visuals...")
    meta = analyzer.metadata
    missing_df, num_df, cat_df = analyzer.get_profiling_reports()
    fig_types, fig_missing, fig_corr = analyzer.get_overview_visuals()

    # --- Dynamically identify candidate columns for stratification ---
    candidate_cols = ["(Do not stratify)"]
    if 'categorical_cols' in meta:
        for col in meta['categorical_cols']:
            # A good candidate has more than 1 but fewer than 50 unique values (heuristic)
            if analyzer.df[col].dtype.name != 'object' or (1 < analyzer.df[col].nunique() < 50):
                 candidate_cols.append(col)
    
    initial_updates = (
        gr.update(value="⏳ Generating AI report..."), gr.update(value=missing_df),
        gr.update(value=num_df), gr.update(value=cat_df), gr.update(value=fig_types),
        gr.update(value=fig_missing), gr.update(value=fig_corr),
        gr.update(choices=meta.get('numeric_cols', [])),
        gr.update(choices=meta.get('numeric_cols', [])),
        gr.update(choices=meta.get('numeric_cols', [])),
        gr.update(choices=meta.get('columns', [])), gr.update(visible=bool(meta.get('datetime_cols'))),
        gr.update(visible=bool(meta.get('text_cols'))), gr.update(visible=len(meta.get('numeric_cols', [])) > 1),
        gr.update(choices=candidate_cols, value="(Do not stratify)") # dd_stratify_by_col
    )
    yield initial_updates

    thread.join()
    progress(1.0, desc="AI Report complete!")
    final_updates_list = list(initial_updates)
    final_updates_list[0] = gr.update(value=ai_report_queue[0])
    yield tuple(final_updates_list)

# --- Stratification Callbacks ---

def update_filter_dropdown(analyzer, stratify_col):
    """
    When the user selects a feature to stratify by, this function populates
    the second dropdown with the unique values of that feature.
    """
    if not isinstance(analyzer, DataAnalyzer) or not stratify_col or stratify_col == "(Do not stratify)":
        return gr.update(choices=[], value=None, interactive=False)
    
    values = ["(Global Analysis)"] + sorted(analyzer.df[stratify_col].unique().tolist())
    return gr.update(choices=values, value="(Global Analysis)", interactive=True)

def update_stratified_clustering(analyzer, stratify_col, filter_value, k):
    """
    Orchestrates the full clustering workflow on a dataset that is generically
    filtered based on user selections.
    """
    if not isinstance(analyzer, DataAnalyzer):
        return go.Figure(), go.Figure(), "", "", go.Figure()

    logging.info(f"Updating clustering. Stratify by: '{stratify_col}', Filter: '{filter_value}', K={k}")

    # Step 1: Stratify the DataFrame based on user selection
    analysis_df = analyzer.df
    report_title_prefix = "Global Analysis: "

    if stratify_col and stratify_col != "(Do not stratify)" and filter_value and filter_value != "(Global Analysis)":
        analysis_df = analyzer.df[analyzer.df[stratify_col] == filter_value]
        report_title_prefix = f"Analysis for '{stratify_col}' = '{filter_value}': "
    
    if len(analysis_df) < k:
        error_msg = f"Not enough data ({len(analysis_df)} rows) to form {k} clusters for the selected filter."
        return go.Figure(), go.Figure(), error_msg, error_msg, go.Figure()

    # Step 2: Perform Clustering
    numeric_cols = [c for c in analyzer.metadata['numeric_cols'] if c in analysis_df.columns]
    fig_cluster, fig_elbow, summary, cluster_labels = perform_clustering(
        analysis_df, numeric_cols, k
    )

    if cluster_labels.empty:
        return fig_cluster, fig_elbow, summary, "Clustering failed.", go.Figure()

    # Step 3: Profile the resulting clusters
    cats_to_profile = [c for c in analyzer.metadata['categorical_cols'] if c in analysis_df.columns]
    numeric_to_profile = [c for c in numeric_cols if c not in ['Month', 'Day_of_Week', 'Is_Weekend', 'Hour']]

    md_personas, fig_profile = profile_clusters(
        analysis_df, cluster_labels, numeric_to_profile, cats_to_profile
    )
    
    summary = f"**{report_title_prefix}**" + summary
    md_personas = f"**{report_title_prefix}**" + md_personas

    # Step 4: Return all results
    return fig_cluster, fig_elbow, summary, md_personas, fig_profile

# --- Other Callbacks ---
def create_histogram(analyzer, col):
    if not isinstance(analyzer, DataAnalyzer) or not col: return go.Figure()
    return px.histogram(analyzer.df, x=col, title=f"<b>Distribution of {col}</b>", marginal="box")

def create_scatterplot(analyzer, x_col, y_col, color_col):
    if not isinstance(analyzer, DataAnalyzer) or not x_col or not y_col: return go.Figure()
    df_sample = analyzer.df.sample(n=min(len(analyzer.df), 10000))
    return px.scatter(df_sample, x=x_col, y=y_col, color=color_col if color_col else None)