mgbam commited on
Commit
659fba8
·
verified ·
1 Parent(s): 16f65e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -9
app.py CHANGED
@@ -11,13 +11,13 @@ import base64
11
  import io
12
 
13
  class GroqLLM:
14
- """Compatible LLM interface for smolagents CodeAgent"""
15
  def __init__(self, model_name="llama-3.1-8B-Instant"):
16
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
17
  self.model_name = model_name
18
 
19
  def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
20
- """Make the class callable as required by smolagents"""
21
  try:
22
  # Handle different prompt formats
23
  if isinstance(prompt, (dict, list)):
@@ -49,18 +49,18 @@ class GroqLLM:
49
  return error_msg
50
 
51
  class DataAnalysisAgent(CodeAgent):
52
- """Extended CodeAgent with dataset awareness"""
53
  def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
54
  super().__init__(*args, **kwargs)
55
  self._dataset = dataset
56
 
57
  @property
58
  def dataset(self) -> pd.DataFrame:
59
- """Access the stored dataset"""
60
  return self._dataset
61
 
62
  def run(self, prompt: str) -> str:
63
- """Override run method to include dataset context"""
64
  dataset_info = f"""
65
  Dataset Shape: {self.dataset.shape}
66
  Columns: {', '.join(self.dataset.columns)}
@@ -78,7 +78,15 @@ class DataAnalysisAgent(CodeAgent):
78
 
79
  @tool
80
  def analyze_basic_stats(data: pd.DataFrame) -> str:
81
- """Calculate basic statistical measures for numerical columns in the dataset."""
 
 
 
 
 
 
 
 
82
  if data is None:
83
  data = tool.agent.dataset
84
 
@@ -98,7 +106,14 @@ def analyze_basic_stats(data: pd.DataFrame) -> str:
98
 
99
  @tool
100
  def generate_correlation_matrix(data: pd.DataFrame) -> str:
101
- """Generate a visual correlation matrix for numerical columns in the dataset."""
 
 
 
 
 
 
 
102
  if data is None:
103
  data = tool.agent.dataset
104
 
@@ -115,7 +130,15 @@ def generate_correlation_matrix(data: pd.DataFrame) -> str:
115
 
116
  @tool
117
  def analyze_categorical_columns(data: pd.DataFrame) -> str:
118
- """Analyze categorical columns in the dataset for distribution and frequencies."""
 
 
 
 
 
 
 
 
119
  if data is None:
120
  data = tool.agent.dataset
121
 
@@ -133,7 +156,15 @@ def analyze_categorical_columns(data: pd.DataFrame) -> str:
133
 
134
  @tool
135
  def suggest_features(data: pd.DataFrame) -> str:
136
- """Suggest potential feature engineering steps based on data characteristics."""
 
 
 
 
 
 
 
 
137
  if data is None:
138
  data = tool.agent.dataset
139
 
 
11
  import io
12
 
13
  class GroqLLM:
14
+ """Compatible LLM interface for smolagents CodeAgent."""
15
  def __init__(self, model_name="llama-3.1-8B-Instant"):
16
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
17
  self.model_name = model_name
18
 
19
  def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
20
+ """Make the class callable as required by smolagents."""
21
  try:
22
  # Handle different prompt formats
23
  if isinstance(prompt, (dict, list)):
 
49
  return error_msg
50
 
51
  class DataAnalysisAgent(CodeAgent):
52
+ """Extended CodeAgent with dataset awareness."""
53
  def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
54
  super().__init__(*args, **kwargs)
55
  self._dataset = dataset
56
 
57
  @property
58
  def dataset(self) -> pd.DataFrame:
59
+ """Access the stored dataset."""
60
  return self._dataset
61
 
62
  def run(self, prompt: str) -> str:
63
+ """Override run method to include dataset context."""
64
  dataset_info = f"""
65
  Dataset Shape: {self.dataset.shape}
66
  Columns: {', '.join(self.dataset.columns)}
 
78
 
79
  @tool
80
  def analyze_basic_stats(data: pd.DataFrame) -> str:
81
+ """Calculate basic statistical measures for numerical columns in the dataset.
82
+
83
+ Args:
84
+ data (pd.DataFrame): The dataset to analyze. It should contain at least one numerical column.
85
+
86
+ Returns:
87
+ str: A string containing formatted basic statistics for each numerical column,
88
+ including mean, median, standard deviation, skewness, and missing value counts.
89
+ """
90
  if data is None:
91
  data = tool.agent.dataset
92
 
 
106
 
107
  @tool
108
  def generate_correlation_matrix(data: pd.DataFrame) -> str:
109
+ """Generate a visual correlation matrix for numerical columns in the dataset.
110
+
111
+ Args:
112
+ data (pd.DataFrame): The dataset to analyze. It should contain at least two numerical columns.
113
+
114
+ Returns:
115
+ str: A base64 encoded string representing the correlation matrix plot image.
116
+ """
117
  if data is None:
118
  data = tool.agent.dataset
119
 
 
130
 
131
  @tool
132
  def analyze_categorical_columns(data: pd.DataFrame) -> str:
133
+ """Analyze categorical columns in the dataset for distribution and frequencies.
134
+
135
+ Args:
136
+ data (pd.DataFrame): The dataset to analyze. It should contain at least one categorical column.
137
+
138
+ Returns:
139
+ str: A string containing formatted analysis results for each categorical column,
140
+ including unique value counts, top categories, and missing value counts.
141
+ """
142
  if data is None:
143
  data = tool.agent.dataset
144
 
 
156
 
157
  @tool
158
  def suggest_features(data: pd.DataFrame) -> str:
159
+ """Suggest potential feature engineering steps based on data characteristics.
160
+
161
+ Args:
162
+ data (pd.DataFrame): The dataset to analyze. It can contain both numerical and categorical columns.
163
+
164
+ Returns:
165
+ str: A string containing suggestions for feature engineering based on
166
+ the characteristics of the input data.
167
+ """
168
  if data is None:
169
  data = tool.agent.dataset
170