mgbam commited on
Commit
f7b84f1
ยท
verified ยท
1 Parent(s): c08faed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +358 -404
app.py CHANGED
@@ -5,20 +5,33 @@ import plotly.express as px
5
  import plotly.graph_objects as go
6
  from plotly.subplots import make_subplots
7
  import io
8
- from scipy import stats
9
  import warnings
10
  import google.generativeai as genai
11
  import os
12
- from dotenv import load_dotenv
13
  import logging
14
- import json
15
  from contextlib import redirect_stdout
 
 
 
 
 
16
 
17
  # --- Configuration ---
18
  warnings.filterwarnings('ignore')
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
-
21
- # --- Helper Functions ---
 
 
 
 
 
 
 
 
 
 
22
 
23
  def safe_exec(code_string: str, local_vars: dict) -> tuple:
24
  """Safely execute a string of Python code and capture its output."""
@@ -26,291 +39,282 @@ def safe_exec(code_string: str, local_vars: dict) -> tuple:
26
  try:
27
  with redirect_stdout(output_buffer):
28
  exec(code_string, globals(), local_vars)
29
-
30
- stdout_output = output_buffer.getvalue()
31
- fig = local_vars.get('fig', None)
32
- return stdout_output, fig, None
33
  except Exception as e:
34
- error_message = f"Execution Error: {str(e)}"
35
- logging.error(f"Error executing AI-generated code: {error_message}")
36
- return None, None, error_message
37
-
38
- # --- Core Data Processing ---
39
 
40
- def load_and_process_file(file_obj, state_dict):
41
- """Loads a CSV file and performs initial processing, updating the global state."""
42
- if file_obj is None:
43
- return None, "Please upload a file.", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
44
 
45
  try:
46
  df = pd.read_csv(file_obj.name)
47
 
48
- # Attempt to convert object columns to datetime
49
  for col in df.select_dtypes(include=['object']).columns:
50
  try:
51
  df[col] = pd.to_datetime(df[col], errors='raise')
52
- logging.info(f"Successfully converted column '{col}' to datetime.")
53
  except (ValueError, TypeError):
54
- continue
 
55
 
 
 
56
  metadata = extract_dataset_metadata(df)
57
 
58
- state_dict = {
59
- 'df': df,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  'metadata': metadata,
61
- 'filename': os.path.basename(file_obj.name)
62
  }
 
 
 
 
 
 
63
 
64
- # Update UI elements dynamically
65
- update_args = {
66
- 'choices': metadata['columns'],
67
- 'value': None,
68
- 'interactive': True
 
 
 
 
69
  }
70
-
71
- # Check for time series tab visibility
72
- time_series_visible = len(metadata['datetime_cols']) > 0
73
-
74
- return (
75
- state_dict,
76
- f"โœ… Loaded `{state_dict['filename']}` ({metadata['shape'][0]} rows, {metadata['shape'][1]} cols)",
77
- gr.update(**update_args), gr.update(**update_args), gr.update(**update_args),
78
- gr.update(choices=metadata['numeric_cols'], value=None, interactive=True),
79
- gr.update(choices=metadata['datetime_cols'], value=None, interactive=True),
80
- gr.update(visible=time_series_visible), # Show/hide Time Series tab
81
- gr.update(visible=True) # Show Chatbot tab
82
- )
83
  except Exception as e:
84
- logging.error(f"Error loading file: {e}")
85
- return state_dict, f"โŒ Error: {e}", gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), gr.update(visible=False)
86
 
87
- def extract_dataset_metadata(df: pd.DataFrame) -> dict:
88
- """Extracts comprehensive metadata from a DataFrame."""
89
  rows, cols = df.shape
90
- columns = df.columns.tolist()
91
-
92
- numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
93
- categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
94
- datetime_cols = df.select_dtypes(include=['datetime64', 'datetime64[ns]']).columns.tolist()
95
-
96
- missing_data = df.isnull().sum()
97
- data_quality = round((df.notna().sum().sum() / (rows * cols)) * 100, 1) if rows * cols > 0 else 0
98
-
99
  return {
100
  'shape': (rows, cols),
101
- 'columns': columns,
102
- 'numeric_cols': numeric_cols,
103
- 'categorical_cols': categorical_cols,
104
- 'datetime_cols': datetime_cols,
105
- 'dtypes': df.dtypes.to_string(),
106
- 'missing_data': missing_data.to_dict(),
107
- 'data_quality': data_quality,
108
- 'head': df.head().to_string()
109
  }
110
 
111
- # --- Tab 1: AI Overview ---
112
-
113
- def analyze_dataset_overview(state_dict, api_key: str):
114
- """Generates an AI-powered narrative overview of the dataset."""
115
- if not state_dict:
116
- return "โŒ Please upload a dataset first.", "", 0
117
- if not api_key:
118
- return "โŒ Please enter your Gemini API key.", "", 0
119
 
120
- metadata = state_dict['metadata']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- # Create prompt for Gemini
123
- prompt = f"""
124
- You are an expert data analyst and storyteller. Your task is to provide a high-level, engaging overview of a dataset based on its metadata.
125
-
126
- **Dataset Metadata:**
127
- - **Shape:** {metadata['shape'][0]} rows, {metadata['shape'][1]} columns
128
- - **Column Names:** {', '.join(metadata['columns'])}
129
- - **Numeric Columns:** {', '.join(metadata['numeric_cols'])}
130
- - **Categorical Columns:** {', '.join(metadata['categorical_cols'])}
131
- - **Datetime Columns:** {', '.join(metadata['datetime_cols'])}
132
- - **Data Quality (Non-missing values):** {metadata['data_quality']}%
133
- - **First 5 rows:**
134
- {metadata['head']}
135
-
136
- **Your Task:**
137
- Based on the metadata, generate a report in Markdown format. Use emojis to make it visually appealing. The report should have the following sections:
138
-
139
- # ๐Ÿš€ AI-Powered Dataset Overview
140
-
141
- ## ๐Ÿค” What is this dataset likely about?
142
- (Predict the domain and purpose of the dataset, e.g., "This appears to be customer transaction data for an e-commerce platform.")
143
-
144
- ## ๐Ÿ’ก Potential Key Questions to Explore
145
- - (Suggest 3-4 interesting business or research questions the data could answer.)
146
- - (Example: "Which products are most frequently purchased together?")
147
-
148
- ## ๐Ÿ“Š Potential Analyses & Visualizations
149
- - (List 3-4 types of analyses that would be valuable.)
150
- - (Example: "Time series analysis of sales to identify seasonality.")
151
-
152
- ## โš ๏ธ Data Quality & Potential Issues
153
- - (Briefly comment on the data quality score and mention if the presence of datetime columns is a good sign for certain analyses.)
154
- """
155
-
156
- try:
157
- genai.configure(api_key=api_key)
158
- model = genai.GenerativeModel('gemini-1.5-flash')
159
- response = model.generate_content(prompt)
160
- story = response.text
161
- except Exception as e:
162
- story = f"## โš ๏ธ AI Generation Failed\n**Error:** {str(e)}\n\nPlease check your API key and network connection. A fallback analysis is provided below.\n\n" \
163
- f"### Fallback Analysis\nThis dataset contains **{metadata['shape'][0]}** records and **{metadata['shape'][1]}** features. " \
164
- f"It includes **{len(metadata['numeric_cols'])}** numeric, **{len(metadata['categorical_cols'])}** categorical, " \
165
- f"and **{len(metadata['datetime_cols'])}** time-based columns. The overall data quality is **{metadata['data_quality']}%**, " \
166
- f"which is a good starting point for analysis."
167
-
168
- # Basic Info Summary
169
- basic_info = f"""
170
- ๐Ÿ“‹ **File:** `{state_dict.get('filename', 'N/A')}`
171
- ๐Ÿ“Š **Size:** {metadata['shape'][0]:,} rows ร— {metadata['shape'][1]} columns
172
- ๐Ÿ”ข **Features:**
173
- โ€ข **Numeric:** {len(metadata['numeric_cols'])}
174
- โ€ข **Categorical:** {len(metadata['categorical_cols'])}
175
- โ€ข **DateTime:** {len(metadata['datetime_cols'])}
176
- ๐ŸŽฏ **Data Quality:** {metadata['data_quality']}%
177
- """
178
 
179
- return story, basic_info, metadata['data_quality']
180
 
181
- # --- Tab 2: Univariate Analysis ---
182
-
183
- def generate_univariate_plot(column_name, state_dict):
184
- """Generates plots for a single selected variable."""
185
- if not column_name or not state_dict:
186
- return None, "Select a column to analyze."
187
 
188
- df = state_dict['df']
189
- metadata = state_dict['metadata']
 
190
 
191
- fig = None
192
- summary = ""
193
-
194
- if column_name in metadata['numeric_cols']:
195
- fig = make_subplots(rows=1, cols=2, subplot_titles=("Histogram", "Box Plot"))
196
- fig.add_trace(go.Histogram(x=df[column_name], name="Histogram"), row=1, col=1)
197
- fig.add_trace(go.Box(y=df[column_name], name="Box Plot"), row=1, col=2)
198
- fig.update_layout(title_text=f"Distribution of '{column_name}'", showlegend=False)
199
- summary = df[column_name].describe().to_frame().to_markdown()
200
 
201
- elif column_name in metadata['categorical_cols']:
202
- top_n = 20
203
- counts = df[column_name].value_counts()
204
- title = f"Top {min(top_n, len(counts))} Categories for '{column_name}'"
205
- fig = px.bar(counts.nlargest(top_n), title=title, labels={'index': column_name, 'value': 'Count'})
206
- fig.update_layout(showlegend=False)
207
- summary = counts.to_frame().to_markdown()
208
-
209
- elif column_name in metadata['datetime_cols']:
210
- counts = df[column_name].dt.to_period("M").value_counts().sort_index()
211
- fig = px.line(x=counts.index.to_timestamp(), y=counts.values, title=f"Records over Time for '{column_name}'")
212
- fig.update_layout(xaxis_title="Time", yaxis_title="Record Count")
213
- summary = df[column_name].describe(datetime_is_numeric=True).to_frame().to_markdown()
214
-
215
- return fig, summary
216
-
217
- # --- Tab 3: Bivariate Analysis ---
218
-
219
- def generate_bivariate_plot(x_col, y_col, state_dict):
220
- """Generates plots to explore the relationship between two variables."""
221
- if not x_col or not y_col or not state_dict:
222
- return None, "Select two columns to analyze."
223
- if x_col == y_col:
224
- return None, "Please select two different columns."
225
-
226
- df = state_dict['df']
227
- metadata = state_dict['metadata']
228
 
229
- x_type = 'numeric' if x_col in metadata['numeric_cols'] else 'categorical'
230
- y_type = 'numeric' if y_col in metadata['numeric_cols'] else 'categorical'
231
-
232
- fig = None
233
- title = f"{x_col} vs. {y_col}"
234
-
235
- if x_type == 'numeric' and y_type == 'numeric':
236
- fig = px.scatter(df, x=x_col, y=y_col, title=f"Scatter Plot: {title}", trendline="ols", trendline_color_override="red")
237
- summary = df[[x_col, y_col]].corr().to_markdown()
238
- elif x_type == 'numeric' and y_type == 'categorical':
239
- fig = px.box(df, x=x_col, y=y_col, title=f"Box Plot: {title}")
240
- summary = df.groupby(y_col)[x_col].describe().to_markdown()
241
- elif x_type == 'categorical' and y_type == 'numeric':
242
- fig = px.box(df, x=y_col, y=x_col, title=f"Box Plot: {title}")
243
- summary = df.groupby(x_col)[y_col].describe().to_markdown()
244
- else: # Both categorical
245
- crosstab = pd.crosstab(df[x_col], df[y_col])
246
- fig = px.imshow(crosstab, title=f"Heatmap of Counts: {title}", text_auto=True)
247
- summary = crosstab.to_markdown()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
- return fig, f"### Analysis Summary\n{summary}"
250
-
251
- # --- Tab 4: Time Series Analysis ---
252
-
253
- def generate_time_series_plot(time_col, value_col, resample_freq, state_dict):
254
- """Generates a time series plot with resampling."""
255
- if not time_col or not value_col or not state_dict:
256
- return None, "Select Time and Value columns."
257
-
258
- df = state_dict['df'].copy()
259
 
260
- try:
261
- df[time_col] = pd.to_datetime(df[time_col])
262
- df_resampled = df.set_index(time_col)[value_col].resample(resample_freq).mean().reset_index()
 
 
 
 
 
 
263
 
264
- fig = px.line(df_resampled, x=time_col, y=value_col,
265
- title=f"Time Series of {value_col} (Resampled to '{resample_freq}')")
266
- fig.update_layout(xaxis_title="Date", yaxis_title=f"Mean of {value_col}")
267
- return fig, f"Showing mean of '{value_col}' aggregated by '{resample_freq}'."
268
- except Exception as e:
269
- return None, f"Error: {e}"
270
-
271
- # --- Tab 5: AI Analyst Chat ---
272
-
273
- def respond_to_chat(user_message, history, state_dict, api_key):
274
- """Handles the chat interaction with the AI Analyst."""
 
 
 
 
 
 
 
 
 
 
 
 
275
  if not api_key:
276
- history.append((user_message, "I can't answer without a Gemini API key. Please enter it in the 'AI Overview' tab."))
277
- return history, None, ""
278
-
279
- if not state_dict:
280
- history.append((user_message, "Please upload a dataset before asking questions."))
281
- return history, None, ""
282
-
283
- history.append((user_message, None))
284
 
285
- df_metadata = state_dict['metadata']
286
 
287
- # Construct a robust prompt for the AI
288
  prompt = f"""
289
- You are an AI Data Analyst assistant. Your name is 'Gemini Analyst'.
290
- You are given a pandas DataFrame named `df`.
291
- Your goal is to answer the user's question about this DataFrame by writing and executing Python code.
292
-
 
 
293
  **Instructions:**
294
- 1. Analyze the user's question.
295
- 2. Write Python code to answer it.
296
- 3. You can use pandas, numpy, and plotly.express.
297
- 4. If you create a plot, you **MUST** assign it to a variable named `fig`. The plot will be displayed to the user.
298
- 5. If you are just calculating something or printing text, the `print()` output will be shown.
299
- 6. **DO NOT** write any code that modifies the DataFrame (e.g., `df.dropna(inplace=True)`). Use `df.copy()` if you need to modify data.
300
- 7. Respond **ONLY** with a JSON object containing two keys: "thought" and "code".
301
- - "thought": A short, one-sentence explanation of your plan.
302
- - "code": A string containing the Python code to execute.
303
-
304
- **DataFrame Metadata:**
305
- - **Filename:** {state_dict['filename']}
306
- - **Shape:** {df_metadata['shape'][0]} rows, {df_metadata['shape'][1]} columns
307
- - **Columns and Data Types:**
308
- {df_metadata['dtypes']}
309
-
310
- ---
311
- **User Question:** "{user_message}"
312
- ---
313
-
314
  **Your JSON Response:**
315
  """
316
 
@@ -319,168 +323,118 @@ def respond_to_chat(user_message, history, state_dict, api_key):
319
  model = genai.GenerativeModel('gemini-1.5-flash')
320
  response = model.generate_content(prompt)
321
 
322
- # Clean and parse the JSON response
323
- response_text = response.text.strip().replace("```json", "").replace("```", "")
324
- response_json = json.loads(response_text)
325
-
326
  thought = response_json.get("thought", "Thinking...")
327
- code_to_run = response_json.get("code", "")
328
 
329
- bot_message = f"๐Ÿง  **Thought:** {thought}\n\n"
 
 
330
 
331
- # Execute the code
332
- local_vars = {'df': state_dict['df'], 'px': px, 'pd': pd, 'np': np}
333
- stdout, fig_result, error = safe_exec(code_to_run, local_vars)
 
 
334
 
335
  if error:
336
- bot_message += f"๐Ÿ’ฅ **Error:**\n```\n{error}\n```"
337
- history[-1] = (user_message, bot_message)
338
- return history, None, ""
339
-
340
  if stdout:
341
- bot_message += f"๐Ÿ“‹ **Output:**\n```\n{stdout}\n```"
342
-
343
- if not fig_result and not stdout:
344
- bot_message += "โœ… Code executed successfully, but it produced no visible output."
 
 
345
 
346
- history[-1] = (user_message, bot_message)
347
-
348
- return history, fig_result, ""
349
-
350
  except Exception as e:
351
- error_msg = f"An unexpected error occurred: {e}. The AI might have returned an invalid response. Please try rephrasing your question."
352
- logging.error(f"Chatbot error: {error_msg}")
353
  history[-1] = (user_message, error_msg)
354
- return history, None, ""
355
-
356
- # --- Gradio Interface ---
357
-
358
- def create_gradio_interface():
359
- """Builds and returns the full Gradio application interface."""
360
- with gr.Blocks(title="๐Ÿš€ AI Data Explorer", theme=gr.themes.Soft()) as demo:
361
- # Global state to hold data
362
- global_state = gr.State({})
363
-
364
- # Header
365
- gr.Markdown("# ๐Ÿš€ AI Data Explorer: Your Advanced Analytic Tool")
366
- gr.Markdown("Upload a CSV, then explore your data with interactive tabs and a powerful AI Analyst.")
367
-
368
- # --- Top Row: File Upload and API Key ---
369
- with gr.Row():
370
- with gr.Column(scale=2):
371
- file_input = gr.File(label="๐Ÿ“ Upload CSV File", file_types=[".csv"])
372
- status_output = gr.Markdown("Status: Waiting for file...")
373
- with gr.Column(scale=1):
374
- api_key_input = gr.Textbox(
375
- label="๐Ÿ”‘ Gemini API Key",
376
- placeholder="Enter your key here...",
377
- type="password",
378
- info="Get your free key from Google AI Studio"
379
- )
380
-
381
- # --- Main Tabs ---
382
- with gr.Tabs() as tabs:
383
- # Tab 1: AI Overview
384
- with gr.Tab("๐Ÿค– AI Overview", id=0):
385
- overview_btn = gr.Button("๐Ÿง  Generate AI Overview", variant="primary")
386
- with gr.Row():
387
- story_output = gr.Markdown(label="๐Ÿ“– AI-Generated Story")
388
- with gr.Column():
389
- basic_info_output = gr.Markdown(label="๐Ÿ“‹ Basic Information")
390
- quality_score = gr.Number(label="๐ŸŽฏ Data Quality Score (%)", interactive=False)
391
-
392
- # Tab 2: Univariate Analysis
393
- with gr.Tab("๐Ÿ“Š Univariate Analysis", id=1):
394
- uni_col_select = gr.Dropdown(label="Select a Column to Analyze", interactive=False)
395
- with gr.Row():
396
- uni_plot_output = gr.Plot(label="Distribution Plot")
397
- uni_summary_output = gr.Markdown(label="Summary Statistics")
398
-
399
- # Tab 3: Bivariate Analysis
400
- with gr.Tab("๐Ÿ“ˆ Bivariate Analysis", id=2):
401
- with gr.Row():
402
- bi_x_select = gr.Dropdown(label="Select X-Axis Column", interactive=False)
403
- bi_y_select = gr.Dropdown(label="Select Y-Axis Column", interactive=False)
404
- bi_btn = gr.Button("๐ŸŽจ Generate Bivariate Plot", variant="secondary")
405
- with gr.Row():
406
- bi_plot_output = gr.Plot(label="Relationship Plot")
407
- bi_summary_output = gr.Markdown(label="Analysis Summary")
408
-
409
- # Tab 4: Time Series (conditionally visible)
410
- with gr.Tab("โณ Time Series Analysis", id=3, visible=False) as ts_tab:
411
- with gr.Row():
412
- ts_time_col = gr.Dropdown(label="Select Time Column", interactive=False)
413
- ts_value_col = gr.Dropdown(label="Select Value Column", interactive=False)
414
- ts_resample = gr.Radio(['D', 'W', 'M', 'Q', 'Y'], label="Resample Frequency", value='M')
415
- ts_btn = gr.Button("๐Ÿ“ˆ Plot Time Series", variant="secondary")
416
- ts_plot_output = gr.Plot(label="Time Series Plot")
417
- ts_status_output = gr.Markdown()
418
-
419
- # Tab 5: AI Analyst Chat (conditionally visible)
420
- with gr.Tab("๐Ÿ’ฌ AI Analyst Chat", id=4, visible=False) as chat_tab:
421
- chatbot = gr.Chatbot(label="Chat with Gemini Analyst", height=500)
422
- chat_plot_output = gr.Plot(label="AI Generated Plot")
423
- with gr.Row():
424
- chat_input = gr.Textbox(label="Your Question", placeholder="e.g., 'Show me the distribution of age'", scale=4)
425
- chat_submit_btn = gr.Button("Submit", variant="primary", scale=1)
426
- chat_clear_btn = gr.Button("Clear Chat")
427
-
428
- # --- Event Handlers ---
429
-
430
- # File upload triggers data loading and UI updates
431
- file_input.upload(
432
- fn=load_and_process_file,
433
- inputs=[file_input, global_state],
434
- outputs=[global_state, status_output, uni_col_select, bi_x_select, bi_y_select, ts_value_col, ts_time_col, ts_tab, chat_tab]
435
- )
436
-
437
- # Tab 1: Overview
438
- overview_btn.click(
439
- fn=analyze_dataset_overview,
440
- inputs=[global_state, api_key_input],
441
- outputs=[story_output, basic_info_output, quality_score]
442
- )
443
 
444
- # Tab 2: Univariate
445
- uni_col_select.change(
446
- fn=generate_univariate_plot,
447
- inputs=[uni_col_select, global_state],
448
- outputs=[uni_plot_output, uni_summary_output]
449
- )
450
-
451
- # Tab 3: Bivariate
452
- bi_btn.click(
453
- fn=generate_bivariate_plot,
454
- inputs=[bi_x_select, bi_y_select, global_state],
455
- outputs=[bi_plot_output, bi_summary_output]
456
- )
457
-
458
- # Tab 4: Time Series
459
- ts_btn.click(
460
- fn=generate_time_series_plot,
461
- inputs=[ts_time_col, ts_value_col, ts_resample, global_state],
462
- outputs=[ts_plot_output, ts_status_output]
463
- )
464
-
465
- # Tab 5: AI Chat
466
- chat_submit_btn.click(
467
- fn=respond_to_chat,
468
- inputs=[chat_input, chatbot, global_state, api_key_input],
469
- outputs=[chatbot, chat_plot_output, chat_input]
470
- )
471
- chat_input.submit(
472
- fn=respond_to_chat,
473
- inputs=[chat_input, chatbot, global_state, api_key_input],
474
- outputs=[chatbot, chat_plot_output, chat_input]
475
- )
476
- chat_clear_btn.click(lambda: ([], None, ""), None, [chatbot, chat_plot_output, chat_input])
477
-
478
-
479
- return demo
480
-
481
- # --- Main Application Execution ---
 
 
 
482
  if __name__ == "__main__":
483
- # For local development, you might use load_dotenv()
484
- # load_dotenv()
485
- app = create_gradio_interface()
486
- app.launch(debug=True)
 
5
  import plotly.graph_objects as go
6
  from plotly.subplots import make_subplots
7
  import io
8
+ import json
9
  import warnings
10
  import google.generativeai as genai
11
  import os
 
12
  import logging
 
13
  from contextlib import redirect_stdout
14
+ from sklearn.model_selection import train_test_split
15
+ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
16
+ from sklearn.linear_model import LogisticRegression, LinearRegression
17
+ from sklearn.metrics import accuracy_score, confusion_matrix, r2_score, mean_squared_error
18
+ from sklearn.preprocessing import LabelEncoder
19
 
20
  # --- Configuration ---
21
  warnings.filterwarnings('ignore')
22
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ THEME = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan").set(
24
+ body_background_fill="rgba(0,0,0,0.8)",
25
+ block_background_fill="rgba(0,0,0,0.6)",
26
+ block_border_width="1px",
27
+ border_color_primary="rgba(255,255,255,0.1)"
28
+ )
29
+ MODEL_REGISTRY = {
30
+ "Classification": {"Random Forest": RandomForestClassifier, "Logistic Regression": LogisticRegression},
31
+ "Regression": {"Random Forest": RandomForestRegressor, "Linear Regression": LinearRegression}
32
+ }
33
+
34
+ # --- Core Logic ---
35
 
36
  def safe_exec(code_string: str, local_vars: dict) -> tuple:
37
  """Safely execute a string of Python code and capture its output."""
 
39
  try:
40
  with redirect_stdout(output_buffer):
41
  exec(code_string, globals(), local_vars)
42
+ stdout = output_buffer.getvalue()
43
+ fig = local_vars.get('fig')
44
+ df_out = local_vars.get('df_result')
45
+ return stdout, fig, df_out, None
46
  except Exception as e:
47
+ return None, None, None, f"Execution Error: {str(e)}"
 
 
 
 
48
 
49
+ def prime_data(file_obj):
50
+ """Loads, analyzes, and primes the entire application state upon file upload."""
51
+ if not file_obj:
52
+ return {gr.update(visible=False): None}
53
 
54
  try:
55
  df = pd.read_csv(file_obj.name)
56
 
57
+ # Smart type conversion
58
  for col in df.select_dtypes(include=['object']).columns:
59
  try:
60
  df[col] = pd.to_datetime(df[col], errors='raise')
 
61
  except (ValueError, TypeError):
62
+ if df[col].nunique() / len(df) < 0.5: # If not too many unique values
63
+ df[col] = df[col].astype('category')
64
 
65
+ # --- Phoenix Eye: Proactive Insights Engine ---
66
+ insights = {}
67
  metadata = extract_dataset_metadata(df)
68
 
69
+ # 1. Missing Data
70
+ missing = df.isnull().sum()
71
+ insights['missing'] = missing[missing > 0].sort_values(ascending=False)
72
+
73
+ # 2. High Cardinality
74
+ insights['high_cardinality'] = {c: df[c].nunique() for c in metadata['categorical_cols'] if df[c].nunique() > 50}
75
+
76
+ # 3. High Correlations
77
+ if len(metadata['numeric_cols']) > 1:
78
+ corr = df[metadata['numeric_cols']].corr().abs()
79
+ sol = corr.unstack()
80
+ so = sol.sort_values(kind="quicksort", ascending=False)
81
+ so = so[so < 1] # Remove self-correlation
82
+ insights['high_correlations'] = so.head(5)
83
+
84
+ # 4. Outlier Detection (IQR method)
85
+ outliers = {}
86
+ for col in metadata['numeric_cols']:
87
+ Q1, Q3 = df[col].quantile(0.25), df[col].quantile(0.75)
88
+ IQR = Q3 - Q1
89
+ outlier_count = ((df[col] < (Q1 - 1.5 * IQR)) | (df[col] > (Q3 + 1.5 * IQR))).sum()
90
+ if outlier_count > 0:
91
+ outliers[col] = outlier_count
92
+ insights['outliers'] = outliers
93
+
94
+ # 5. ML Target Suggestion
95
+ suggestions = []
96
+ for col in metadata['categorical_cols']:
97
+ if df[col].nunique() == 2:
98
+ suggestions.append(f"{col} (Binary Classification)")
99
+ for col in metadata['numeric_cols']:
100
+ if df[col].nunique() > 20: # Heuristic for continuous target
101
+ suggestions.append(f"{col} (Regression)")
102
+ insights['ml_suggestions'] = suggestions
103
+
104
+ state = {
105
+ 'df_original': df,
106
+ 'df_modified': df.copy(),
107
+ 'filename': os.path.basename(file_obj.name),
108
  'metadata': metadata,
109
+ 'proactive_insights': insights
110
  }
111
+
112
+ # Generate UI updates
113
+ overview_md = generate_phoenix_eye_markdown(state)
114
+ all_cols = metadata['columns']
115
+ num_cols = metadata['numeric_cols']
116
+ cat_cols = metadata['categorical_cols']
117
 
118
+ return {
119
+ global_state: state,
120
+ phoenix_tabs: gr.update(visible=True),
121
+ phoenix_eye_output: overview_md,
122
+ # Data Medic updates
123
+ medic_col_select: gr.update(choices=insights['missing'].index.tolist() or [], interactive=True),
124
+ # Oracle updates
125
+ oracle_target_select: gr.update(choices=all_cols, interactive=True),
126
+ oracle_feature_select: gr.update(choices=all_cols, interactive=True),
127
  }
128
+
 
 
 
 
 
 
 
 
 
 
 
 
129
  except Exception as e:
130
+ logging.error(f"Priming Error: {e}")
131
+ return {phoenix_eye_output: gr.update(value=f"โŒ **Error:** {e}")}
132
 
133
+ def extract_dataset_metadata(df):
134
+ """Extracts typed metadata from a DataFrame."""
135
  rows, cols = df.shape
 
 
 
 
 
 
 
 
 
136
  return {
137
  'shape': (rows, cols),
138
+ 'columns': df.columns.tolist(),
139
+ 'numeric_cols': df.select_dtypes(include=np.number).columns.tolist(),
140
+ 'categorical_cols': df.select_dtypes(include=['object', 'category']).columns.tolist(),
141
+ 'datetime_cols': df.select_dtypes(include=['datetime64', 'datetime64[ns]']).columns.tolist(),
142
+ 'dtypes': df.dtypes.apply(lambda x: x.name).to_dict()
 
 
 
143
  }
144
 
145
+ def generate_phoenix_eye_markdown(state):
146
+ """Creates the markdown for the proactive insights dashboard."""
147
+ insights = state['proactive_insights']
148
+ md = f"## ๐Ÿฆ… Phoenix Eye: Proactive Insights for `{state['filename']}`\n"
149
+ md += f"Dataset has **{state['metadata']['shape'][0]} rows** and **{state['metadata']['shape'][1]} columns**.\n\n"
 
 
 
150
 
151
+ # ML Suggestions
152
+ md += "### ๐Ÿ”ฎ Potential ML Targets\n"
153
+ if insights['ml_suggestions']:
154
+ for s in insights['ml_suggestions']: md += f"- `{s}`\n"
155
+ else: md += "No obvious ML target columns found.\n"
156
+ md += "\n"
157
+
158
+ # Missing Data
159
+ md += "### ๐Ÿ’ง Missing Data\n"
160
+ if not insights['missing'].empty:
161
+ md += "Found missing values in these columns. Use the **Data Medic** tab to fix.\n"
162
+ md += insights['missing'].to_frame('Missing Count').to_markdown() + "\n"
163
+ else: md += "โœ… No missing data found!\n"
164
+ md += "\n"
165
+
166
+ # High Correlation
167
+ md += "### ๐Ÿ”— Top Correlations\n"
168
+ if 'high_correlations' in insights and not insights['high_correlations'].empty:
169
+ md += insights['high_correlations'].to_frame('Correlation').to_markdown() + "\n"
170
+ else: md += "No strong correlations found between numeric features.\n"
171
+ md += "\n"
172
+
173
+ # Outliers
174
+ md += "### ๐Ÿ“ˆ Outlier Alert\n"
175
+ if insights['outliers']:
176
+ for col, count in insights['outliers'].items(): md += f"- `{col}` has **{count}** potential outliers.\n"
177
+ else: md += "โœ… No significant outliers detected.\n"
178
+ md += "\n"
179
 
180
+ # High Cardinality
181
+ md += "### ๐Ÿ‡‡ High Cardinality Warning\n"
182
+ if insights['high_cardinality']:
183
+ for col, count in insights['high_cardinality'].items(): md += f"- `{col}` has **{count}** unique values, which may be problematic for some models.\n"
184
+ else: md += "โœ… No high-cardinality categorical columns found.\n"
185
+ md += "\n"
186
+
187
+ return md
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ # --- Tab Handlers ---
190
 
191
+ def medic_preview_imputation(state, col, method):
192
+ """Shows a before-and-after plot for data imputation."""
193
+ if not col: return None
194
+ df_orig = state['df_original']
195
+ df_mod = df_orig.copy()
 
196
 
197
+ if method == 'mean': value = df_mod[col].mean()
198
+ elif method == 'median': value = df_mod[col].median()
199
+ else: value = df_mod[col].mode()[0]
200
 
201
+ df_mod[col] = df_mod[col].fillna(value)
 
 
 
 
 
 
 
 
202
 
203
+ fig = go.Figure()
204
+ fig.add_trace(go.Histogram(x=df_orig[col], name='Before', opacity=0.7))
205
+ fig.add_trace(go.Histogram(x=df_mod[col], name='After', opacity=0.7))
206
+ fig.update_layout(barmode='overlay', title=f"'{col}' Distribution: Before vs. After Imputation", legend_title_text='Dataset')
207
+ return fig
208
+
209
+ def medic_apply_imputation(state, col, method):
210
+ """Applies imputation and updates the main state."""
211
+ if not col: return state, "No column selected."
212
+ df_mod = state['df_modified'].copy()
213
+
214
+ if method == 'mean': value = df_mod[col].mean()
215
+ elif method == 'median': value = df_mod[col].median()
216
+ else: value = df_mod[col].mode()[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ df_mod[col] = df_mod[col].fillna(value)
219
+ state['df_modified'] = df_mod
220
+
221
+ # Re-run proactive insights on the modified df
222
+ state['proactive_insights']['missing'] = df_mod.isnull().sum()
223
+ state['proactive_insights']['missing'] = state['proactive_insights']['missing'][state['proactive_insights']['missing'] > 0]
224
+
225
+ return state, f"โœ… Applied '{method}' imputation to '{col}'.", gr.update(choices=state['proactive_insights']['missing'].index.tolist())
226
+
227
+ def download_cleaned_data(state):
228
+ """Saves the modified dataframe to a csv and returns the path."""
229
+ if state:
230
+ df = state['df_modified']
231
+ # Gradio handles the tempfile creation
232
+ return gr.File.update(value=df.to_csv(index=False), visible=True)
233
+ return gr.File.update(visible=False)
234
+
235
+ def oracle_run_model(state, target, features, model_name):
236
+ """Trains a simple ML model and returns metrics and plots."""
237
+ if not target or not features: return None, None, "Please select a target and at least one feature."
238
+
239
+ df = state['df_modified'].copy()
240
+
241
+ # Preprocessing
242
+ df.dropna(subset=features + [target], inplace=True)
243
+ if df.empty: return None, None, "Not enough data after dropping NA values."
244
+
245
+ le = LabelEncoder()
246
+ for col in features + [target]:
247
+ if df[col].dtype == 'object' or df[col].dtype.name == 'category':
248
+ df[col] = le.fit_transform(df[col])
249
+
250
+ X = df[features]
251
+ y = df[target]
252
+
253
+ problem_type = "Classification" if y.nunique() <= 10 else "Regression"
254
+
255
+ if model_name not in MODEL_REGISTRY[problem_type]:
256
+ return None, None, f"Model {model_name} not suitable for {problem_type}."
257
 
258
+ model = MODEL_REGISTRY[problem_type][model_name](random_state=42)
 
 
 
 
 
 
 
 
 
259
 
260
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
261
+ model.fit(X_train, y_train)
262
+ preds = model.predict(X_test)
263
+
264
+ # Results
265
+ if problem_type == "Classification":
266
+ acc = accuracy_score(y_test, preds)
267
+ cm = confusion_matrix(y_test, preds)
268
+ cm_fig = px.imshow(cm, text_auto=True, title=f"Confusion Matrix (Accuracy: {acc:.2f})")
269
 
270
+ if hasattr(model, 'feature_importances_'):
271
+ fi = pd.Series(model.feature_importances_, index=features).sort_values(ascending=False)
272
+ fi_fig = px.bar(fi, title="Feature Importance")
273
+ return fi_fig, cm_fig, f"**Classification Report:**\n- Accuracy: {acc:.2f}"
274
+ else:
275
+ return None, cm_fig, f"**Classification Report:**\n- Accuracy: {acc:.2f}"
276
+
277
+ else: # Regression
278
+ r2 = r2_score(y_test, preds)
279
+ rmse = np.sqrt(mean_squared_error(y_test, preds))
280
+
281
+ preds_fig = px.scatter(x=y_test, y=preds, labels={'x': 'Actual Values', 'y': 'Predicted Values'},
282
+ title=f"Predictions vs. Actuals (Rยฒ: {r2:.2f})", trendline='ols')
283
+
284
+ if hasattr(model, 'feature_importances_'):
285
+ fi = pd.Series(model.feature_importances_, index=features).sort_values(ascending=False)
286
+ fi_fig = px.bar(fi, title="Feature Importance")
287
+ return fi_fig, preds_fig, f"**Regression Report:**\n- Rยฒ Score: {r2:.2f}\n- RMSE: {rmse:.2f}"
288
+ else:
289
+ return None, preds_fig, f"**Regression Report:**\n- Rยฒ Score: {r2:.2f}\n- RMSE: {rmse:.2f}"
290
+
291
+ def copilot_respond(user_message, history, state, api_key):
292
+ """Handles the AI Co-pilot chat interaction."""
293
  if not api_key:
294
+ return history + [(user_message, "I need a Gemini API key to function.")], None, None, ""
 
 
 
 
 
 
 
295
 
296
+ history += [(user_message, None)]
297
 
 
298
  prompt = f"""
299
+ You are 'Phoenix Co-pilot', a world-class AI data analyst. Your goal is to help the user by writing and executing Python code.
300
+ You have access to a pandas DataFrame named `df`. This is the user's LATEST data, including any cleaning they've performed.
301
+
302
+ **DataFrame Info:**
303
+ - Columns and dtypes: {json.dumps(state['metadata']['dtypes'])}
304
+
305
  **Instructions:**
306
+ 1. Analyze the user's request: '{user_message}'.
307
+ 2. Formulate a plan (thought).
308
+ 3. Write Python code to execute the plan.
309
+ 4. Use `pandas`, `numpy`, and `plotly.express as px`.
310
+ 5. To show a plot, assign it to a variable `fig`. Ex: `fig = px.histogram(df, x='age')`.
311
+ 6. To show a dataframe, assign it to a variable `df_result`. Ex: `df_result = df.describe()`.
312
+ 7. Use `print()` for text output.
313
+ 8. **NEVER** modify `df` in place. Use `df.copy()` if needed.
314
+ 9. Respond **ONLY** with a single, valid JSON object with keys "thought" and "code".
315
+
316
+ **User Request:** "{user_message}"
317
+
 
 
 
 
 
 
 
 
318
  **Your JSON Response:**
319
  """
320
 
 
323
  model = genai.GenerativeModel('gemini-1.5-flash')
324
  response = model.generate_content(prompt)
325
 
326
+ # Clean and parse JSON
327
+ response_json = json.loads(response.text.strip().replace("```json", "").replace("```", ""))
 
 
328
  thought = response_json.get("thought", "Thinking...")
329
+ code_to_run = response_json.get("code", "print('No code generated.')")
330
 
331
+ bot_thinking = f"๐Ÿง  **Thinking:** *{thought}*"
332
+ history[-1] = (user_message, bot_thinking)
333
+ yield history, None, None, gr.update(value=code_to_run)
334
 
335
+ # Execute Code
336
+ local_vars = {'df': state['df_modified'], 'px': px, 'pd': pd, 'np': np}
337
+ stdout, fig_result, df_result, error = safe_exec(code_to_run, local_vars)
338
+
339
+ bot_response = bot_thinking + "\n\n---\n\n"
340
 
341
  if error:
342
+ bot_response += f"๐Ÿ’ฅ **Execution Error:**\n```\n{error}\n```"
 
 
 
343
  if stdout:
344
+ bot_response += f"๐Ÿ“‹ **Output:**\n```\n{stdout}\n```"
345
+ if not error and not stdout and not fig_result and not isinstance(df_result, pd.DataFrame):
346
+ bot_response += "โœ… Code executed, but produced no direct output."
347
+
348
+ history[-1] = (user_message, bot_response)
349
+ yield history, fig_result, df_result, gr.update(value=code_to_run)
350
 
 
 
 
 
351
  except Exception as e:
352
+ error_msg = f"A critical error occurred: {e}. The AI may have returned invalid JSON. Check the generated code."
 
353
  history[-1] = (user_message, error_msg)
354
+ yield history, None, None, ""
355
+
356
+ # --- Gradio UI Construction ---
357
+
358
+ with gr.Blocks(theme=THEME, title="Phoenix AI Data Explorer") as demo:
359
+ global_state = gr.State({})
360
+
361
+ gr.Markdown("# ๐Ÿ”ฅ Phoenix AI Data Explorer")
362
+ gr.Markdown("The next-generation analytic tool. Upload your data to awaken the Phoenix.")
363
+
364
+ with gr.Row():
365
+ file_input = gr.File(label="๐Ÿ“ Upload CSV", file_types=[".csv"])
366
+ api_key_input = gr.Textbox(label="๐Ÿ”‘ Gemini API Key", type="password", placeholder="Enter Google AI Studio key...")
367
+
368
+ with gr.Tabs(visible=False) as phoenix_tabs:
369
+ with gr.Tab("๐Ÿฆ… Phoenix Eye"):
370
+ phoenix_eye_output = gr.Markdown()
371
+
372
+ with gr.Tab("๐Ÿฉบ Data Medic"):
373
+ gr.Markdown("### Cleanse Your Data\nSelect a column with missing values and choose a method to fill them.")
374
+ with gr.Row():
375
+ medic_col_select = gr.Dropdown(label="Select Column to Clean")
376
+ medic_method_select = gr.Radio(['mean', 'median', 'mode'], label="Imputation Method", value='mean')
377
+ medic_preview_btn = gr.Button("๐Ÿ“Š Preview Changes")
378
+ medic_plot = gr.Plot()
379
+ with gr.Row():
380
+ medic_apply_btn = gr.Button("โœ… Apply & Save Changes", variant="primary")
381
+ medic_status = gr.Textbox(label="Status", interactive=False)
382
+ with gr.Accordion("Download Cleaned Data", open=False):
383
+ download_btn = gr.Button("โฌ‡๏ธ Download Cleaned CSV")
384
+ download_file_output = gr.File(label="Download Link", visible=False)
385
+
386
+ with gr.Tab("๐Ÿ”ฎ The Oracle (Predictive Modeling)"):
387
+ gr.Markdown("### Glimpse the Future\nTrain a simple model to see the predictive power of your data.")
388
+ with gr.Row():
389
+ oracle_target_select = gr.Dropdown(label="๐ŸŽฏ Select Target Variable")
390
+ oracle_feature_select = gr.Multiselect(label="โœจ Select Features")
391
+ oracle_model_select = gr.Dropdown(choices=["Random Forest", "Logistic Regression", "Linear Regression"], label="๐Ÿง  Select Model")
392
+ oracle_run_btn = gr.Button("๐Ÿš€ Train Model!", variant="primary")
393
+ oracle_status = gr.Markdown()
394
+ with gr.Row():
395
+ oracle_fig1 = gr.Plot()
396
+ oracle_fig2 = gr.Plot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
+ with gr.Tab("๐Ÿค– AI Co-pilot"):
399
+ gr.Markdown("### Your Conversational Analyst\nAsk any question about your data in plain English.")
400
+ copilot_chatbot = gr.Chatbot(label="Chat History", height=400)
401
+ with gr.Accordion("AI Generated Results", open=True):
402
+ copilot_fig_output = gr.Plot()
403
+ copilot_df_output = gr.Dataframe(interactive=False)
404
+ with gr.Accordion("Generated Code", open=False):
405
+ copilot_code_output = gr.Code(language="python", interactive=False)
406
+
407
+ with gr.Row():
408
+ copilot_input = gr.Textbox(label="Your Question", placeholder="e.g., 'What's the correlation between age and salary?'", scale=4)
409
+ copilot_submit_btn = gr.Button("Submit", variant="primary", scale=1)
410
+
411
+ # --- Event Wiring ---
412
+ file_input.upload(
413
+ fn=prime_data,
414
+ inputs=file_input,
415
+ outputs=[global_state, phoenix_tabs, phoenix_eye_output, medic_col_select, oracle_target_select, oracle_feature_select],
416
+ show_progress="full"
417
+ )
418
+
419
+ # Data Medic
420
+ medic_preview_btn.click(medic_preview_imputation, [global_state, medic_col_select, medic_method_select], medic_plot)
421
+ medic_apply_btn.click(medic_apply_imputation, [global_state, medic_col_select, medic_method_select], [global_state, medic_status, medic_col_select])
422
+ download_btn.click(download_cleaned_data, [global_state], download_file_output)
423
+
424
+ # Oracle
425
+ oracle_run_btn.click(
426
+ oracle_run_model,
427
+ [global_state, oracle_target_select, oracle_feature_select, oracle_model_select],
428
+ [oracle_fig1, oracle_fig2, oracle_status],
429
+ show_progress="full"
430
+ )
431
+
432
+ # AI Co-pilot
433
+ copilot_submit_btn.click(
434
+ copilot_respond,
435
+ [copilot_input, copilot_chatbot, global_state, api_key_input],
436
+ [copilot_chatbot, copilot_fig_output, copilot_df_output, copilot_code_output]
437
+ ).then(lambda: "", copilot_input, copilot_input) # Clear input after submit
438
+
439
  if __name__ == "__main__":
440
+ demo.launch(debug=True)