|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
from plotly.subplots import make_subplots |
|
import io |
|
import json |
|
import warnings |
|
import google.generativeai as genai |
|
import os |
|
import logging |
|
from contextlib import redirect_stdout |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor |
|
from sklearn.linear_model import LogisticRegression, LinearRegression |
|
from sklearn.metrics import accuracy_score, confusion_matrix, r2_score, mean_squared_error |
|
from sklearn.preprocessing import LabelEncoder |
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
THEME = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan").set( |
|
body_background_fill="rgba(0,0,0,0.8)", |
|
block_background_fill="rgba(0,0,0,0.6)", |
|
block_border_width="1px", |
|
border_color_primary="rgba(255,255,255,0.1)" |
|
) |
|
MODEL_REGISTRY = { |
|
"Classification": {"Random Forest": RandomForestClassifier, "Logistic Regression": LogisticRegression}, |
|
"Regression": {"Random Forest": RandomForestRegressor, "Linear Regression": LinearRegression} |
|
} |
|
|
|
|
|
|
|
def safe_exec(code_string: str, local_vars: dict) -> tuple: |
|
"""Safely execute a string of Python code and capture its output.""" |
|
output_buffer = io.StringIO() |
|
try: |
|
with redirect_stdout(output_buffer): |
|
exec(code_string, globals(), local_vars) |
|
stdout = output_buffer.getvalue() |
|
fig = local_vars.get('fig') |
|
df_out = local_vars.get('df_result') |
|
return stdout, fig, df_out, None |
|
except Exception as e: |
|
return None, None, None, f"Execution Error: {str(e)}" |
|
|
|
def prime_data(file_obj): |
|
"""Loads, analyzes, and primes the entire application state upon file upload.""" |
|
if not file_obj: |
|
return {gr.update(visible=False): None} |
|
|
|
try: |
|
df = pd.read_csv(file_obj.name) |
|
|
|
|
|
for col in df.select_dtypes(include=['object']).columns: |
|
try: |
|
df[col] = pd.to_datetime(df[col], errors='raise') |
|
except (ValueError, TypeError): |
|
if df[col].nunique() / len(df) < 0.5: |
|
df[col] = df[col].astype('category') |
|
|
|
|
|
insights = {} |
|
metadata = extract_dataset_metadata(df) |
|
|
|
|
|
missing = df.isnull().sum() |
|
insights['missing'] = missing[missing > 0].sort_values(ascending=False) |
|
|
|
|
|
insights['high_cardinality'] = {c: df[c].nunique() for c in metadata['categorical_cols'] if df[c].nunique() > 50} |
|
|
|
|
|
if len(metadata['numeric_cols']) > 1: |
|
corr = df[metadata['numeric_cols']].corr().abs() |
|
sol = corr.unstack() |
|
so = sol.sort_values(kind="quicksort", ascending=False) |
|
so = so[so < 1] |
|
insights['high_correlations'] = so.head(5) |
|
|
|
|
|
outliers = {} |
|
for col in metadata['numeric_cols']: |
|
Q1, Q3 = df[col].quantile(0.25), df[col].quantile(0.75) |
|
IQR = Q3 - Q1 |
|
outlier_count = ((df[col] < (Q1 - 1.5 * IQR)) | (df[col] > (Q3 + 1.5 * IQR))).sum() |
|
if outlier_count > 0: |
|
outliers[col] = outlier_count |
|
insights['outliers'] = outliers |
|
|
|
|
|
suggestions = [] |
|
for col in metadata['categorical_cols']: |
|
if df[col].nunique() == 2: |
|
suggestions.append(f"{col} (Binary Classification)") |
|
for col in metadata['numeric_cols']: |
|
if df[col].nunique() > 20: |
|
suggestions.append(f"{col} (Regression)") |
|
insights['ml_suggestions'] = suggestions |
|
|
|
state = { |
|
'df_original': df, |
|
'df_modified': df.copy(), |
|
'filename': os.path.basename(file_obj.name), |
|
'metadata': metadata, |
|
'proactive_insights': insights |
|
} |
|
|
|
|
|
overview_md = generate_phoenix_eye_markdown(state) |
|
all_cols = metadata['columns'] |
|
num_cols = metadata['numeric_cols'] |
|
cat_cols = metadata['categorical_cols'] |
|
|
|
return { |
|
global_state: state, |
|
phoenix_tabs: gr.update(visible=True), |
|
phoenix_eye_output: overview_md, |
|
|
|
medic_col_select: gr.update(choices=insights['missing'].index.tolist() or [], interactive=True), |
|
|
|
oracle_target_select: gr.update(choices=all_cols, interactive=True), |
|
oracle_feature_select: gr.update(choices=all_cols, interactive=True), |
|
} |
|
|
|
except Exception as e: |
|
logging.error(f"Priming Error: {e}") |
|
return {phoenix_eye_output: gr.update(value=f"โ **Error:** {e}")} |
|
|
|
def extract_dataset_metadata(df): |
|
"""Extracts typed metadata from a DataFrame.""" |
|
rows, cols = df.shape |
|
return { |
|
'shape': (rows, cols), |
|
'columns': df.columns.tolist(), |
|
'numeric_cols': df.select_dtypes(include=np.number).columns.tolist(), |
|
'categorical_cols': df.select_dtypes(include=['object', 'category']).columns.tolist(), |
|
'datetime_cols': df.select_dtypes(include=['datetime64', 'datetime64[ns]']).columns.tolist(), |
|
'dtypes': df.dtypes.apply(lambda x: x.name).to_dict() |
|
} |
|
|
|
def generate_phoenix_eye_markdown(state): |
|
"""Creates the markdown for the proactive insights dashboard.""" |
|
insights = state['proactive_insights'] |
|
md = f"## ๐ฆ
Phoenix Eye: Proactive Insights for `{state['filename']}`\n" |
|
md += f"Dataset has **{state['metadata']['shape'][0]} rows** and **{state['metadata']['shape'][1]} columns**.\n\n" |
|
|
|
|
|
md += "### ๐ฎ Potential ML Targets\n" |
|
if insights['ml_suggestions']: |
|
for s in insights['ml_suggestions']: md += f"- `{s}`\n" |
|
else: md += "No obvious ML target columns found.\n" |
|
md += "\n" |
|
|
|
|
|
md += "### ๐ง Missing Data\n" |
|
if not insights['missing'].empty: |
|
md += "Found missing values in these columns. Use the **Data Medic** tab to fix.\n" |
|
md += insights['missing'].to_frame('Missing Count').to_markdown() + "\n" |
|
else: md += "โ
No missing data found!\n" |
|
md += "\n" |
|
|
|
|
|
md += "### ๐ Top Correlations\n" |
|
if 'high_correlations' in insights and not insights['high_correlations'].empty: |
|
md += insights['high_correlations'].to_frame('Correlation').to_markdown() + "\n" |
|
else: md += "No strong correlations found between numeric features.\n" |
|
md += "\n" |
|
|
|
|
|
md += "### ๐ Outlier Alert\n" |
|
if insights['outliers']: |
|
for col, count in insights['outliers'].items(): md += f"- `{col}` has **{count}** potential outliers.\n" |
|
else: md += "โ
No significant outliers detected.\n" |
|
md += "\n" |
|
|
|
|
|
md += "### ๐ High Cardinality Warning\n" |
|
if insights['high_cardinality']: |
|
for col, count in insights['high_cardinality'].items(): md += f"- `{col}` has **{count}** unique values, which may be problematic for some models.\n" |
|
else: md += "โ
No high-cardinality categorical columns found.\n" |
|
md += "\n" |
|
|
|
return md |
|
|
|
|
|
|
|
def medic_preview_imputation(state, col, method): |
|
"""Shows a before-and-after plot for data imputation.""" |
|
if not col: return None |
|
df_orig = state['df_original'] |
|
df_mod = df_orig.copy() |
|
|
|
if method == 'mean': value = df_mod[col].mean() |
|
elif method == 'median': value = df_mod[col].median() |
|
else: value = df_mod[col].mode()[0] |
|
|
|
df_mod[col] = df_mod[col].fillna(value) |
|
|
|
fig = go.Figure() |
|
fig.add_trace(go.Histogram(x=df_orig[col], name='Before', opacity=0.7)) |
|
fig.add_trace(go.Histogram(x=df_mod[col], name='After', opacity=0.7)) |
|
fig.update_layout(barmode='overlay', title=f"'{col}' Distribution: Before vs. After Imputation", legend_title_text='Dataset') |
|
return fig |
|
|
|
def medic_apply_imputation(state, col, method): |
|
"""Applies imputation and updates the main state.""" |
|
if not col: return state, "No column selected." |
|
df_mod = state['df_modified'].copy() |
|
|
|
if method == 'mean': value = df_mod[col].mean() |
|
elif method == 'median': value = df_mod[col].median() |
|
else: value = df_mod[col].mode()[0] |
|
|
|
df_mod[col] = df_mod[col].fillna(value) |
|
state['df_modified'] = df_mod |
|
|
|
|
|
state['proactive_insights']['missing'] = df_mod.isnull().sum() |
|
state['proactive_insights']['missing'] = state['proactive_insights']['missing'][state['proactive_insights']['missing'] > 0] |
|
|
|
return state, f"โ
Applied '{method}' imputation to '{col}'.", gr.update(choices=state['proactive_insights']['missing'].index.tolist()) |
|
|
|
def download_cleaned_data(state): |
|
"""Saves the modified dataframe to a csv and returns the path.""" |
|
if state: |
|
df = state['df_modified'] |
|
|
|
return gr.File.update(value=df.to_csv(index=False), visible=True) |
|
return gr.File.update(visible=False) |
|
|
|
def oracle_run_model(state, target, features, model_name): |
|
"""Trains a simple ML model and returns metrics and plots.""" |
|
if not target or not features: return None, None, "Please select a target and at least one feature." |
|
|
|
df = state['df_modified'].copy() |
|
|
|
|
|
df.dropna(subset=features + [target], inplace=True) |
|
if df.empty: return None, None, "Not enough data after dropping NA values." |
|
|
|
le = LabelEncoder() |
|
for col in features + [target]: |
|
if df[col].dtype == 'object' or df[col].dtype.name == 'category': |
|
df[col] = le.fit_transform(df[col]) |
|
|
|
X = df[features] |
|
y = df[target] |
|
|
|
problem_type = "Classification" if y.nunique() <= 10 else "Regression" |
|
|
|
if model_name not in MODEL_REGISTRY[problem_type]: |
|
return None, None, f"Model {model_name} not suitable for {problem_type}." |
|
|
|
model = MODEL_REGISTRY[problem_type][model_name](random_state=42) |
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) |
|
model.fit(X_train, y_train) |
|
preds = model.predict(X_test) |
|
|
|
|
|
if problem_type == "Classification": |
|
acc = accuracy_score(y_test, preds) |
|
cm = confusion_matrix(y_test, preds) |
|
cm_fig = px.imshow(cm, text_auto=True, title=f"Confusion Matrix (Accuracy: {acc:.2f})") |
|
|
|
if hasattr(model, 'feature_importances_'): |
|
fi = pd.Series(model.feature_importances_, index=features).sort_values(ascending=False) |
|
fi_fig = px.bar(fi, title="Feature Importance") |
|
return fi_fig, cm_fig, f"**Classification Report:**\n- Accuracy: {acc:.2f}" |
|
else: |
|
return None, cm_fig, f"**Classification Report:**\n- Accuracy: {acc:.2f}" |
|
|
|
else: |
|
r2 = r2_score(y_test, preds) |
|
rmse = np.sqrt(mean_squared_error(y_test, preds)) |
|
|
|
preds_fig = px.scatter(x=y_test, y=preds, labels={'x': 'Actual Values', 'y': 'Predicted Values'}, |
|
title=f"Predictions vs. Actuals (Rยฒ: {r2:.2f})", trendline='ols') |
|
|
|
if hasattr(model, 'feature_importances_'): |
|
fi = pd.Series(model.feature_importances_, index=features).sort_values(ascending=False) |
|
fi_fig = px.bar(fi, title="Feature Importance") |
|
return fi_fig, preds_fig, f"**Regression Report:**\n- Rยฒ Score: {r2:.2f}\n- RMSE: {rmse:.2f}" |
|
else: |
|
return None, preds_fig, f"**Regression Report:**\n- Rยฒ Score: {r2:.2f}\n- RMSE: {rmse:.2f}" |
|
|
|
def copilot_respond(user_message, history, state, api_key): |
|
"""Handles the AI Co-pilot chat interaction.""" |
|
if not api_key: |
|
return history + [(user_message, "I need a Gemini API key to function.")], None, None, "" |
|
|
|
history += [(user_message, None)] |
|
|
|
prompt = f""" |
|
You are 'Phoenix Co-pilot', a world-class AI data analyst. Your goal is to help the user by writing and executing Python code. |
|
You have access to a pandas DataFrame named `df`. This is the user's LATEST data, including any cleaning they've performed. |
|
|
|
**DataFrame Info:** |
|
- Columns and dtypes: {json.dumps(state['metadata']['dtypes'])} |
|
|
|
**Instructions:** |
|
1. Analyze the user's request: '{user_message}'. |
|
2. Formulate a plan (thought). |
|
3. Write Python code to execute the plan. |
|
4. Use `pandas`, `numpy`, and `plotly.express as px`. |
|
5. To show a plot, assign it to a variable `fig`. Ex: `fig = px.histogram(df, x='age')`. |
|
6. To show a dataframe, assign it to a variable `df_result`. Ex: `df_result = df.describe()`. |
|
7. Use `print()` for text output. |
|
8. **NEVER** modify `df` in place. Use `df.copy()` if needed. |
|
9. Respond **ONLY** with a single, valid JSON object with keys "thought" and "code". |
|
|
|
**User Request:** "{user_message}" |
|
|
|
**Your JSON Response:** |
|
""" |
|
|
|
try: |
|
genai.configure(api_key=api_key) |
|
model = genai.GenerativeModel('gemini-1.5-flash') |
|
response = model.generate_content(prompt) |
|
|
|
|
|
response_json = json.loads(response.text.strip().replace("```json", "").replace("```", "")) |
|
thought = response_json.get("thought", "Thinking...") |
|
code_to_run = response_json.get("code", "print('No code generated.')") |
|
|
|
bot_thinking = f"๐ง **Thinking:** *{thought}*" |
|
history[-1] = (user_message, bot_thinking) |
|
yield history, None, None, gr.update(value=code_to_run) |
|
|
|
|
|
local_vars = {'df': state['df_modified'], 'px': px, 'pd': pd, 'np': np} |
|
stdout, fig_result, df_result, error = safe_exec(code_to_run, local_vars) |
|
|
|
bot_response = bot_thinking + "\n\n---\n\n" |
|
|
|
if error: |
|
bot_response += f"๐ฅ **Execution Error:**\n```\n{error}\n```" |
|
if stdout: |
|
bot_response += f"๐ **Output:**\n```\n{stdout}\n```" |
|
if not error and not stdout and not fig_result and not isinstance(df_result, pd.DataFrame): |
|
bot_response += "โ
Code executed, but produced no direct output." |
|
|
|
history[-1] = (user_message, bot_response) |
|
yield history, fig_result, df_result, gr.update(value=code_to_run) |
|
|
|
except Exception as e: |
|
error_msg = f"A critical error occurred: {e}. The AI may have returned invalid JSON. Check the generated code." |
|
history[-1] = (user_message, error_msg) |
|
yield history, None, None, "" |
|
|
|
|
|
|
|
with gr.Blocks(theme=THEME, title="Phoenix AI Data Explorer") as demo: |
|
global_state = gr.State({}) |
|
|
|
gr.Markdown("# ๐ฅ Phoenix AI Data Explorer") |
|
gr.Markdown("The next-generation analytic tool. Upload your data to awaken the Phoenix.") |
|
|
|
with gr.Row(): |
|
file_input = gr.File(label="๐ Upload CSV", file_types=[".csv"]) |
|
api_key_input = gr.Textbox(label="๐ Gemini API Key", type="password", placeholder="Enter Google AI Studio key...") |
|
|
|
with gr.Tabs(visible=False) as phoenix_tabs: |
|
with gr.Tab("๐ฆ
Phoenix Eye"): |
|
phoenix_eye_output = gr.Markdown() |
|
|
|
with gr.Tab("๐ฉบ Data Medic"): |
|
gr.Markdown("### Cleanse Your Data\nSelect a column with missing values and choose a method to fill them.") |
|
with gr.Row(): |
|
medic_col_select = gr.Dropdown(label="Select Column to Clean") |
|
medic_method_select = gr.Radio(['mean', 'median', 'mode'], label="Imputation Method", value='mean') |
|
medic_preview_btn = gr.Button("๐ Preview Changes") |
|
medic_plot = gr.Plot() |
|
with gr.Row(): |
|
medic_apply_btn = gr.Button("โ
Apply & Save Changes", variant="primary") |
|
medic_status = gr.Textbox(label="Status", interactive=False) |
|
with gr.Accordion("Download Cleaned Data", open=False): |
|
download_btn = gr.Button("โฌ๏ธ Download Cleaned CSV") |
|
download_file_output = gr.File(label="Download Link", visible=False) |
|
|
|
with gr.Tab("๐ฎ The Oracle (Predictive Modeling)"): |
|
gr.Markdown("### Glimpse the Future\nTrain a simple model to see the predictive power of your data.") |
|
with gr.Row(): |
|
oracle_target_select = gr.Dropdown(label="๐ฏ Select Target Variable") |
|
oracle_feature_select = gr.Multiselect(label="โจ Select Features") |
|
oracle_model_select = gr.Dropdown(choices=["Random Forest", "Logistic Regression", "Linear Regression"], label="๐ง Select Model") |
|
oracle_run_btn = gr.Button("๐ Train Model!", variant="primary") |
|
oracle_status = gr.Markdown() |
|
with gr.Row(): |
|
oracle_fig1 = gr.Plot() |
|
oracle_fig2 = gr.Plot() |
|
|
|
with gr.Tab("๐ค AI Co-pilot"): |
|
gr.Markdown("### Your Conversational Analyst\nAsk any question about your data in plain English.") |
|
copilot_chatbot = gr.Chatbot(label="Chat History", height=400) |
|
with gr.Accordion("AI Generated Results", open=True): |
|
copilot_fig_output = gr.Plot() |
|
copilot_df_output = gr.Dataframe(interactive=False) |
|
with gr.Accordion("Generated Code", open=False): |
|
copilot_code_output = gr.Code(language="python", interactive=False) |
|
|
|
with gr.Row(): |
|
copilot_input = gr.Textbox(label="Your Question", placeholder="e.g., 'What's the correlation between age and salary?'", scale=4) |
|
copilot_submit_btn = gr.Button("Submit", variant="primary", scale=1) |
|
|
|
|
|
file_input.upload( |
|
fn=prime_data, |
|
inputs=file_input, |
|
outputs=[global_state, phoenix_tabs, phoenix_eye_output, medic_col_select, oracle_target_select, oracle_feature_select], |
|
show_progress="full" |
|
) |
|
|
|
|
|
medic_preview_btn.click(medic_preview_imputation, [global_state, medic_col_select, medic_method_select], medic_plot) |
|
medic_apply_btn.click(medic_apply_imputation, [global_state, medic_col_select, medic_method_select], [global_state, medic_status, medic_col_select]) |
|
download_btn.click(download_cleaned_data, [global_state], download_file_output) |
|
|
|
|
|
oracle_run_btn.click( |
|
oracle_run_model, |
|
[global_state, oracle_target_select, oracle_feature_select, oracle_model_select], |
|
[oracle_fig1, oracle_fig2, oracle_status], |
|
show_progress="full" |
|
) |
|
|
|
|
|
copilot_submit_btn.click( |
|
copilot_respond, |
|
[copilot_input, copilot_chatbot, global_state, api_key_input], |
|
[copilot_chatbot, copilot_fig_output, copilot_df_output, copilot_code_output] |
|
).then(lambda: "", copilot_input, copilot_input) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |