Update app.py
Browse files
app.py
CHANGED
@@ -11,13 +11,13 @@ import base64
|
|
11 |
import io
|
12 |
|
13 |
class GroqLLM:
|
14 |
-
"""Compatible LLM interface for smolagents CodeAgent"""
|
15 |
def __init__(self, model_name="llama-3.1-8B-Instant"):
|
16 |
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
17 |
self.model_name = model_name
|
18 |
|
19 |
def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
|
20 |
-
"""Make the class callable as required by smolagents"""
|
21 |
try:
|
22 |
# Handle different prompt formats
|
23 |
if isinstance(prompt, (dict, list)):
|
@@ -49,18 +49,18 @@ class GroqLLM:
|
|
49 |
return error_msg
|
50 |
|
51 |
class DataAnalysisAgent(CodeAgent):
|
52 |
-
"""Extended CodeAgent with dataset awareness"""
|
53 |
def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
|
54 |
super().__init__(*args, **kwargs)
|
55 |
self._dataset = dataset
|
56 |
|
57 |
@property
|
58 |
def dataset(self) -> pd.DataFrame:
|
59 |
-
"""Access the stored dataset"""
|
60 |
return self._dataset
|
61 |
|
62 |
def run(self, prompt: str) -> str:
|
63 |
-
"""Override run method to include dataset context"""
|
64 |
dataset_info = f"""
|
65 |
Dataset Shape: {self.dataset.shape}
|
66 |
Columns: {', '.join(self.dataset.columns)}
|
@@ -78,7 +78,15 @@ class DataAnalysisAgent(CodeAgent):
|
|
78 |
|
79 |
@tool
|
80 |
def analyze_basic_stats(data: pd.DataFrame) -> str:
|
81 |
-
"""Calculate basic statistical measures for numerical columns in the dataset.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
if data is None:
|
83 |
data = tool.agent.dataset
|
84 |
|
@@ -98,7 +106,14 @@ def analyze_basic_stats(data: pd.DataFrame) -> str:
|
|
98 |
|
99 |
@tool
|
100 |
def generate_correlation_matrix(data: pd.DataFrame) -> str:
|
101 |
-
"""Generate a visual correlation matrix for numerical columns in the dataset.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
if data is None:
|
103 |
data = tool.agent.dataset
|
104 |
|
@@ -115,7 +130,15 @@ def generate_correlation_matrix(data: pd.DataFrame) -> str:
|
|
115 |
|
116 |
@tool
|
117 |
def analyze_categorical_columns(data: pd.DataFrame) -> str:
|
118 |
-
"""Analyze categorical columns in the dataset for distribution and frequencies.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
if data is None:
|
120 |
data = tool.agent.dataset
|
121 |
|
@@ -133,7 +156,15 @@ def analyze_categorical_columns(data: pd.DataFrame) -> str:
|
|
133 |
|
134 |
@tool
|
135 |
def suggest_features(data: pd.DataFrame) -> str:
|
136 |
-
"""Suggest potential feature engineering steps based on data characteristics.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
if data is None:
|
138 |
data = tool.agent.dataset
|
139 |
|
|
|
11 |
import io
|
12 |
|
13 |
class GroqLLM:
|
14 |
+
"""Compatible LLM interface for smolagents CodeAgent."""
|
15 |
def __init__(self, model_name="llama-3.1-8B-Instant"):
|
16 |
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
17 |
self.model_name = model_name
|
18 |
|
19 |
def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
|
20 |
+
"""Make the class callable as required by smolagents."""
|
21 |
try:
|
22 |
# Handle different prompt formats
|
23 |
if isinstance(prompt, (dict, list)):
|
|
|
49 |
return error_msg
|
50 |
|
51 |
class DataAnalysisAgent(CodeAgent):
|
52 |
+
"""Extended CodeAgent with dataset awareness."""
|
53 |
def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
|
54 |
super().__init__(*args, **kwargs)
|
55 |
self._dataset = dataset
|
56 |
|
57 |
@property
|
58 |
def dataset(self) -> pd.DataFrame:
|
59 |
+
"""Access the stored dataset."""
|
60 |
return self._dataset
|
61 |
|
62 |
def run(self, prompt: str) -> str:
|
63 |
+
"""Override run method to include dataset context."""
|
64 |
dataset_info = f"""
|
65 |
Dataset Shape: {self.dataset.shape}
|
66 |
Columns: {', '.join(self.dataset.columns)}
|
|
|
78 |
|
79 |
@tool
|
80 |
def analyze_basic_stats(data: pd.DataFrame) -> str:
|
81 |
+
"""Calculate basic statistical measures for numerical columns in the dataset.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
data (pd.DataFrame): The dataset to analyze. It should contain at least one numerical column.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
str: A string containing formatted basic statistics for each numerical column,
|
88 |
+
including mean, median, standard deviation, skewness, and missing value counts.
|
89 |
+
"""
|
90 |
if data is None:
|
91 |
data = tool.agent.dataset
|
92 |
|
|
|
106 |
|
107 |
@tool
|
108 |
def generate_correlation_matrix(data: pd.DataFrame) -> str:
|
109 |
+
"""Generate a visual correlation matrix for numerical columns in the dataset.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
data (pd.DataFrame): The dataset to analyze. It should contain at least two numerical columns.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
str: A base64 encoded string representing the correlation matrix plot image.
|
116 |
+
"""
|
117 |
if data is None:
|
118 |
data = tool.agent.dataset
|
119 |
|
|
|
130 |
|
131 |
@tool
|
132 |
def analyze_categorical_columns(data: pd.DataFrame) -> str:
|
133 |
+
"""Analyze categorical columns in the dataset for distribution and frequencies.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
data (pd.DataFrame): The dataset to analyze. It should contain at least one categorical column.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
str: A string containing formatted analysis results for each categorical column,
|
140 |
+
including unique value counts, top categories, and missing value counts.
|
141 |
+
"""
|
142 |
if data is None:
|
143 |
data = tool.agent.dataset
|
144 |
|
|
|
156 |
|
157 |
@tool
|
158 |
def suggest_features(data: pd.DataFrame) -> str:
|
159 |
+
"""Suggest potential feature engineering steps based on data characteristics.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
data (pd.DataFrame): The dataset to analyze. It can contain both numerical and categorical columns.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
str: A string containing suggestions for feature engineering based on
|
166 |
+
the characteristics of the input data.
|
167 |
+
"""
|
168 |
if data is None:
|
169 |
data = tool.agent.dataset
|
170 |
|