File size: 21,072 Bytes
c08faed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
import gradio as gr
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import io
from scipy import stats
import warnings
import google.generativeai as genai
import os
from dotenv import load_dotenv
import logging
import json
from contextlib import redirect_stdout

# --- Configuration ---
warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Helper Functions ---

def safe_exec(code_string: str, local_vars: dict) -> tuple:
    """Safely execute a string of Python code and capture its output."""
    output_buffer = io.StringIO()
    try:
        with redirect_stdout(output_buffer):
            exec(code_string, globals(), local_vars)
        
        stdout_output = output_buffer.getvalue()
        fig = local_vars.get('fig', None)
        return stdout_output, fig, None
    except Exception as e:
        error_message = f"Execution Error: {str(e)}"
        logging.error(f"Error executing AI-generated code: {error_message}")
        return None, None, error_message

# --- Core Data Processing ---

def load_and_process_file(file_obj, state_dict):
    """Loads a CSV file and performs initial processing, updating the global state."""
    if file_obj is None:
        return None, "Please upload a file.", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)

    try:
        df = pd.read_csv(file_obj.name)
        
        # Attempt to convert object columns to datetime
        for col in df.select_dtypes(include=['object']).columns:
            try:
                df[col] = pd.to_datetime(df[col], errors='raise')
                logging.info(f"Successfully converted column '{col}' to datetime.")
            except (ValueError, TypeError):
                continue

        metadata = extract_dataset_metadata(df)
        
        state_dict = {
            'df': df,
            'metadata': metadata,
            'filename': os.path.basename(file_obj.name)
        }
        
        # Update UI elements dynamically
        update_args = {
            'choices': metadata['columns'],
            'value': None,
            'interactive': True
        }
        
        # Check for time series tab visibility
        time_series_visible = len(metadata['datetime_cols']) > 0
        
        return (
            state_dict, 
            f"βœ… Loaded `{state_dict['filename']}` ({metadata['shape'][0]} rows, {metadata['shape'][1]} cols)",
            gr.update(**update_args), gr.update(**update_args), gr.update(**update_args),
            gr.update(choices=metadata['numeric_cols'], value=None, interactive=True),
            gr.update(choices=metadata['datetime_cols'], value=None, interactive=True),
            gr.update(visible=time_series_visible), # Show/hide Time Series tab
            gr.update(visible=True) # Show Chatbot tab
        )
    except Exception as e:
        logging.error(f"Error loading file: {e}")
        return state_dict, f"❌ Error: {e}", gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), gr.update(visible=False)

def extract_dataset_metadata(df: pd.DataFrame) -> dict:
    """Extracts comprehensive metadata from a DataFrame."""
    rows, cols = df.shape
    columns = df.columns.tolist()
    
    numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
    datetime_cols = df.select_dtypes(include=['datetime64', 'datetime64[ns]']).columns.tolist()
    
    missing_data = df.isnull().sum()
    data_quality = round((df.notna().sum().sum() / (rows * cols)) * 100, 1) if rows * cols > 0 else 0

    return {
        'shape': (rows, cols),
        'columns': columns,
        'numeric_cols': numeric_cols,
        'categorical_cols': categorical_cols,
        'datetime_cols': datetime_cols,
        'dtypes': df.dtypes.to_string(),
        'missing_data': missing_data.to_dict(),
        'data_quality': data_quality,
        'head': df.head().to_string()
    }

# --- Tab 1: AI Overview ---

def analyze_dataset_overview(state_dict, api_key: str):
    """Generates an AI-powered narrative overview of the dataset."""
    if not state_dict:
        return "❌ Please upload a dataset first.", "", 0
    if not api_key:
        return "❌ Please enter your Gemini API key.", "", 0
    
    metadata = state_dict['metadata']
    
    # Create prompt for Gemini
    prompt = f"""
    You are an expert data analyst and storyteller. Your task is to provide a high-level, engaging overview of a dataset based on its metadata.

    **Dataset Metadata:**
    - **Shape:** {metadata['shape'][0]} rows, {metadata['shape'][1]} columns
    - **Column Names:** {', '.join(metadata['columns'])}
    - **Numeric Columns:** {', '.join(metadata['numeric_cols'])}
    - **Categorical Columns:** {', '.join(metadata['categorical_cols'])}
    - **Datetime Columns:** {', '.join(metadata['datetime_cols'])}
    - **Data Quality (Non-missing values):** {metadata['data_quality']}%
    - **First 5 rows:**
    {metadata['head']}

    **Your Task:**
    Based on the metadata, generate a report in Markdown format. Use emojis to make it visually appealing. The report should have the following sections:

    # πŸš€ AI-Powered Dataset Overview

    ## πŸ€” What is this dataset likely about?
    (Predict the domain and purpose of the dataset, e.g., "This appears to be customer transaction data for an e-commerce platform.")

    ## πŸ’‘ Potential Key Questions to Explore
    - (Suggest 3-4 interesting business or research questions the data could answer.)
    - (Example: "Which products are most frequently purchased together?")

    ## πŸ“Š Potential Analyses & Visualizations
    - (List 3-4 types of analyses that would be valuable.)
    - (Example: "Time series analysis of sales to identify seasonality.")

    ## ⚠️ Data Quality & Potential Issues
    - (Briefly comment on the data quality score and mention if the presence of datetime columns is a good sign for certain analyses.)
    """

    try:
        genai.configure(api_key=api_key)
        model = genai.GenerativeModel('gemini-1.5-flash')
        response = model.generate_content(prompt)
        story = response.text
    except Exception as e:
        story = f"## ⚠️ AI Generation Failed\n**Error:** {str(e)}\n\nPlease check your API key and network connection. A fallback analysis is provided below.\n\n" \
                f"### Fallback Analysis\nThis dataset contains **{metadata['shape'][0]}** records and **{metadata['shape'][1]}** features. " \
                f"It includes **{len(metadata['numeric_cols'])}** numeric, **{len(metadata['categorical_cols'])}** categorical, " \
                f"and **{len(metadata['datetime_cols'])}** time-based columns. The overall data quality is **{metadata['data_quality']}%**, " \
                f"which is a good starting point for analysis."

    # Basic Info Summary
    basic_info = f"""
    πŸ“‹ **File:** `{state_dict.get('filename', 'N/A')}`
    πŸ“Š **Size:** {metadata['shape'][0]:,} rows Γ— {metadata['shape'][1]} columns
    πŸ”’ **Features:**
      β€’ **Numeric:** {len(metadata['numeric_cols'])}
      β€’ **Categorical:** {len(metadata['categorical_cols'])}
      β€’ **DateTime:** {len(metadata['datetime_cols'])}
    🎯 **Data Quality:** {metadata['data_quality']}%
    """
    
    return story, basic_info, metadata['data_quality']

# --- Tab 2: Univariate Analysis ---

def generate_univariate_plot(column_name, state_dict):
    """Generates plots for a single selected variable."""
    if not column_name or not state_dict:
        return None, "Select a column to analyze."
    
    df = state_dict['df']
    metadata = state_dict['metadata']
    
    fig = None
    summary = ""

    if column_name in metadata['numeric_cols']:
        fig = make_subplots(rows=1, cols=2, subplot_titles=("Histogram", "Box Plot"))
        fig.add_trace(go.Histogram(x=df[column_name], name="Histogram"), row=1, col=1)
        fig.add_trace(go.Box(y=df[column_name], name="Box Plot"), row=1, col=2)
        fig.update_layout(title_text=f"Distribution of '{column_name}'", showlegend=False)
        summary = df[column_name].describe().to_frame().to_markdown()
    
    elif column_name in metadata['categorical_cols']:
        top_n = 20
        counts = df[column_name].value_counts()
        title = f"Top {min(top_n, len(counts))} Categories for '{column_name}'"
        fig = px.bar(counts.nlargest(top_n), title=title, labels={'index': column_name, 'value': 'Count'})
        fig.update_layout(showlegend=False)
        summary = counts.to_frame().to_markdown()
        
    elif column_name in metadata['datetime_cols']:
        counts = df[column_name].dt.to_period("M").value_counts().sort_index()
        fig = px.line(x=counts.index.to_timestamp(), y=counts.values, title=f"Records over Time for '{column_name}'")
        fig.update_layout(xaxis_title="Time", yaxis_title="Record Count")
        summary = df[column_name].describe(datetime_is_numeric=True).to_frame().to_markdown()

    return fig, summary

# --- Tab 3: Bivariate Analysis ---

def generate_bivariate_plot(x_col, y_col, state_dict):
    """Generates plots to explore the relationship between two variables."""
    if not x_col or not y_col or not state_dict:
        return None, "Select two columns to analyze."
    if x_col == y_col:
        return None, "Please select two different columns."
        
    df = state_dict['df']
    metadata = state_dict['metadata']
    
    x_type = 'numeric' if x_col in metadata['numeric_cols'] else 'categorical'
    y_type = 'numeric' if y_col in metadata['numeric_cols'] else 'categorical'

    fig = None
    title = f"{x_col} vs. {y_col}"

    if x_type == 'numeric' and y_type == 'numeric':
        fig = px.scatter(df, x=x_col, y=y_col, title=f"Scatter Plot: {title}", trendline="ols", trendline_color_override="red")
        summary = df[[x_col, y_col]].corr().to_markdown()
    elif x_type == 'numeric' and y_type == 'categorical':
        fig = px.box(df, x=x_col, y=y_col, title=f"Box Plot: {title}")
        summary = df.groupby(y_col)[x_col].describe().to_markdown()
    elif x_type == 'categorical' and y_type == 'numeric':
        fig = px.box(df, x=y_col, y=x_col, title=f"Box Plot: {title}")
        summary = df.groupby(x_col)[y_col].describe().to_markdown()
    else: # Both categorical
        crosstab = pd.crosstab(df[x_col], df[y_col])
        fig = px.imshow(crosstab, title=f"Heatmap of Counts: {title}", text_auto=True)
        summary = crosstab.to_markdown()
        
    return fig, f"### Analysis Summary\n{summary}"

# --- Tab 4: Time Series Analysis ---

def generate_time_series_plot(time_col, value_col, resample_freq, state_dict):
    """Generates a time series plot with resampling."""
    if not time_col or not value_col or not state_dict:
        return None, "Select Time and Value columns."

    df = state_dict['df'].copy()
    
    try:
        df[time_col] = pd.to_datetime(df[time_col])
        df_resampled = df.set_index(time_col)[value_col].resample(resample_freq).mean().reset_index()
        
        fig = px.line(df_resampled, x=time_col, y=value_col, 
                      title=f"Time Series of {value_col} (Resampled to '{resample_freq}')")
        fig.update_layout(xaxis_title="Date", yaxis_title=f"Mean of {value_col}")
        return fig, f"Showing mean of '{value_col}' aggregated by '{resample_freq}'."
    except Exception as e:
        return None, f"Error: {e}"

# --- Tab 5: AI Analyst Chat ---

def respond_to_chat(user_message, history, state_dict, api_key):
    """Handles the chat interaction with the AI Analyst."""
    if not api_key:
        history.append((user_message, "I can't answer without a Gemini API key. Please enter it in the 'AI Overview' tab."))
        return history, None, ""

    if not state_dict:
        history.append((user_message, "Please upload a dataset before asking questions."))
        return history, None, ""
    
    history.append((user_message, None))

    df_metadata = state_dict['metadata']
    
    # Construct a robust prompt for the AI
    prompt = f"""
    You are an AI Data Analyst assistant. Your name is 'Gemini Analyst'.
    You are given a pandas DataFrame named `df`.
    Your goal is to answer the user's question about this DataFrame by writing and executing Python code.

    **Instructions:**
    1.  Analyze the user's question.
    2.  Write Python code to answer it.
    3.  You can use pandas, numpy, and plotly.express.
    4.  If you create a plot, you **MUST** assign it to a variable named `fig`. The plot will be displayed to the user.
    5.  If you are just calculating something or printing text, the `print()` output will be shown.
    6.  **DO NOT** write any code that modifies the DataFrame (e.g., `df.dropna(inplace=True)`). Use `df.copy()` if you need to modify data.
    7.  Respond **ONLY** with a JSON object containing two keys: "thought" and "code".
        - "thought": A short, one-sentence explanation of your plan.
        - "code": A string containing the Python code to execute.

    **DataFrame Metadata:**
    - **Filename:** {state_dict['filename']}
    - **Shape:** {df_metadata['shape'][0]} rows, {df_metadata['shape'][1]} columns
    - **Columns and Data Types:**
    {df_metadata['dtypes']}

    ---
    **User Question:** "{user_message}"
    ---

    **Your JSON Response:**
    """
    
    try:
        genai.configure(api_key=api_key)
        model = genai.GenerativeModel('gemini-1.5-flash')
        response = model.generate_content(prompt)
        
        # Clean and parse the JSON response
        response_text = response.text.strip().replace("```json", "").replace("```", "")
        response_json = json.loads(response_text)
        
        thought = response_json.get("thought", "Thinking...")
        code_to_run = response_json.get("code", "")
        
        bot_message = f"🧠 **Thought:** {thought}\n\n"
        
        # Execute the code
        local_vars = {'df': state_dict['df'], 'px': px, 'pd': pd, 'np': np}
        stdout, fig_result, error = safe_exec(code_to_run, local_vars)
        
        if error:
            bot_message += f"πŸ’₯ **Error:**\n```\n{error}\n```"
            history[-1] = (user_message, bot_message)
            return history, None, ""
        
        if stdout:
            bot_message += f"πŸ“‹ **Output:**\n```\n{stdout}\n```"
        
        if not fig_result and not stdout:
            bot_message += "βœ… Code executed successfully, but it produced no visible output."

        history[-1] = (user_message, bot_message)
        
        return history, fig_result, ""
        
    except Exception as e:
        error_msg = f"An unexpected error occurred: {e}. The AI might have returned an invalid response. Please try rephrasing your question."
        logging.error(f"Chatbot error: {error_msg}")
        history[-1] = (user_message, error_msg)
        return history, None, ""

# --- Gradio Interface ---

def create_gradio_interface():
    """Builds and returns the full Gradio application interface."""
    with gr.Blocks(title="πŸš€ AI Data Explorer", theme=gr.themes.Soft()) as demo:
        # Global state to hold data
        global_state = gr.State({})

        # Header
        gr.Markdown("# πŸš€ AI Data Explorer: Your Advanced Analytic Tool")
        gr.Markdown("Upload a CSV, then explore your data with interactive tabs and a powerful AI Analyst.")

        # --- Top Row: File Upload and API Key ---
        with gr.Row():
            with gr.Column(scale=2):
                file_input = gr.File(label="πŸ“ Upload CSV File", file_types=[".csv"])
                status_output = gr.Markdown("Status: Waiting for file...")
            with gr.Column(scale=1):
                api_key_input = gr.Textbox(
                    label="πŸ”‘ Gemini API Key",
                    placeholder="Enter your key here...",
                    type="password",
                    info="Get your free key from Google AI Studio"
                )

        # --- Main Tabs ---
        with gr.Tabs() as tabs:
            # Tab 1: AI Overview
            with gr.Tab("πŸ€– AI Overview", id=0):
                overview_btn = gr.Button("🧠 Generate AI Overview", variant="primary")
                with gr.Row():
                    story_output = gr.Markdown(label="πŸ“– AI-Generated Story")
                    with gr.Column():
                        basic_info_output = gr.Markdown(label="πŸ“‹ Basic Information")
                        quality_score = gr.Number(label="🎯 Data Quality Score (%)", interactive=False)

            # Tab 2: Univariate Analysis
            with gr.Tab("πŸ“Š Univariate Analysis", id=1):
                uni_col_select = gr.Dropdown(label="Select a Column to Analyze", interactive=False)
                with gr.Row():
                    uni_plot_output = gr.Plot(label="Distribution Plot")
                    uni_summary_output = gr.Markdown(label="Summary Statistics")
            
            # Tab 3: Bivariate Analysis
            with gr.Tab("πŸ“ˆ Bivariate Analysis", id=2):
                with gr.Row():
                    bi_x_select = gr.Dropdown(label="Select X-Axis Column", interactive=False)
                    bi_y_select = gr.Dropdown(label="Select Y-Axis Column", interactive=False)
                bi_btn = gr.Button("🎨 Generate Bivariate Plot", variant="secondary")
                with gr.Row():
                    bi_plot_output = gr.Plot(label="Relationship Plot")
                    bi_summary_output = gr.Markdown(label="Analysis Summary")

            # Tab 4: Time Series (conditionally visible)
            with gr.Tab("⏳ Time Series Analysis", id=3, visible=False) as ts_tab:
                with gr.Row():
                    ts_time_col = gr.Dropdown(label="Select Time Column", interactive=False)
                    ts_value_col = gr.Dropdown(label="Select Value Column", interactive=False)
                    ts_resample = gr.Radio(['D', 'W', 'M', 'Q', 'Y'], label="Resample Frequency", value='M')
                ts_btn = gr.Button("πŸ“ˆ Plot Time Series", variant="secondary")
                ts_plot_output = gr.Plot(label="Time Series Plot")
                ts_status_output = gr.Markdown()

            # Tab 5: AI Analyst Chat (conditionally visible)
            with gr.Tab("πŸ’¬ AI Analyst Chat", id=4, visible=False) as chat_tab:
                chatbot = gr.Chatbot(label="Chat with Gemini Analyst", height=500)
                chat_plot_output = gr.Plot(label="AI Generated Plot")
                with gr.Row():
                    chat_input = gr.Textbox(label="Your Question", placeholder="e.g., 'Show me the distribution of age'", scale=4)
                    chat_submit_btn = gr.Button("Submit", variant="primary", scale=1)
                chat_clear_btn = gr.Button("Clear Chat")

        # --- Event Handlers ---
        
        # File upload triggers data loading and UI updates
        file_input.upload(
            fn=load_and_process_file,
            inputs=[file_input, global_state],
            outputs=[global_state, status_output, uni_col_select, bi_x_select, bi_y_select, ts_value_col, ts_time_col, ts_tab, chat_tab]
        )

        # Tab 1: Overview
        overview_btn.click(
            fn=analyze_dataset_overview,
            inputs=[global_state, api_key_input],
            outputs=[story_output, basic_info_output, quality_score]
        )
        
        # Tab 2: Univariate
        uni_col_select.change(
            fn=generate_univariate_plot,
            inputs=[uni_col_select, global_state],
            outputs=[uni_plot_output, uni_summary_output]
        )

        # Tab 3: Bivariate
        bi_btn.click(
            fn=generate_bivariate_plot,
            inputs=[bi_x_select, bi_y_select, global_state],
            outputs=[bi_plot_output, bi_summary_output]
        )
        
        # Tab 4: Time Series
        ts_btn.click(
            fn=generate_time_series_plot,
            inputs=[ts_time_col, ts_value_col, ts_resample, global_state],
            outputs=[ts_plot_output, ts_status_output]
        )
        
        # Tab 5: AI Chat
        chat_submit_btn.click(
            fn=respond_to_chat,
            inputs=[chat_input, chatbot, global_state, api_key_input],
            outputs=[chatbot, chat_plot_output, chat_input]
        )
        chat_input.submit(
            fn=respond_to_chat,
            inputs=[chat_input, chatbot, global_state, api_key_input],
            outputs=[chatbot, chat_plot_output, chat_input]
        )
        chat_clear_btn.click(lambda: ([], None, ""), None, [chatbot, chat_plot_output, chat_input])


    return demo

# --- Main Application Execution ---
if __name__ == "__main__":
    # For local development, you might use load_dotenv()
    # load_dotenv()
    app = create_gradio_interface()
    app.launch(debug=True)