Update app.py
Browse files
app.py
CHANGED
@@ -5,20 +5,33 @@ import plotly.express as px
|
|
5 |
import plotly.graph_objects as go
|
6 |
from plotly.subplots import make_subplots
|
7 |
import io
|
8 |
-
|
9 |
import warnings
|
10 |
import google.generativeai as genai
|
11 |
import os
|
12 |
-
from dotenv import load_dotenv
|
13 |
import logging
|
14 |
-
import json
|
15 |
from contextlib import redirect_stdout
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# --- Configuration ---
|
18 |
warnings.filterwarnings('ignore')
|
19 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def safe_exec(code_string: str, local_vars: dict) -> tuple:
|
24 |
"""Safely execute a string of Python code and capture its output."""
|
@@ -26,291 +39,282 @@ def safe_exec(code_string: str, local_vars: dict) -> tuple:
|
|
26 |
try:
|
27 |
with redirect_stdout(output_buffer):
|
28 |
exec(code_string, globals(), local_vars)
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
return
|
33 |
except Exception as e:
|
34 |
-
|
35 |
-
logging.error(f"Error executing AI-generated code: {error_message}")
|
36 |
-
return None, None, error_message
|
37 |
-
|
38 |
-
# --- Core Data Processing ---
|
39 |
|
40 |
-
def
|
41 |
-
"""Loads
|
42 |
-
if file_obj
|
43 |
-
return
|
44 |
|
45 |
try:
|
46 |
df = pd.read_csv(file_obj.name)
|
47 |
|
48 |
-
#
|
49 |
for col in df.select_dtypes(include=['object']).columns:
|
50 |
try:
|
51 |
df[col] = pd.to_datetime(df[col], errors='raise')
|
52 |
-
logging.info(f"Successfully converted column '{col}' to datetime.")
|
53 |
except (ValueError, TypeError):
|
54 |
-
|
|
|
55 |
|
|
|
|
|
56 |
metadata = extract_dataset_metadata(df)
|
57 |
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
'metadata': metadata,
|
61 |
-
'
|
62 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
69 |
}
|
70 |
-
|
71 |
-
# Check for time series tab visibility
|
72 |
-
time_series_visible = len(metadata['datetime_cols']) > 0
|
73 |
-
|
74 |
-
return (
|
75 |
-
state_dict,
|
76 |
-
f"โ
Loaded `{state_dict['filename']}` ({metadata['shape'][0]} rows, {metadata['shape'][1]} cols)",
|
77 |
-
gr.update(**update_args), gr.update(**update_args), gr.update(**update_args),
|
78 |
-
gr.update(choices=metadata['numeric_cols'], value=None, interactive=True),
|
79 |
-
gr.update(choices=metadata['datetime_cols'], value=None, interactive=True),
|
80 |
-
gr.update(visible=time_series_visible), # Show/hide Time Series tab
|
81 |
-
gr.update(visible=True) # Show Chatbot tab
|
82 |
-
)
|
83 |
except Exception as e:
|
84 |
-
logging.error(f"Error
|
85 |
-
return
|
86 |
|
87 |
-
def extract_dataset_metadata(df
|
88 |
-
"""Extracts
|
89 |
rows, cols = df.shape
|
90 |
-
columns = df.columns.tolist()
|
91 |
-
|
92 |
-
numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
|
93 |
-
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
94 |
-
datetime_cols = df.select_dtypes(include=['datetime64', 'datetime64[ns]']).columns.tolist()
|
95 |
-
|
96 |
-
missing_data = df.isnull().sum()
|
97 |
-
data_quality = round((df.notna().sum().sum() / (rows * cols)) * 100, 1) if rows * cols > 0 else 0
|
98 |
-
|
99 |
return {
|
100 |
'shape': (rows, cols),
|
101 |
-
'columns': columns,
|
102 |
-
'numeric_cols':
|
103 |
-
'categorical_cols':
|
104 |
-
'datetime_cols':
|
105 |
-
'dtypes': df.dtypes.
|
106 |
-
'missing_data': missing_data.to_dict(),
|
107 |
-
'data_quality': data_quality,
|
108 |
-
'head': df.head().to_string()
|
109 |
}
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
"
|
115 |
-
|
116 |
-
return "โ Please upload a dataset first.", "", 0
|
117 |
-
if not api_key:
|
118 |
-
return "โ Please enter your Gemini API key.", "", 0
|
119 |
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
-
#
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
- **Categorical Columns:** {', '.join(metadata['categorical_cols'])}
|
131 |
-
- **Datetime Columns:** {', '.join(metadata['datetime_cols'])}
|
132 |
-
- **Data Quality (Non-missing values):** {metadata['data_quality']}%
|
133 |
-
- **First 5 rows:**
|
134 |
-
{metadata['head']}
|
135 |
-
|
136 |
-
**Your Task:**
|
137 |
-
Based on the metadata, generate a report in Markdown format. Use emojis to make it visually appealing. The report should have the following sections:
|
138 |
-
|
139 |
-
# ๐ AI-Powered Dataset Overview
|
140 |
-
|
141 |
-
## ๐ค What is this dataset likely about?
|
142 |
-
(Predict the domain and purpose of the dataset, e.g., "This appears to be customer transaction data for an e-commerce platform.")
|
143 |
-
|
144 |
-
## ๐ก Potential Key Questions to Explore
|
145 |
-
- (Suggest 3-4 interesting business or research questions the data could answer.)
|
146 |
-
- (Example: "Which products are most frequently purchased together?")
|
147 |
-
|
148 |
-
## ๐ Potential Analyses & Visualizations
|
149 |
-
- (List 3-4 types of analyses that would be valuable.)
|
150 |
-
- (Example: "Time series analysis of sales to identify seasonality.")
|
151 |
-
|
152 |
-
## โ ๏ธ Data Quality & Potential Issues
|
153 |
-
- (Briefly comment on the data quality score and mention if the presence of datetime columns is a good sign for certain analyses.)
|
154 |
-
"""
|
155 |
-
|
156 |
-
try:
|
157 |
-
genai.configure(api_key=api_key)
|
158 |
-
model = genai.GenerativeModel('gemini-1.5-flash')
|
159 |
-
response = model.generate_content(prompt)
|
160 |
-
story = response.text
|
161 |
-
except Exception as e:
|
162 |
-
story = f"## โ ๏ธ AI Generation Failed\n**Error:** {str(e)}\n\nPlease check your API key and network connection. A fallback analysis is provided below.\n\n" \
|
163 |
-
f"### Fallback Analysis\nThis dataset contains **{metadata['shape'][0]}** records and **{metadata['shape'][1]}** features. " \
|
164 |
-
f"It includes **{len(metadata['numeric_cols'])}** numeric, **{len(metadata['categorical_cols'])}** categorical, " \
|
165 |
-
f"and **{len(metadata['datetime_cols'])}** time-based columns. The overall data quality is **{metadata['data_quality']}%**, " \
|
166 |
-
f"which is a good starting point for analysis."
|
167 |
-
|
168 |
-
# Basic Info Summary
|
169 |
-
basic_info = f"""
|
170 |
-
๐ **File:** `{state_dict.get('filename', 'N/A')}`
|
171 |
-
๐ **Size:** {metadata['shape'][0]:,} rows ร {metadata['shape'][1]} columns
|
172 |
-
๐ข **Features:**
|
173 |
-
โข **Numeric:** {len(metadata['numeric_cols'])}
|
174 |
-
โข **Categorical:** {len(metadata['categorical_cols'])}
|
175 |
-
โข **DateTime:** {len(metadata['datetime_cols'])}
|
176 |
-
๐ฏ **Data Quality:** {metadata['data_quality']}%
|
177 |
-
"""
|
178 |
|
179 |
-
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
return None, "Select a column to analyze."
|
187 |
|
188 |
-
|
189 |
-
|
|
|
190 |
|
191 |
-
|
192 |
-
summary = ""
|
193 |
-
|
194 |
-
if column_name in metadata['numeric_cols']:
|
195 |
-
fig = make_subplots(rows=1, cols=2, subplot_titles=("Histogram", "Box Plot"))
|
196 |
-
fig.add_trace(go.Histogram(x=df[column_name], name="Histogram"), row=1, col=1)
|
197 |
-
fig.add_trace(go.Box(y=df[column_name], name="Box Plot"), row=1, col=2)
|
198 |
-
fig.update_layout(title_text=f"Distribution of '{column_name}'", showlegend=False)
|
199 |
-
summary = df[column_name].describe().to_frame().to_markdown()
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
return fig, summary
|
216 |
-
|
217 |
-
# --- Tab 3: Bivariate Analysis ---
|
218 |
-
|
219 |
-
def generate_bivariate_plot(x_col, y_col, state_dict):
|
220 |
-
"""Generates plots to explore the relationship between two variables."""
|
221 |
-
if not x_col or not y_col or not state_dict:
|
222 |
-
return None, "Select two columns to analyze."
|
223 |
-
if x_col == y_col:
|
224 |
-
return None, "Please select two different columns."
|
225 |
-
|
226 |
-
df = state_dict['df']
|
227 |
-
metadata = state_dict['metadata']
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
-
|
250 |
-
|
251 |
-
# --- Tab 4: Time Series Analysis ---
|
252 |
-
|
253 |
-
def generate_time_series_plot(time_col, value_col, resample_freq, state_dict):
|
254 |
-
"""Generates a time series plot with resampling."""
|
255 |
-
if not time_col or not value_col or not state_dict:
|
256 |
-
return None, "Select Time and Value columns."
|
257 |
-
|
258 |
-
df = state_dict['df'].copy()
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
if not api_key:
|
276 |
-
history
|
277 |
-
return history, None, ""
|
278 |
-
|
279 |
-
if not state_dict:
|
280 |
-
history.append((user_message, "Please upload a dataset before asking questions."))
|
281 |
-
return history, None, ""
|
282 |
-
|
283 |
-
history.append((user_message, None))
|
284 |
|
285 |
-
|
286 |
|
287 |
-
# Construct a robust prompt for the AI
|
288 |
prompt = f"""
|
289 |
-
You are
|
290 |
-
You
|
291 |
-
|
292 |
-
|
|
|
|
|
293 |
**Instructions:**
|
294 |
-
1. Analyze the user's
|
295 |
-
2.
|
296 |
-
3.
|
297 |
-
4.
|
298 |
-
5.
|
299 |
-
6.
|
300 |
-
7.
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
**
|
305 |
-
|
306 |
-
- **Shape:** {df_metadata['shape'][0]} rows, {df_metadata['shape'][1]} columns
|
307 |
-
- **Columns and Data Types:**
|
308 |
-
{df_metadata['dtypes']}
|
309 |
-
|
310 |
-
---
|
311 |
-
**User Question:** "{user_message}"
|
312 |
-
---
|
313 |
-
|
314 |
**Your JSON Response:**
|
315 |
"""
|
316 |
|
@@ -319,168 +323,118 @@ def respond_to_chat(user_message, history, state_dict, api_key):
|
|
319 |
model = genai.GenerativeModel('gemini-1.5-flash')
|
320 |
response = model.generate_content(prompt)
|
321 |
|
322 |
-
# Clean and parse
|
323 |
-
|
324 |
-
response_json = json.loads(response_text)
|
325 |
-
|
326 |
thought = response_json.get("thought", "Thinking...")
|
327 |
-
code_to_run = response_json.get("code", "")
|
328 |
|
329 |
-
|
|
|
|
|
330 |
|
331 |
-
# Execute
|
332 |
-
local_vars = {'df':
|
333 |
-
stdout, fig_result, error = safe_exec(code_to_run, local_vars)
|
|
|
|
|
334 |
|
335 |
if error:
|
336 |
-
|
337 |
-
history[-1] = (user_message, bot_message)
|
338 |
-
return history, None, ""
|
339 |
-
|
340 |
if stdout:
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
|
|
|
|
345 |
|
346 |
-
history[-1] = (user_message, bot_message)
|
347 |
-
|
348 |
-
return history, fig_result, ""
|
349 |
-
|
350 |
except Exception as e:
|
351 |
-
error_msg = f"
|
352 |
-
logging.error(f"Chatbot error: {error_msg}")
|
353 |
history[-1] = (user_message, error_msg)
|
354 |
-
|
355 |
-
|
356 |
-
# --- Gradio
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
gr.
|
366 |
-
gr.
|
367 |
-
|
368 |
-
|
369 |
-
with gr.
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
with
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
uni_summary_output = gr.Markdown(label="Summary Statistics")
|
398 |
-
|
399 |
-
# Tab 3: Bivariate Analysis
|
400 |
-
with gr.Tab("๐ Bivariate Analysis", id=2):
|
401 |
-
with gr.Row():
|
402 |
-
bi_x_select = gr.Dropdown(label="Select X-Axis Column", interactive=False)
|
403 |
-
bi_y_select = gr.Dropdown(label="Select Y-Axis Column", interactive=False)
|
404 |
-
bi_btn = gr.Button("๐จ Generate Bivariate Plot", variant="secondary")
|
405 |
-
with gr.Row():
|
406 |
-
bi_plot_output = gr.Plot(label="Relationship Plot")
|
407 |
-
bi_summary_output = gr.Markdown(label="Analysis Summary")
|
408 |
-
|
409 |
-
# Tab 4: Time Series (conditionally visible)
|
410 |
-
with gr.Tab("โณ Time Series Analysis", id=3, visible=False) as ts_tab:
|
411 |
-
with gr.Row():
|
412 |
-
ts_time_col = gr.Dropdown(label="Select Time Column", interactive=False)
|
413 |
-
ts_value_col = gr.Dropdown(label="Select Value Column", interactive=False)
|
414 |
-
ts_resample = gr.Radio(['D', 'W', 'M', 'Q', 'Y'], label="Resample Frequency", value='M')
|
415 |
-
ts_btn = gr.Button("๐ Plot Time Series", variant="secondary")
|
416 |
-
ts_plot_output = gr.Plot(label="Time Series Plot")
|
417 |
-
ts_status_output = gr.Markdown()
|
418 |
-
|
419 |
-
# Tab 5: AI Analyst Chat (conditionally visible)
|
420 |
-
with gr.Tab("๐ฌ AI Analyst Chat", id=4, visible=False) as chat_tab:
|
421 |
-
chatbot = gr.Chatbot(label="Chat with Gemini Analyst", height=500)
|
422 |
-
chat_plot_output = gr.Plot(label="AI Generated Plot")
|
423 |
-
with gr.Row():
|
424 |
-
chat_input = gr.Textbox(label="Your Question", placeholder="e.g., 'Show me the distribution of age'", scale=4)
|
425 |
-
chat_submit_btn = gr.Button("Submit", variant="primary", scale=1)
|
426 |
-
chat_clear_btn = gr.Button("Clear Chat")
|
427 |
-
|
428 |
-
# --- Event Handlers ---
|
429 |
-
|
430 |
-
# File upload triggers data loading and UI updates
|
431 |
-
file_input.upload(
|
432 |
-
fn=load_and_process_file,
|
433 |
-
inputs=[file_input, global_state],
|
434 |
-
outputs=[global_state, status_output, uni_col_select, bi_x_select, bi_y_select, ts_value_col, ts_time_col, ts_tab, chat_tab]
|
435 |
-
)
|
436 |
-
|
437 |
-
# Tab 1: Overview
|
438 |
-
overview_btn.click(
|
439 |
-
fn=analyze_dataset_overview,
|
440 |
-
inputs=[global_state, api_key_input],
|
441 |
-
outputs=[story_output, basic_info_output, quality_score]
|
442 |
-
)
|
443 |
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
|
|
|
|
|
|
482 |
if __name__ == "__main__":
|
483 |
-
|
484 |
-
# load_dotenv()
|
485 |
-
app = create_gradio_interface()
|
486 |
-
app.launch(debug=True)
|
|
|
5 |
import plotly.graph_objects as go
|
6 |
from plotly.subplots import make_subplots
|
7 |
import io
|
8 |
+
import json
|
9 |
import warnings
|
10 |
import google.generativeai as genai
|
11 |
import os
|
|
|
12 |
import logging
|
|
|
13 |
from contextlib import redirect_stdout
|
14 |
+
from sklearn.model_selection import train_test_split
|
15 |
+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
16 |
+
from sklearn.linear_model import LogisticRegression, LinearRegression
|
17 |
+
from sklearn.metrics import accuracy_score, confusion_matrix, r2_score, mean_squared_error
|
18 |
+
from sklearn.preprocessing import LabelEncoder
|
19 |
|
20 |
# --- Configuration ---
|
21 |
warnings.filterwarnings('ignore')
|
22 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
23 |
+
THEME = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan").set(
|
24 |
+
body_background_fill="rgba(0,0,0,0.8)",
|
25 |
+
block_background_fill="rgba(0,0,0,0.6)",
|
26 |
+
block_border_width="1px",
|
27 |
+
border_color_primary="rgba(255,255,255,0.1)"
|
28 |
+
)
|
29 |
+
MODEL_REGISTRY = {
|
30 |
+
"Classification": {"Random Forest": RandomForestClassifier, "Logistic Regression": LogisticRegression},
|
31 |
+
"Regression": {"Random Forest": RandomForestRegressor, "Linear Regression": LinearRegression}
|
32 |
+
}
|
33 |
+
|
34 |
+
# --- Core Logic ---
|
35 |
|
36 |
def safe_exec(code_string: str, local_vars: dict) -> tuple:
|
37 |
"""Safely execute a string of Python code and capture its output."""
|
|
|
39 |
try:
|
40 |
with redirect_stdout(output_buffer):
|
41 |
exec(code_string, globals(), local_vars)
|
42 |
+
stdout = output_buffer.getvalue()
|
43 |
+
fig = local_vars.get('fig')
|
44 |
+
df_out = local_vars.get('df_result')
|
45 |
+
return stdout, fig, df_out, None
|
46 |
except Exception as e:
|
47 |
+
return None, None, None, f"Execution Error: {str(e)}"
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
def prime_data(file_obj):
|
50 |
+
"""Loads, analyzes, and primes the entire application state upon file upload."""
|
51 |
+
if not file_obj:
|
52 |
+
return {gr.update(visible=False): None}
|
53 |
|
54 |
try:
|
55 |
df = pd.read_csv(file_obj.name)
|
56 |
|
57 |
+
# Smart type conversion
|
58 |
for col in df.select_dtypes(include=['object']).columns:
|
59 |
try:
|
60 |
df[col] = pd.to_datetime(df[col], errors='raise')
|
|
|
61 |
except (ValueError, TypeError):
|
62 |
+
if df[col].nunique() / len(df) < 0.5: # If not too many unique values
|
63 |
+
df[col] = df[col].astype('category')
|
64 |
|
65 |
+
# --- Phoenix Eye: Proactive Insights Engine ---
|
66 |
+
insights = {}
|
67 |
metadata = extract_dataset_metadata(df)
|
68 |
|
69 |
+
# 1. Missing Data
|
70 |
+
missing = df.isnull().sum()
|
71 |
+
insights['missing'] = missing[missing > 0].sort_values(ascending=False)
|
72 |
+
|
73 |
+
# 2. High Cardinality
|
74 |
+
insights['high_cardinality'] = {c: df[c].nunique() for c in metadata['categorical_cols'] if df[c].nunique() > 50}
|
75 |
+
|
76 |
+
# 3. High Correlations
|
77 |
+
if len(metadata['numeric_cols']) > 1:
|
78 |
+
corr = df[metadata['numeric_cols']].corr().abs()
|
79 |
+
sol = corr.unstack()
|
80 |
+
so = sol.sort_values(kind="quicksort", ascending=False)
|
81 |
+
so = so[so < 1] # Remove self-correlation
|
82 |
+
insights['high_correlations'] = so.head(5)
|
83 |
+
|
84 |
+
# 4. Outlier Detection (IQR method)
|
85 |
+
outliers = {}
|
86 |
+
for col in metadata['numeric_cols']:
|
87 |
+
Q1, Q3 = df[col].quantile(0.25), df[col].quantile(0.75)
|
88 |
+
IQR = Q3 - Q1
|
89 |
+
outlier_count = ((df[col] < (Q1 - 1.5 * IQR)) | (df[col] > (Q3 + 1.5 * IQR))).sum()
|
90 |
+
if outlier_count > 0:
|
91 |
+
outliers[col] = outlier_count
|
92 |
+
insights['outliers'] = outliers
|
93 |
+
|
94 |
+
# 5. ML Target Suggestion
|
95 |
+
suggestions = []
|
96 |
+
for col in metadata['categorical_cols']:
|
97 |
+
if df[col].nunique() == 2:
|
98 |
+
suggestions.append(f"{col} (Binary Classification)")
|
99 |
+
for col in metadata['numeric_cols']:
|
100 |
+
if df[col].nunique() > 20: # Heuristic for continuous target
|
101 |
+
suggestions.append(f"{col} (Regression)")
|
102 |
+
insights['ml_suggestions'] = suggestions
|
103 |
+
|
104 |
+
state = {
|
105 |
+
'df_original': df,
|
106 |
+
'df_modified': df.copy(),
|
107 |
+
'filename': os.path.basename(file_obj.name),
|
108 |
'metadata': metadata,
|
109 |
+
'proactive_insights': insights
|
110 |
}
|
111 |
+
|
112 |
+
# Generate UI updates
|
113 |
+
overview_md = generate_phoenix_eye_markdown(state)
|
114 |
+
all_cols = metadata['columns']
|
115 |
+
num_cols = metadata['numeric_cols']
|
116 |
+
cat_cols = metadata['categorical_cols']
|
117 |
|
118 |
+
return {
|
119 |
+
global_state: state,
|
120 |
+
phoenix_tabs: gr.update(visible=True),
|
121 |
+
phoenix_eye_output: overview_md,
|
122 |
+
# Data Medic updates
|
123 |
+
medic_col_select: gr.update(choices=insights['missing'].index.tolist() or [], interactive=True),
|
124 |
+
# Oracle updates
|
125 |
+
oracle_target_select: gr.update(choices=all_cols, interactive=True),
|
126 |
+
oracle_feature_select: gr.update(choices=all_cols, interactive=True),
|
127 |
}
|
128 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
except Exception as e:
|
130 |
+
logging.error(f"Priming Error: {e}")
|
131 |
+
return {phoenix_eye_output: gr.update(value=f"โ **Error:** {e}")}
|
132 |
|
133 |
+
def extract_dataset_metadata(df):
|
134 |
+
"""Extracts typed metadata from a DataFrame."""
|
135 |
rows, cols = df.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
return {
|
137 |
'shape': (rows, cols),
|
138 |
+
'columns': df.columns.tolist(),
|
139 |
+
'numeric_cols': df.select_dtypes(include=np.number).columns.tolist(),
|
140 |
+
'categorical_cols': df.select_dtypes(include=['object', 'category']).columns.tolist(),
|
141 |
+
'datetime_cols': df.select_dtypes(include=['datetime64', 'datetime64[ns]']).columns.tolist(),
|
142 |
+
'dtypes': df.dtypes.apply(lambda x: x.name).to_dict()
|
|
|
|
|
|
|
143 |
}
|
144 |
|
145 |
+
def generate_phoenix_eye_markdown(state):
|
146 |
+
"""Creates the markdown for the proactive insights dashboard."""
|
147 |
+
insights = state['proactive_insights']
|
148 |
+
md = f"## ๐ฆ
Phoenix Eye: Proactive Insights for `{state['filename']}`\n"
|
149 |
+
md += f"Dataset has **{state['metadata']['shape'][0]} rows** and **{state['metadata']['shape'][1]} columns**.\n\n"
|
|
|
|
|
|
|
150 |
|
151 |
+
# ML Suggestions
|
152 |
+
md += "### ๐ฎ Potential ML Targets\n"
|
153 |
+
if insights['ml_suggestions']:
|
154 |
+
for s in insights['ml_suggestions']: md += f"- `{s}`\n"
|
155 |
+
else: md += "No obvious ML target columns found.\n"
|
156 |
+
md += "\n"
|
157 |
+
|
158 |
+
# Missing Data
|
159 |
+
md += "### ๐ง Missing Data\n"
|
160 |
+
if not insights['missing'].empty:
|
161 |
+
md += "Found missing values in these columns. Use the **Data Medic** tab to fix.\n"
|
162 |
+
md += insights['missing'].to_frame('Missing Count').to_markdown() + "\n"
|
163 |
+
else: md += "โ
No missing data found!\n"
|
164 |
+
md += "\n"
|
165 |
+
|
166 |
+
# High Correlation
|
167 |
+
md += "### ๐ Top Correlations\n"
|
168 |
+
if 'high_correlations' in insights and not insights['high_correlations'].empty:
|
169 |
+
md += insights['high_correlations'].to_frame('Correlation').to_markdown() + "\n"
|
170 |
+
else: md += "No strong correlations found between numeric features.\n"
|
171 |
+
md += "\n"
|
172 |
+
|
173 |
+
# Outliers
|
174 |
+
md += "### ๐ Outlier Alert\n"
|
175 |
+
if insights['outliers']:
|
176 |
+
for col, count in insights['outliers'].items(): md += f"- `{col}` has **{count}** potential outliers.\n"
|
177 |
+
else: md += "โ
No significant outliers detected.\n"
|
178 |
+
md += "\n"
|
179 |
|
180 |
+
# High Cardinality
|
181 |
+
md += "### ๐ High Cardinality Warning\n"
|
182 |
+
if insights['high_cardinality']:
|
183 |
+
for col, count in insights['high_cardinality'].items(): md += f"- `{col}` has **{count}** unique values, which may be problematic for some models.\n"
|
184 |
+
else: md += "โ
No high-cardinality categorical columns found.\n"
|
185 |
+
md += "\n"
|
186 |
+
|
187 |
+
return md
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
+
# --- Tab Handlers ---
|
190 |
|
191 |
+
def medic_preview_imputation(state, col, method):
|
192 |
+
"""Shows a before-and-after plot for data imputation."""
|
193 |
+
if not col: return None
|
194 |
+
df_orig = state['df_original']
|
195 |
+
df_mod = df_orig.copy()
|
|
|
196 |
|
197 |
+
if method == 'mean': value = df_mod[col].mean()
|
198 |
+
elif method == 'median': value = df_mod[col].median()
|
199 |
+
else: value = df_mod[col].mode()[0]
|
200 |
|
201 |
+
df_mod[col] = df_mod[col].fillna(value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
+
fig = go.Figure()
|
204 |
+
fig.add_trace(go.Histogram(x=df_orig[col], name='Before', opacity=0.7))
|
205 |
+
fig.add_trace(go.Histogram(x=df_mod[col], name='After', opacity=0.7))
|
206 |
+
fig.update_layout(barmode='overlay', title=f"'{col}' Distribution: Before vs. After Imputation", legend_title_text='Dataset')
|
207 |
+
return fig
|
208 |
+
|
209 |
+
def medic_apply_imputation(state, col, method):
|
210 |
+
"""Applies imputation and updates the main state."""
|
211 |
+
if not col: return state, "No column selected."
|
212 |
+
df_mod = state['df_modified'].copy()
|
213 |
+
|
214 |
+
if method == 'mean': value = df_mod[col].mean()
|
215 |
+
elif method == 'median': value = df_mod[col].median()
|
216 |
+
else: value = df_mod[col].mode()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
+
df_mod[col] = df_mod[col].fillna(value)
|
219 |
+
state['df_modified'] = df_mod
|
220 |
+
|
221 |
+
# Re-run proactive insights on the modified df
|
222 |
+
state['proactive_insights']['missing'] = df_mod.isnull().sum()
|
223 |
+
state['proactive_insights']['missing'] = state['proactive_insights']['missing'][state['proactive_insights']['missing'] > 0]
|
224 |
+
|
225 |
+
return state, f"โ
Applied '{method}' imputation to '{col}'.", gr.update(choices=state['proactive_insights']['missing'].index.tolist())
|
226 |
+
|
227 |
+
def download_cleaned_data(state):
|
228 |
+
"""Saves the modified dataframe to a csv and returns the path."""
|
229 |
+
if state:
|
230 |
+
df = state['df_modified']
|
231 |
+
# Gradio handles the tempfile creation
|
232 |
+
return gr.File.update(value=df.to_csv(index=False), visible=True)
|
233 |
+
return gr.File.update(visible=False)
|
234 |
+
|
235 |
+
def oracle_run_model(state, target, features, model_name):
|
236 |
+
"""Trains a simple ML model and returns metrics and plots."""
|
237 |
+
if not target or not features: return None, None, "Please select a target and at least one feature."
|
238 |
+
|
239 |
+
df = state['df_modified'].copy()
|
240 |
+
|
241 |
+
# Preprocessing
|
242 |
+
df.dropna(subset=features + [target], inplace=True)
|
243 |
+
if df.empty: return None, None, "Not enough data after dropping NA values."
|
244 |
+
|
245 |
+
le = LabelEncoder()
|
246 |
+
for col in features + [target]:
|
247 |
+
if df[col].dtype == 'object' or df[col].dtype.name == 'category':
|
248 |
+
df[col] = le.fit_transform(df[col])
|
249 |
+
|
250 |
+
X = df[features]
|
251 |
+
y = df[target]
|
252 |
+
|
253 |
+
problem_type = "Classification" if y.nunique() <= 10 else "Regression"
|
254 |
+
|
255 |
+
if model_name not in MODEL_REGISTRY[problem_type]:
|
256 |
+
return None, None, f"Model {model_name} not suitable for {problem_type}."
|
257 |
|
258 |
+
model = MODEL_REGISTRY[problem_type][model_name](random_state=42)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
|
260 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
|
261 |
+
model.fit(X_train, y_train)
|
262 |
+
preds = model.predict(X_test)
|
263 |
+
|
264 |
+
# Results
|
265 |
+
if problem_type == "Classification":
|
266 |
+
acc = accuracy_score(y_test, preds)
|
267 |
+
cm = confusion_matrix(y_test, preds)
|
268 |
+
cm_fig = px.imshow(cm, text_auto=True, title=f"Confusion Matrix (Accuracy: {acc:.2f})")
|
269 |
|
270 |
+
if hasattr(model, 'feature_importances_'):
|
271 |
+
fi = pd.Series(model.feature_importances_, index=features).sort_values(ascending=False)
|
272 |
+
fi_fig = px.bar(fi, title="Feature Importance")
|
273 |
+
return fi_fig, cm_fig, f"**Classification Report:**\n- Accuracy: {acc:.2f}"
|
274 |
+
else:
|
275 |
+
return None, cm_fig, f"**Classification Report:**\n- Accuracy: {acc:.2f}"
|
276 |
+
|
277 |
+
else: # Regression
|
278 |
+
r2 = r2_score(y_test, preds)
|
279 |
+
rmse = np.sqrt(mean_squared_error(y_test, preds))
|
280 |
+
|
281 |
+
preds_fig = px.scatter(x=y_test, y=preds, labels={'x': 'Actual Values', 'y': 'Predicted Values'},
|
282 |
+
title=f"Predictions vs. Actuals (Rยฒ: {r2:.2f})", trendline='ols')
|
283 |
+
|
284 |
+
if hasattr(model, 'feature_importances_'):
|
285 |
+
fi = pd.Series(model.feature_importances_, index=features).sort_values(ascending=False)
|
286 |
+
fi_fig = px.bar(fi, title="Feature Importance")
|
287 |
+
return fi_fig, preds_fig, f"**Regression Report:**\n- Rยฒ Score: {r2:.2f}\n- RMSE: {rmse:.2f}"
|
288 |
+
else:
|
289 |
+
return None, preds_fig, f"**Regression Report:**\n- Rยฒ Score: {r2:.2f}\n- RMSE: {rmse:.2f}"
|
290 |
+
|
291 |
+
def copilot_respond(user_message, history, state, api_key):
|
292 |
+
"""Handles the AI Co-pilot chat interaction."""
|
293 |
if not api_key:
|
294 |
+
return history + [(user_message, "I need a Gemini API key to function.")], None, None, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
+
history += [(user_message, None)]
|
297 |
|
|
|
298 |
prompt = f"""
|
299 |
+
You are 'Phoenix Co-pilot', a world-class AI data analyst. Your goal is to help the user by writing and executing Python code.
|
300 |
+
You have access to a pandas DataFrame named `df`. This is the user's LATEST data, including any cleaning they've performed.
|
301 |
+
|
302 |
+
**DataFrame Info:**
|
303 |
+
- Columns and dtypes: {json.dumps(state['metadata']['dtypes'])}
|
304 |
+
|
305 |
**Instructions:**
|
306 |
+
1. Analyze the user's request: '{user_message}'.
|
307 |
+
2. Formulate a plan (thought).
|
308 |
+
3. Write Python code to execute the plan.
|
309 |
+
4. Use `pandas`, `numpy`, and `plotly.express as px`.
|
310 |
+
5. To show a plot, assign it to a variable `fig`. Ex: `fig = px.histogram(df, x='age')`.
|
311 |
+
6. To show a dataframe, assign it to a variable `df_result`. Ex: `df_result = df.describe()`.
|
312 |
+
7. Use `print()` for text output.
|
313 |
+
8. **NEVER** modify `df` in place. Use `df.copy()` if needed.
|
314 |
+
9. Respond **ONLY** with a single, valid JSON object with keys "thought" and "code".
|
315 |
+
|
316 |
+
**User Request:** "{user_message}"
|
317 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
**Your JSON Response:**
|
319 |
"""
|
320 |
|
|
|
323 |
model = genai.GenerativeModel('gemini-1.5-flash')
|
324 |
response = model.generate_content(prompt)
|
325 |
|
326 |
+
# Clean and parse JSON
|
327 |
+
response_json = json.loads(response.text.strip().replace("```json", "").replace("```", ""))
|
|
|
|
|
328 |
thought = response_json.get("thought", "Thinking...")
|
329 |
+
code_to_run = response_json.get("code", "print('No code generated.')")
|
330 |
|
331 |
+
bot_thinking = f"๐ง **Thinking:** *{thought}*"
|
332 |
+
history[-1] = (user_message, bot_thinking)
|
333 |
+
yield history, None, None, gr.update(value=code_to_run)
|
334 |
|
335 |
+
# Execute Code
|
336 |
+
local_vars = {'df': state['df_modified'], 'px': px, 'pd': pd, 'np': np}
|
337 |
+
stdout, fig_result, df_result, error = safe_exec(code_to_run, local_vars)
|
338 |
+
|
339 |
+
bot_response = bot_thinking + "\n\n---\n\n"
|
340 |
|
341 |
if error:
|
342 |
+
bot_response += f"๐ฅ **Execution Error:**\n```\n{error}\n```"
|
|
|
|
|
|
|
343 |
if stdout:
|
344 |
+
bot_response += f"๐ **Output:**\n```\n{stdout}\n```"
|
345 |
+
if not error and not stdout and not fig_result and not isinstance(df_result, pd.DataFrame):
|
346 |
+
bot_response += "โ
Code executed, but produced no direct output."
|
347 |
+
|
348 |
+
history[-1] = (user_message, bot_response)
|
349 |
+
yield history, fig_result, df_result, gr.update(value=code_to_run)
|
350 |
|
|
|
|
|
|
|
|
|
351 |
except Exception as e:
|
352 |
+
error_msg = f"A critical error occurred: {e}. The AI may have returned invalid JSON. Check the generated code."
|
|
|
353 |
history[-1] = (user_message, error_msg)
|
354 |
+
yield history, None, None, ""
|
355 |
+
|
356 |
+
# --- Gradio UI Construction ---
|
357 |
+
|
358 |
+
with gr.Blocks(theme=THEME, title="Phoenix AI Data Explorer") as demo:
|
359 |
+
global_state = gr.State({})
|
360 |
+
|
361 |
+
gr.Markdown("# ๐ฅ Phoenix AI Data Explorer")
|
362 |
+
gr.Markdown("The next-generation analytic tool. Upload your data to awaken the Phoenix.")
|
363 |
+
|
364 |
+
with gr.Row():
|
365 |
+
file_input = gr.File(label="๐ Upload CSV", file_types=[".csv"])
|
366 |
+
api_key_input = gr.Textbox(label="๐ Gemini API Key", type="password", placeholder="Enter Google AI Studio key...")
|
367 |
+
|
368 |
+
with gr.Tabs(visible=False) as phoenix_tabs:
|
369 |
+
with gr.Tab("๐ฆ
Phoenix Eye"):
|
370 |
+
phoenix_eye_output = gr.Markdown()
|
371 |
+
|
372 |
+
with gr.Tab("๐ฉบ Data Medic"):
|
373 |
+
gr.Markdown("### Cleanse Your Data\nSelect a column with missing values and choose a method to fill them.")
|
374 |
+
with gr.Row():
|
375 |
+
medic_col_select = gr.Dropdown(label="Select Column to Clean")
|
376 |
+
medic_method_select = gr.Radio(['mean', 'median', 'mode'], label="Imputation Method", value='mean')
|
377 |
+
medic_preview_btn = gr.Button("๐ Preview Changes")
|
378 |
+
medic_plot = gr.Plot()
|
379 |
+
with gr.Row():
|
380 |
+
medic_apply_btn = gr.Button("โ
Apply & Save Changes", variant="primary")
|
381 |
+
medic_status = gr.Textbox(label="Status", interactive=False)
|
382 |
+
with gr.Accordion("Download Cleaned Data", open=False):
|
383 |
+
download_btn = gr.Button("โฌ๏ธ Download Cleaned CSV")
|
384 |
+
download_file_output = gr.File(label="Download Link", visible=False)
|
385 |
+
|
386 |
+
with gr.Tab("๐ฎ The Oracle (Predictive Modeling)"):
|
387 |
+
gr.Markdown("### Glimpse the Future\nTrain a simple model to see the predictive power of your data.")
|
388 |
+
with gr.Row():
|
389 |
+
oracle_target_select = gr.Dropdown(label="๐ฏ Select Target Variable")
|
390 |
+
oracle_feature_select = gr.Multiselect(label="โจ Select Features")
|
391 |
+
oracle_model_select = gr.Dropdown(choices=["Random Forest", "Logistic Regression", "Linear Regression"], label="๐ง Select Model")
|
392 |
+
oracle_run_btn = gr.Button("๐ Train Model!", variant="primary")
|
393 |
+
oracle_status = gr.Markdown()
|
394 |
+
with gr.Row():
|
395 |
+
oracle_fig1 = gr.Plot()
|
396 |
+
oracle_fig2 = gr.Plot()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
+
with gr.Tab("๐ค AI Co-pilot"):
|
399 |
+
gr.Markdown("### Your Conversational Analyst\nAsk any question about your data in plain English.")
|
400 |
+
copilot_chatbot = gr.Chatbot(label="Chat History", height=400)
|
401 |
+
with gr.Accordion("AI Generated Results", open=True):
|
402 |
+
copilot_fig_output = gr.Plot()
|
403 |
+
copilot_df_output = gr.Dataframe(interactive=False)
|
404 |
+
with gr.Accordion("Generated Code", open=False):
|
405 |
+
copilot_code_output = gr.Code(language="python", interactive=False)
|
406 |
+
|
407 |
+
with gr.Row():
|
408 |
+
copilot_input = gr.Textbox(label="Your Question", placeholder="e.g., 'What's the correlation between age and salary?'", scale=4)
|
409 |
+
copilot_submit_btn = gr.Button("Submit", variant="primary", scale=1)
|
410 |
+
|
411 |
+
# --- Event Wiring ---
|
412 |
+
file_input.upload(
|
413 |
+
fn=prime_data,
|
414 |
+
inputs=file_input,
|
415 |
+
outputs=[global_state, phoenix_tabs, phoenix_eye_output, medic_col_select, oracle_target_select, oracle_feature_select],
|
416 |
+
show_progress="full"
|
417 |
+
)
|
418 |
+
|
419 |
+
# Data Medic
|
420 |
+
medic_preview_btn.click(medic_preview_imputation, [global_state, medic_col_select, medic_method_select], medic_plot)
|
421 |
+
medic_apply_btn.click(medic_apply_imputation, [global_state, medic_col_select, medic_method_select], [global_state, medic_status, medic_col_select])
|
422 |
+
download_btn.click(download_cleaned_data, [global_state], download_file_output)
|
423 |
+
|
424 |
+
# Oracle
|
425 |
+
oracle_run_btn.click(
|
426 |
+
oracle_run_model,
|
427 |
+
[global_state, oracle_target_select, oracle_feature_select, oracle_model_select],
|
428 |
+
[oracle_fig1, oracle_fig2, oracle_status],
|
429 |
+
show_progress="full"
|
430 |
+
)
|
431 |
+
|
432 |
+
# AI Co-pilot
|
433 |
+
copilot_submit_btn.click(
|
434 |
+
copilot_respond,
|
435 |
+
[copilot_input, copilot_chatbot, global_state, api_key_input],
|
436 |
+
[copilot_chatbot, copilot_fig_output, copilot_df_output, copilot_code_output]
|
437 |
+
).then(lambda: "", copilot_input, copilot_input) # Clear input after submit
|
438 |
+
|
439 |
if __name__ == "__main__":
|
440 |
+
demo.launch(debug=True)
|
|
|
|
|
|