mgbam commited on
Commit
f60d18c
·
verified ·
1 Parent(s): b1ccc30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -21
app.py CHANGED
@@ -131,7 +131,7 @@ class DataAnalysisAgent(CodeAgent):
131
  # ------------------------------
132
 
133
  @tool
134
- def analyze_basic_stats(data: pd.DataFrame) -> str:
135
  """
136
  Calculate and visualize basic statistical measures for numerical columns.
137
 
@@ -141,9 +141,11 @@ def analyze_basic_stats(data: pd.DataFrame) -> str:
141
  the mean, median, and standard deviation for each numerical feature.
142
 
143
  Args:
144
- data (pd.DataFrame): A pandas DataFrame containing the dataset to analyze.
145
- The DataFrame should contain at least one numerical column
146
- for meaningful analysis.
 
 
147
 
148
  Returns:
149
  str: A markdown-formatted string containing the statistics and the generated plot.
@@ -184,7 +186,7 @@ def analyze_basic_stats(data: pd.DataFrame) -> str:
184
  return f"### Basic Statistics\n{stats_df.to_markdown()} \n\n![Basic Statistics](data:image/png;base64,{stats_plot})"
185
 
186
  @tool
187
- def generate_correlation_matrix(data: pd.DataFrame) -> str:
188
  """
189
  Generate an interactive correlation matrix using Plotly.
190
 
@@ -193,9 +195,11 @@ def generate_correlation_matrix(data: pd.DataFrame) -> str:
193
  and interact with the plot (zoom, pan).
194
 
195
  Args:
196
- data (pd.DataFrame): A pandas DataFrame containing the dataset to analyze.
197
- The DataFrame should contain at least two numerical columns
198
- for correlation analysis.
 
 
199
 
200
  Returns:
201
  str: An HTML string representing the interactive correlation matrix plot.
@@ -220,7 +224,7 @@ def generate_correlation_matrix(data: pd.DataFrame) -> str:
220
  return correlation_html
221
 
222
  @tool
223
- def analyze_categorical_columns(data: pd.DataFrame) -> str:
224
  """
225
  Analyze categorical columns with visualizations.
226
 
@@ -229,9 +233,11 @@ def analyze_categorical_columns(data: pd.DataFrame) -> str:
229
  categorical feature.
230
 
231
  Args:
232
- data (pd.DataFrame): A pandas DataFrame containing the dataset to analyze.
233
- The DataFrame should contain at least one categorical column
234
- for meaningful analysis.
 
 
235
 
236
  Returns:
237
  str: A markdown-formatted string containing analysis results and embedded plots.
@@ -275,7 +281,7 @@ def analyze_categorical_columns(data: pd.DataFrame) -> str:
275
  return plots + f"### Categorical Columns Analysis\n{pd.DataFrame(analysis).T.to_markdown()}"
276
 
277
  @tool
278
- def suggest_features(data: pd.DataFrame) -> str:
279
  """
280
  Suggest potential feature engineering steps based on data characteristics.
281
 
@@ -283,8 +289,10 @@ def suggest_features(data: pd.DataFrame) -> str:
283
  recommend possible feature engineering steps that could improve model performance.
284
 
285
  Args:
286
- data (pd.DataFrame): A pandas DataFrame containing the dataset to analyze.
287
- The DataFrame can contain both numerical and categorical columns.
 
 
288
 
289
  Returns:
290
  str: A string containing suggestions for feature engineering based on
@@ -322,7 +330,7 @@ def suggest_features(data: pd.DataFrame) -> str:
322
  return "\n".join(suggestions)
323
 
324
  @tool
325
- def predictive_analysis(data: pd.DataFrame, target: str) -> str:
326
  """
327
  Perform predictive analytics by training a classification model.
328
 
@@ -330,9 +338,13 @@ def predictive_analysis(data: pd.DataFrame, target: str) -> str:
330
  and provides detailed metrics and visualizations such as the confusion matrix and ROC curve.
331
 
332
  Args:
333
- data (pd.DataFrame): A pandas DataFrame containing the dataset to analyze.
334
- The DataFrame should contain the target variable for prediction.
335
- target (str): The name of the target variable column in the dataset.
 
 
 
 
336
 
337
  Returns:
338
  str: A markdown-formatted string containing the classification report, confusion matrix,
@@ -341,8 +353,8 @@ def predictive_analysis(data: pd.DataFrame, target: str) -> str:
341
  if data is None:
342
  data = tool.agent.dataset
343
 
344
- if target not in data.columns:
345
- return f"Error: Target column `{target}` not found in the dataset."
346
 
347
  # Handle categorical target
348
  if data[target].dtype == 'object' or data[target].dtype.name == 'category':
 
131
  # ------------------------------
132
 
133
  @tool
134
+ def analyze_basic_stats(data: Optional[pd.DataFrame] = None) -> str:
135
  """
136
  Calculate and visualize basic statistical measures for numerical columns.
137
 
 
141
  the mean, median, and standard deviation for each numerical feature.
142
 
143
  Args:
144
+ data (Optional[pd.DataFrame]):
145
+ A pandas DataFrame containing the dataset to analyze.
146
+ If None, the agent's stored dataset will be used.
147
+ The DataFrame should contain at least one numerical column
148
+ for meaningful analysis.
149
 
150
  Returns:
151
  str: A markdown-formatted string containing the statistics and the generated plot.
 
186
  return f"### Basic Statistics\n{stats_df.to_markdown()} \n\n![Basic Statistics](data:image/png;base64,{stats_plot})"
187
 
188
  @tool
189
+ def generate_correlation_matrix(data: Optional[pd.DataFrame] = None) -> str:
190
  """
191
  Generate an interactive correlation matrix using Plotly.
192
 
 
195
  and interact with the plot (zoom, pan).
196
 
197
  Args:
198
+ data (Optional[pd.DataFrame]):
199
+ A pandas DataFrame containing the dataset to analyze.
200
+ If None, the agent's stored dataset will be used.
201
+ The DataFrame should contain at least two numerical columns
202
+ for correlation analysis.
203
 
204
  Returns:
205
  str: An HTML string representing the interactive correlation matrix plot.
 
224
  return correlation_html
225
 
226
  @tool
227
+ def analyze_categorical_columns(data: Optional[pd.DataFrame] = None) -> str:
228
  """
229
  Analyze categorical columns with visualizations.
230
 
 
233
  categorical feature.
234
 
235
  Args:
236
+ data (Optional[pd.DataFrame]):
237
+ A pandas DataFrame containing the dataset to analyze.
238
+ If None, the agent's stored dataset will be used.
239
+ The DataFrame should contain at least one categorical column
240
+ for meaningful analysis.
241
 
242
  Returns:
243
  str: A markdown-formatted string containing analysis results and embedded plots.
 
281
  return plots + f"### Categorical Columns Analysis\n{pd.DataFrame(analysis).T.to_markdown()}"
282
 
283
  @tool
284
+ def suggest_features(data: Optional[pd.DataFrame] = None) -> str:
285
  """
286
  Suggest potential feature engineering steps based on data characteristics.
287
 
 
289
  recommend possible feature engineering steps that could improve model performance.
290
 
291
  Args:
292
+ data (Optional[pd.DataFrame]):
293
+ A pandas DataFrame containing the dataset to analyze.
294
+ If None, the agent's stored dataset will be used.
295
+ The DataFrame can contain both numerical and categorical columns.
296
 
297
  Returns:
298
  str: A string containing suggestions for feature engineering based on
 
330
  return "\n".join(suggestions)
331
 
332
  @tool
333
+ def predictive_analysis(data: Optional[pd.DataFrame] = None, target: Optional[str] = None) -> str:
334
  """
335
  Perform predictive analytics by training a classification model.
336
 
 
338
  and provides detailed metrics and visualizations such as the confusion matrix and ROC curve.
339
 
340
  Args:
341
+ data (Optional[pd.DataFrame]):
342
+ A pandas DataFrame containing the dataset to analyze.
343
+ If None, the agent's stored dataset will be used.
344
+ The DataFrame should contain the target variable for prediction.
345
+ target (Optional[str]):
346
+ The name of the target variable column in the dataset.
347
+ If None, the agent must provide the target variable through the prompt.
348
 
349
  Returns:
350
  str: A markdown-formatted string containing the classification report, confusion matrix,
 
353
  if data is None:
354
  data = tool.agent.dataset
355
 
356
+ if target is None or target not in data.columns:
357
+ return f"Error: Target column not specified or `{target}` not found in the dataset."
358
 
359
  # Handle categorical target
360
  if data[target].dtype == 'object' or data[target].dtype.name == 'category':