mgbam commited on
Commit
57776e0
·
verified ·
1 Parent(s): a9bdee1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -315
app.py CHANGED
@@ -2,48 +2,41 @@
2
 
3
  import streamlit as st
4
  import numpy as np
5
- import pandas as pd
6
  from smolagents import CodeAgent, tool
7
  from typing import Union, List, Dict, Optional
8
  import matplotlib.pyplot as plt
9
  import seaborn as sns
10
- import plotly.express as px
11
- import plotly.graph_objects as go
12
  import os
13
  from groq import Groq
14
  from dataclasses import dataclass
15
  import tempfile
16
  import base64
17
- import io
18
- from sklearn.model_selection import train_test_split
19
- from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
20
- import joblib
21
- import pdfkit # Ensure wkhtmltopdf is available in the environment
22
- import uuid # For generating unique report IDs
23
 
24
  # ------------------------------
25
  # Language Model Interface
26
  # ------------------------------
27
  class GroqLLM:
28
- """Enhanced LLM interface with support for generating natural language summaries."""
29
-
30
  def __init__(self, model_name: str = "llama-3.1-8B-Instant"):
31
  """
32
- Initialize the GroqLLM with a specified model.
33
-
34
  Args:
35
  model_name (str): The name of the language model to use.
36
  """
37
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
38
  self.model_name = model_name
39
-
40
  def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
41
  """
42
  Make the class callable as required by smolagents.
43
-
44
  Args:
45
  prompt (Union[str, dict, List[Dict]]): The input prompt for the language model.
46
-
47
  Returns:
48
  str: The generated response from the language model.
49
  """
@@ -62,7 +55,7 @@ class GroqLLM:
62
  "content": prompt_str
63
  }],
64
  temperature=0.7,
65
- max_tokens=1500, # Increased tokens for detailed responses
66
  stream=False
67
  )
68
 
@@ -77,12 +70,12 @@ class GroqLLM:
77
  # Data Analysis Agent
78
  # ------------------------------
79
  class DataAnalysisAgent(CodeAgent):
80
- """Extended CodeAgent with dataset awareness and predictive analytics capabilities."""
81
-
82
  def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
83
  """
84
  Initialize the DataAnalysisAgent with the provided dataset.
85
-
86
  Args:
87
  dataset (pd.DataFrame): The dataset to analyze.
88
  *args: Variable length argument list.
@@ -90,24 +83,19 @@ class DataAnalysisAgent(CodeAgent):
90
  """
91
  super().__init__(*args, **kwargs)
92
  self._dataset = dataset
93
- self.models = {} # To store trained models
94
-
95
  @property
96
  def dataset(self) -> pd.DataFrame:
97
- """Access the stored dataset.
98
-
99
- Returns:
100
- pd.DataFrame: The dataset stored in the agent.
101
- """
102
  return self._dataset
103
-
104
  def run(self, prompt: str) -> str:
105
  """
106
- Override the run method to include dataset context and support predictive tasks.
107
-
108
  Args:
109
  prompt (str): The task prompt for analysis.
110
-
111
  Returns:
112
  str: The result of the analysis.
113
  """
@@ -133,23 +121,22 @@ class DataAnalysisAgent(CodeAgent):
133
  @tool
134
  def analyze_basic_stats(data: Optional[pd.DataFrame] = None) -> str:
135
  """
136
- Calculate and visualize basic statistical measures for numerical columns.
137
-
138
  This function computes fundamental statistical metrics including mean, median,
139
  standard deviation, skewness, and counts of missing values for all numerical
140
- columns in the provided DataFrame. It also generates a bar chart visualizing
141
- the mean, median, and standard deviation for each numerical feature.
142
-
143
  Args:
144
  data (Optional[pd.DataFrame], optional):
145
- A pandas DataFrame containing the dataset to analyze.
146
- If None, the agent's stored dataset will be used.
147
- The DataFrame should contain at least one numerical column
148
- for meaningful analysis.
149
-
150
  Returns:
151
- str: A markdown-formatted string containing the statistics and the generated plot.
 
152
  """
 
153
  if data is None:
154
  data = tool.agent.dataset
155
 
@@ -165,280 +152,110 @@ def analyze_basic_stats(data: Optional[pd.DataFrame] = None) -> str:
165
  'missing': int(data[col].isnull().sum())
166
  }
167
 
168
- # Generate a summary DataFrame
169
- stats_df = pd.DataFrame(stats).T
170
- stats_df.reset_index(inplace=True)
171
- stats_df.rename(columns={'index': 'Feature'}, inplace=True)
172
-
173
- # Plotting basic statistics
174
- fig, ax = plt.subplots(figsize=(10, 6))
175
- stats_df.set_index('Feature')[['mean', 'median', 'std']].plot(kind='bar', ax=ax)
176
- plt.title('Basic Statistics')
177
- plt.ylabel('Values')
178
- plt.tight_layout()
179
-
180
- # Save plot to buffer
181
- buf = io.BytesIO()
182
- plt.savefig(buf, format='png')
183
- plt.close()
184
- stats_plot = base64.b64encode(buf.getvalue()).decode()
185
-
186
- return f"### Basic Statistics\n{stats_df.to_markdown()} \n\n![Basic Statistics](data:image/png;base64,{stats_plot})"
187
 
188
  @tool
189
  def generate_correlation_matrix(data: Optional[pd.DataFrame] = None) -> str:
190
  """
191
- Generate an interactive correlation matrix using Plotly.
192
-
193
- This function creates an interactive heatmap visualization showing the correlations between
194
- all numerical columns in the dataset. Users can hover over cells to see correlation values
195
- and interact with the plot (zoom, pan).
196
-
197
  Args:
198
  data (Optional[pd.DataFrame], optional):
199
- A pandas DataFrame containing the dataset to analyze.
200
- If None, the agent's stored dataset will be used.
201
- The DataFrame should contain at least two numerical columns
202
- for correlation analysis.
203
-
204
  Returns:
205
- str: An HTML string representing the interactive correlation matrix plot.
 
206
  """
 
207
  if data is None:
208
  data = tool.agent.dataset
209
-
210
  numeric_data = data.select_dtypes(include=[np.number])
211
- corr = numeric_data.corr()
212
-
213
- fig = px.imshow(corr,
214
- text_auto=True,
215
- aspect="auto",
216
- color_continuous_scale='RdBu',
217
- title='Correlation Matrix')
218
 
219
- fig.update_layout(width=800, height=600)
 
 
220
 
221
- # Convert Plotly figure to HTML div
222
- correlation_html = fig.to_html(full_html=False)
223
-
224
- return correlation_html
225
 
226
  @tool
227
  def analyze_categorical_columns(data: Optional[pd.DataFrame] = None) -> str:
228
  """
229
- Analyze categorical columns with visualizations.
230
-
231
  This function examines categorical columns to identify unique values, top categories,
232
- and missing value counts. It also generates bar charts for the top 5 categories in each
233
- categorical feature.
234
-
235
  Args:
236
  data (Optional[pd.DataFrame], optional):
237
- A pandas DataFrame containing the dataset to analyze.
238
- If None, the agent's stored dataset will be used.
239
- The DataFrame should contain at least one categorical column
240
- for meaningful analysis.
241
-
242
  Returns:
243
- str: A markdown-formatted string containing analysis results and embedded plots.
 
244
  """
 
245
  if data is None:
246
  data = tool.agent.dataset
247
-
248
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
249
  analysis = {}
250
- plots = ""
251
 
252
  for col in categorical_cols:
253
- unique_vals = data[col].nunique()
254
- top_categories = data[col].value_counts().head(5).to_dict()
255
- missing = data[col].isnull().sum()
256
-
257
  analysis[col] = {
258
- 'unique_values': int(unique_vals),
259
- 'top_categories': top_categories,
260
- 'missing': int(missing)
261
  }
262
-
263
- # Generate bar chart for top categories
264
- fig, ax = plt.subplots(figsize=(8, 4))
265
- sns.countplot(data=data, x=col, order=data[col].value_counts().iloc[:5].index, ax=ax)
266
- plt.title(f'Top 5 Categories in {col}')
267
- plt.xticks(rotation=45)
268
- plt.tight_layout()
269
-
270
- buf = io.BytesIO()
271
- plt.savefig(buf, format='png')
272
- plt.close()
273
- plot_img = base64.b64encode(buf.getvalue()).decode()
274
-
275
- plots += f"### {col}\n"
276
- plots += f"- **Unique Values:** {unique_vals}\n"
277
- plots += f"- **Missing Values:** {missing}\n"
278
- plots += f"- **Top Categories:** {top_categories}\n"
279
- plots += f"![Top Categories in {col}](data:image/png;base64,{plot_img})\n\n"
280
 
281
- return plots + f"### Categorical Columns Analysis\n{pd.DataFrame(analysis).T.to_markdown()}"
282
 
283
  @tool
284
  def suggest_features(data: Optional[pd.DataFrame] = None) -> str:
285
  """
286
  Suggest potential feature engineering steps based on data characteristics.
287
-
288
  This function analyzes the dataset's structure and statistical properties to
289
  recommend possible feature engineering steps that could improve model performance.
290
-
291
  Args:
292
  data (Optional[pd.DataFrame], optional):
293
- A pandas DataFrame containing the dataset to analyze.
294
- If None, the agent's stored dataset will be used.
295
- The DataFrame can contain both numerical and categorical columns.
296
-
297
  Returns:
298
  str: A string containing suggestions for feature engineering based on
299
  the characteristics of the input data.
300
  """
 
301
  if data is None:
302
  data = tool.agent.dataset
303
-
304
  suggestions = []
305
  numeric_cols = data.select_dtypes(include=[np.number]).columns
306
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
307
 
308
- # Interaction terms
309
  if len(numeric_cols) >= 2:
310
- suggestions.append("• **Interaction Terms:** Consider creating interaction terms between numerical features to capture combined effects.")
311
 
312
- # Encoding categorical variables
313
  if len(categorical_cols) > 0:
314
- suggestions.append(" **One-Hot Encoding:** Apply one-hot encoding to categorical variables to convert them into numerical format.")
315
- suggestions.append("• **Label Encoding:** For ordinal categorical variables, consider label encoding to maintain order information.")
316
-
317
- # Handling skewness
318
  for col in numeric_cols:
319
  if data[col].skew() > 1 or data[col].skew() < -1:
320
- suggestions.append(f" **Log Transformation:** Apply log transformation to `{col}` to reduce skewness and stabilize variance.")
321
-
322
- # Missing value imputation
323
- for col in data.columns:
324
- if data[col].isnull().sum() > 0:
325
- suggestions.append(f"• **Imputation:** Consider imputing missing values in `{col}` using mean, median, or advanced imputation techniques.")
326
-
327
- # Feature scaling
328
- suggestions.append("• **Feature Scaling:** Apply feature scaling (Standardization or Normalization) to numerical features to ensure uniformity.")
329
-
330
- return "\n".join(suggestions)
331
-
332
- @tool
333
- def predictive_analysis(data: Optional[pd.DataFrame] = None, target: Optional[str] = None) -> str:
334
- """
335
- Perform predictive analytics by training a classification model.
336
-
337
- This function builds a classification model using Random Forest, evaluates its performance,
338
- and provides detailed metrics and visualizations such as the confusion matrix and ROC curve.
339
-
340
- Args:
341
- data (Optional[pd.DataFrame], optional):
342
- A pandas DataFrame containing the dataset to analyze.
343
- If None, the agent's stored dataset will be used.
344
- The DataFrame should contain the target variable for prediction.
345
- target (Optional[str], optional):
346
- The name of the target variable column in the dataset.
347
- If None, the agent must provide the target variable through the prompt.
348
-
349
- Returns:
350
- str: A markdown-formatted string containing the classification report, confusion matrix,
351
- ROC curve, AUC score, and a unique Model ID.
352
- """
353
- if data is None:
354
- data = tool.agent.dataset
355
-
356
- if target is None or target not in data.columns:
357
- return f"Error: Target column not specified or `{target}` not found in the dataset."
358
-
359
- # Handle categorical target
360
- if data[target].dtype == 'object' or data[target].dtype.name == 'category':
361
- data[target] = data[target].astype('category').cat.codes
362
-
363
- # Drop rows with missing target
364
- data = data.dropna(subset=[target])
365
-
366
- # Separate features and target
367
- X = data.drop(columns=[target])
368
- y = data[target]
369
-
370
- # Handle missing values (simple imputation)
371
- X = X.fillna(X.median())
372
-
373
- # Encode categorical variables
374
- X = pd.get_dummies(X, drop_first=True)
375
 
376
- # Split data
377
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
378
-
379
- # Train a Random Forest Classifier (as an example)
380
- from sklearn.ensemble import RandomForestClassifier
381
- clf = RandomForestClassifier(n_estimators=100, random_state=42)
382
- clf.fit(X_train, y_train)
383
-
384
- # Predictions
385
- y_pred = clf.predict(X_test)
386
- y_proba = clf.predict_proba(X_test)[:,1]
387
-
388
- # Evaluation
389
- report = classification_report(y_test, y_pred, output_dict=True)
390
- report_df = pd.DataFrame(report).transpose()
391
-
392
- # Confusion Matrix
393
- cm = confusion_matrix(y_test, y_pred)
394
- fig_cm = px.imshow(cm, text_auto=True, labels=dict(x="Predicted", y="Actual", color="Count"),
395
- x=["Negative", "Positive"], y=["Negative", "Positive"],
396
- title="Confusion Matrix")
397
-
398
- # ROC Curve
399
- fpr, tpr, thresholds = roc_curve(y_test, y_proba)
400
- roc_auc = auc(fpr, tpr)
401
- fig_roc = go.Figure()
402
- fig_roc.add_trace(go.Scatter(x=fpr, y=tpr, mode='lines', name=f'ROC Curve (AUC = {roc_auc:.2f})'))
403
- fig_roc.add_trace(go.Scatter(x=[0,1], y=[0,1], mode='lines', name='Random Guess', line=dict(dash='dash')))
404
- fig_roc.update_layout(title='Receiver Operating Characteristic (ROC) Curve',
405
- xaxis_title='False Positive Rate',
406
- yaxis_title='True Positive Rate')
407
-
408
- # Save models for potential future use
409
- model_id = str(uuid.uuid4())
410
- with tempfile.NamedTemporaryFile(delete=False, suffix='.joblib') as tmp_model_file:
411
- joblib.dump(clf, tmp_model_file.name)
412
- # In a real-world scenario, you'd store this in a persistent storage
413
- tool.agent.models[model_id] = clf # Storing in agent's models dict
414
-
415
- # Generate HTML for plots
416
- cm_html = fig_cm.to_html(full_html=False)
417
- roc_html = fig_roc.to_html(full_html=False)
418
-
419
- # Generate report summary
420
- summary = f"""
421
- ### Predictive Analytics Report for Target: `{target}`
422
-
423
- **Model Used:** Random Forest Classifier
424
-
425
- **Classification Report:**
426
- {report_df.to_markdown()}
427
-
428
- **Confusion Matrix:**
429
- {cm_html}
430
-
431
- **ROC Curve:**
432
- {roc_html}
433
-
434
- **AUC Score:** {roc_auc:.2f}
435
-
436
- **Model ID:** `{model_id}`
437
-
438
- *You can use this Model ID to retrieve or update the model in future analyses.*
439
- """
440
-
441
- return summary
442
 
443
  # ------------------------------
444
  # Report Exporting Function
@@ -446,14 +263,14 @@ def predictive_analysis(data: Optional[pd.DataFrame] = None, target: Optional[st
446
  def export_report(content: str, filename: str):
447
  """
448
  Export the given content as a PDF report.
449
-
450
  This function converts markdown content into a PDF file using pdfkit and provides
451
  a download button for users to obtain the report.
452
-
453
  Args:
454
  content (str): The markdown content to be included in the PDF report.
455
  filename (str): The desired name for the exported PDF file.
456
-
457
  Returns:
458
  None
459
  """
@@ -491,7 +308,7 @@ def export_report(content: str, filename: str):
491
  def main():
492
  st.set_page_config(page_title="📊 Business Intelligence Assistant", layout="wide")
493
  st.title("📊 **Business Intelligence Assistant**")
494
- st.write("Upload your dataset and receive comprehensive analyses, interactive visualizations, and predictive insights.")
495
 
496
  # Initialize session state
497
  if 'data' not in st.session_state:
@@ -501,8 +318,7 @@ def main():
501
  if 'report_content' not in st.session_state:
502
  st.session_state['report_content'] = ""
503
 
504
- # File Uploader
505
- uploaded_file = st.file_uploader("📥 **Upload a CSV file**", type="csv")
506
 
507
  try:
508
  if uploaded_file is not None:
@@ -515,93 +331,80 @@ def main():
515
  st.session_state['agent'] = DataAnalysisAgent(
516
  dataset=data,
517
  tools=[analyze_basic_stats, generate_correlation_matrix,
518
- analyze_categorical_columns, suggest_features, predictive_analysis],
519
  model=GroqLLM(),
520
- additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn", "plotly"]
521
  )
522
 
523
- st.success(f"✅ Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns")
524
  st.subheader("🔍 **Data Preview**")
525
  st.dataframe(data.head())
526
 
527
  if st.session_state['data'] is not None:
528
- # Sidebar for Analysis Selection
529
- st.sidebar.header("🛠️ **Select Analysis Type**")
530
- analysis_type = st.sidebar.selectbox(
531
  "Choose analysis type",
532
  ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
533
- "Feature Engineering", "Predictive Analytics", "Custom Question"]
534
  )
535
 
536
  if analysis_type == "Basic Statistics":
537
- with st.spinner('📈 Analyzing basic statistics...'):
538
  result = st.session_state['agent'].run(
539
  "Use the analyze_basic_stats tool to analyze this dataset and "
540
  "provide insights about the numerical distributions."
541
  )
542
- st.markdown(result, unsafe_allow_html=True)
543
  st.session_state['report_content'] += result + "\n\n"
544
-
545
  elif analysis_type == "Correlation Analysis":
546
- with st.spinner('📊 Generating correlation matrix...'):
547
  result = st.session_state['agent'].run(
548
  "Use the generate_correlation_matrix tool to analyze correlations "
549
  "and explain any strong relationships found."
550
  )
551
- st.components.v1.html(result, height=600)
 
 
 
 
 
552
  st.session_state['report_content'] += "### Correlation Analysis\n" + result + "\n\n"
553
-
554
  elif analysis_type == "Categorical Analysis":
555
- with st.spinner('📊 Analyzing categorical columns...'):
556
  result = st.session_state['agent'].run(
557
  "Use the analyze_categorical_columns tool to examine the "
558
  "categorical variables and explain the distributions."
559
  )
560
- st.markdown(result, unsafe_allow_html=True)
561
- st.session_state['report_content'] += result + "\n\n"
562
-
563
  elif analysis_type == "Feature Engineering":
564
- with st.spinner('🔧 Generating feature suggestions...'):
565
  result = st.session_state['agent'].run(
566
  "Use the suggest_features tool to recommend potential "
567
  "feature engineering steps for this dataset."
568
  )
569
- st.markdown(result, unsafe_allow_html=True)
570
- st.session_state['report_content'] += result + "\n\n"
571
-
572
- elif analysis_type == "Predictive Analytics":
573
- with st.form("Predictive Analytics Form"):
574
- st.write("🔮 **Predictive Analytics**")
575
- target = st.selectbox("Select the target variable for prediction:", options=st.session_state['data'].columns)
576
- submit = st.form_submit_button("🚀 Run Predictive Analysis")
577
-
578
- if submit:
579
- with st.spinner('🚀 Performing predictive analysis...'):
580
- result = st.session_state['agent'].run(
581
- f"Use the predictive_analysis tool to build a classification model with `{target}` as the target variable."
582
- )
583
- st.markdown(result, unsafe_allow_html=True)
584
- st.session_state['report_content'] += result + "\n\n"
585
- export_report(result, "Predictive_Analysis_Report")
586
-
587
  elif analysis_type == "Custom Question":
588
- with st.expander("📝 **Ask a Custom Question**"):
589
- question = st.text_input("What would you like to know about your data?")
590
- if st.button("🔍 Get Answer"):
591
- if question:
592
- with st.spinner('🧠 Processing your question...'):
593
- result = st.session_state['agent'].run(question)
594
- st.markdown(result, unsafe_allow_html=True)
595
- st.session_state['report_content'] += f"### Custom Question: {question}\n{result}\n\n"
596
- else:
597
- st.warning("Please enter a question.")
598
 
599
  # Option to Export Report
600
  if st.session_state['report_content']:
601
- st.sidebar.markdown("---")
602
- if st.sidebar.button("📤 **Export Analysis Report**"):
603
  export_report(st.session_state['report_content'], "Business_Intelligence_Report")
604
- st.sidebar.success("✅ Report exported successfully!")
605
 
606
  except Exception as e:
607
  st.error(f"⚠️ An error occurred: {str(e)}")
 
2
 
3
  import streamlit as st
4
  import numpy as np
5
+ import pandas as pd
6
  from smolagents import CodeAgent, tool
7
  from typing import Union, List, Dict, Optional
8
  import matplotlib.pyplot as plt
9
  import seaborn as sns
 
 
10
  import os
11
  from groq import Groq
12
  from dataclasses import dataclass
13
  import tempfile
14
  import base64
15
+ import io
 
 
 
 
 
16
 
17
  # ------------------------------
18
  # Language Model Interface
19
  # ------------------------------
20
  class GroqLLM:
21
+ """Compatible LLM interface for smolagents CodeAgent"""
22
+
23
  def __init__(self, model_name: str = "llama-3.1-8B-Instant"):
24
  """
25
+ Initialize the GroqLLM with the specified model.
26
+
27
  Args:
28
  model_name (str): The name of the language model to use.
29
  """
30
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
31
  self.model_name = model_name
32
+
33
  def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
34
  """
35
  Make the class callable as required by smolagents.
36
+
37
  Args:
38
  prompt (Union[str, dict, List[Dict]]): The input prompt for the language model.
39
+
40
  Returns:
41
  str: The generated response from the language model.
42
  """
 
55
  "content": prompt_str
56
  }],
57
  temperature=0.7,
58
+ max_tokens=1024,
59
  stream=False
60
  )
61
 
 
70
  # Data Analysis Agent
71
  # ------------------------------
72
  class DataAnalysisAgent(CodeAgent):
73
+ """Extended CodeAgent with dataset awareness"""
74
+
75
  def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
76
  """
77
  Initialize the DataAnalysisAgent with the provided dataset.
78
+
79
  Args:
80
  dataset (pd.DataFrame): The dataset to analyze.
81
  *args: Variable length argument list.
 
83
  """
84
  super().__init__(*args, **kwargs)
85
  self._dataset = dataset
86
+
 
87
  @property
88
  def dataset(self) -> pd.DataFrame:
89
+ """Access the stored dataset."""
 
 
 
 
90
  return self._dataset
91
+
92
  def run(self, prompt: str) -> str:
93
  """
94
+ Override run method to include dataset context.
95
+
96
  Args:
97
  prompt (str): The task prompt for analysis.
98
+
99
  Returns:
100
  str: The result of the analysis.
101
  """
 
121
  @tool
122
  def analyze_basic_stats(data: Optional[pd.DataFrame] = None) -> str:
123
  """
124
+ Calculate basic statistical measures for numerical columns in the dataset.
125
+
126
  This function computes fundamental statistical metrics including mean, median,
127
  standard deviation, skewness, and counts of missing values for all numerical
128
+ columns in the provided DataFrame.
129
+
 
130
  Args:
131
  data (Optional[pd.DataFrame], optional):
132
+ A pandas DataFrame containing the dataset to analyze. The DataFrame
133
+ should contain at least one numerical column for meaningful analysis.
134
+
 
 
135
  Returns:
136
+ str: A string containing formatted basic statistics for each numerical column,
137
+ including mean, median, standard deviation, skewness, and missing value counts.
138
  """
139
+ # Access dataset from agent if no data provided
140
  if data is None:
141
  data = tool.agent.dataset
142
 
 
152
  'missing': int(data[col].isnull().sum())
153
  }
154
 
155
+ return str(stats)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  @tool
158
  def generate_correlation_matrix(data: Optional[pd.DataFrame] = None) -> str:
159
  """
160
+ Generate a visual correlation matrix for numerical columns in the dataset.
161
+
162
+ This function creates a heatmap visualization showing the correlations between
163
+ all numerical columns in the dataset. The correlation values are displayed
164
+ using a color-coded matrix for easy interpretation.
165
+
166
  Args:
167
  data (Optional[pd.DataFrame], optional):
168
+ A pandas DataFrame containing the dataset to analyze. The DataFrame
169
+ should contain at least two numerical columns for correlation analysis.
170
+
 
 
171
  Returns:
172
+ str: A base64 encoded string representing the correlation matrix plot image,
173
+ which can be displayed in a web interface or saved as an image file.
174
  """
175
+ # Access dataset from agent if no data provided
176
  if data is None:
177
  data = tool.agent.dataset
178
+
179
  numeric_data = data.select_dtypes(include=[np.number])
 
 
 
 
 
 
 
180
 
181
+ plt.figure(figsize=(10, 8))
182
+ sns.heatmap(numeric_data.corr(), annot=True, cmap='coolwarm')
183
+ plt.title('Correlation Matrix')
184
 
185
+ buf = io.BytesIO()
186
+ plt.savefig(buf, format='png')
187
+ plt.close()
188
+ return base64.b64encode(buf.getvalue()).decode()
189
 
190
  @tool
191
  def analyze_categorical_columns(data: Optional[pd.DataFrame] = None) -> str:
192
  """
193
+ Analyze categorical columns in the dataset for distribution and frequencies.
194
+
195
  This function examines categorical columns to identify unique values, top categories,
196
+ and missing value counts, providing insights into the categorical data distribution.
197
+
 
198
  Args:
199
  data (Optional[pd.DataFrame], optional):
200
+ A pandas DataFrame containing the dataset to analyze. The DataFrame
201
+ should contain at least one categorical column for meaningful analysis.
202
+
 
 
203
  Returns:
204
+ str: A string containing formatted analysis results for each categorical column,
205
+ including unique value counts, top categories, and missing value counts.
206
  """
207
+ # Access dataset from agent if no data provided
208
  if data is None:
209
  data = tool.agent.dataset
210
+
211
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
212
  analysis = {}
 
213
 
214
  for col in categorical_cols:
 
 
 
 
215
  analysis[col] = {
216
+ 'unique_values': int(data[col].nunique()),
217
+ 'top_categories': data[col].value_counts().head(5).to_dict(),
218
+ 'missing': int(data[col].isnull().sum())
219
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
+ return str(analysis)
222
 
223
  @tool
224
  def suggest_features(data: Optional[pd.DataFrame] = None) -> str:
225
  """
226
  Suggest potential feature engineering steps based on data characteristics.
227
+
228
  This function analyzes the dataset's structure and statistical properties to
229
  recommend possible feature engineering steps that could improve model performance.
230
+
231
  Args:
232
  data (Optional[pd.DataFrame], optional):
233
+ A pandas DataFrame containing the dataset to analyze. The DataFrame
234
+ can contain both numerical and categorical columns.
235
+
 
236
  Returns:
237
  str: A string containing suggestions for feature engineering based on
238
  the characteristics of the input data.
239
  """
240
+ # Access dataset from agent if no data provided
241
  if data is None:
242
  data = tool.agent.dataset
243
+
244
  suggestions = []
245
  numeric_cols = data.select_dtypes(include=[np.number]).columns
246
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
247
 
 
248
  if len(numeric_cols) >= 2:
249
+ suggestions.append("Consider creating interaction terms between numerical features")
250
 
 
251
  if len(categorical_cols) > 0:
252
+ suggestions.append("Consider one-hot encoding for categorical variables")
253
+
 
 
254
  for col in numeric_cols:
255
  if data[col].skew() > 1 or data[col].skew() < -1:
256
+ suggestions.append(f"Consider log transformation for {col} due to skewness")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
+ return '\n'.join(suggestions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  # ------------------------------
261
  # Report Exporting Function
 
263
  def export_report(content: str, filename: str):
264
  """
265
  Export the given content as a PDF report.
266
+
267
  This function converts markdown content into a PDF file using pdfkit and provides
268
  a download button for users to obtain the report.
269
+
270
  Args:
271
  content (str): The markdown content to be included in the PDF report.
272
  filename (str): The desired name for the exported PDF file.
273
+
274
  Returns:
275
  None
276
  """
 
308
  def main():
309
  st.set_page_config(page_title="📊 Business Intelligence Assistant", layout="wide")
310
  st.title("📊 **Business Intelligence Assistant**")
311
+ st.write("Upload your dataset and get automated analysis with natural language interaction.")
312
 
313
  # Initialize session state
314
  if 'data' not in st.session_state:
 
318
  if 'report_content' not in st.session_state:
319
  st.session_state['report_content'] = ""
320
 
321
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
 
322
 
323
  try:
324
  if uploaded_file is not None:
 
331
  st.session_state['agent'] = DataAnalysisAgent(
332
  dataset=data,
333
  tools=[analyze_basic_stats, generate_correlation_matrix,
334
+ analyze_categorical_columns, suggest_features],
335
  model=GroqLLM(),
336
+ additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"]
337
  )
338
 
339
+ st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
340
  st.subheader("🔍 **Data Preview**")
341
  st.dataframe(data.head())
342
 
343
  if st.session_state['data'] is not None:
344
+ analysis_type = st.selectbox(
 
 
345
  "Choose analysis type",
346
  ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
347
+ "Feature Engineering", "Custom Question"]
348
  )
349
 
350
  if analysis_type == "Basic Statistics":
351
+ with st.spinner('Analyzing basic statistics...'):
352
  result = st.session_state['agent'].run(
353
  "Use the analyze_basic_stats tool to analyze this dataset and "
354
  "provide insights about the numerical distributions."
355
  )
356
+ st.write(result)
357
  st.session_state['report_content'] += result + "\n\n"
358
+
359
  elif analysis_type == "Correlation Analysis":
360
+ with st.spinner('Generating correlation matrix...'):
361
  result = st.session_state['agent'].run(
362
  "Use the generate_correlation_matrix tool to analyze correlations "
363
  "and explain any strong relationships found."
364
  )
365
+ if isinstance(result, str) and 'base64' in result:
366
+ # Extract base64 string and display the image
367
+ image_data = f"data:image/png;base64,{result}"
368
+ st.image(image_data, caption='Correlation Matrix')
369
+ else:
370
+ st.write(result)
371
  st.session_state['report_content'] += "### Correlation Analysis\n" + result + "\n\n"
372
+
373
  elif analysis_type == "Categorical Analysis":
374
+ with st.spinner('Analyzing categorical columns...'):
375
  result = st.session_state['agent'].run(
376
  "Use the analyze_categorical_columns tool to examine the "
377
  "categorical variables and explain the distributions."
378
  )
379
+ st.write(result)
380
+ st.session_state['report_content'] += "### Categorical Analysis\n" + result + "\n\n"
381
+
382
  elif analysis_type == "Feature Engineering":
383
+ with st.spinner('Generating feature suggestions...'):
384
  result = st.session_state['agent'].run(
385
  "Use the suggest_features tool to recommend potential "
386
  "feature engineering steps for this dataset."
387
  )
388
+ st.write(result)
389
+ st.session_state['report_content'] += "### Feature Engineering Suggestions\n" + result + "\n\n"
390
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  elif analysis_type == "Custom Question":
392
+ question = st.text_input("What would you like to know about your data?")
393
+ if st.button("🔍 Get Answer"):
394
+ if question:
395
+ with st.spinner('Analyzing...'):
396
+ result = st.session_state['agent'].run(question)
397
+ st.write(result)
398
+ st.session_state['report_content'] += f"### Custom Question: {question}\n{result}\n\n"
399
+ else:
400
+ st.warning("Please enter a question.")
 
401
 
402
  # Option to Export Report
403
  if st.session_state['report_content']:
404
+ st.markdown("---")
405
+ if st.button("📤 **Export Analysis Report**"):
406
  export_report(st.session_state['report_content'], "Business_Intelligence_Report")
407
+ st.success("✅ Report exported successfully!")
408
 
409
  except Exception as e:
410
  st.error(f"⚠️ An error occurred: {str(e)}")