mgbam commited on
Commit
28e2398
·
verified ·
1 Parent(s): e207857

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -101
app.py CHANGED
@@ -5,60 +5,60 @@ from smolagents import CodeAgent, tool
5
  from typing import Union, List, Dict, Optional
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
- import base64
9
  import os
10
  from groq import Groq
11
- import io
12
  import tempfile
13
- import pdfkit
14
-
15
 
16
- # --------------------------------------
17
- # LLM Interface
18
- # --------------------------------------
19
  class GroqLLM:
20
- """Compatible LLM interface for smolagents CodeAgent."""
21
-
22
  def __init__(self, model_name="llama-3.1-8B-Instant"):
23
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
24
  self.model_name = model_name
25
-
26
  def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
27
- """Make the class callable as required by smolagents."""
28
  try:
 
29
  if isinstance(prompt, (dict, list)):
30
  prompt_str = str(prompt)
31
  else:
32
  prompt_str = str(prompt)
 
 
33
  completion = self.client.chat.completions.create(
34
  model=self.model_name,
35
- messages=[{"role": "user", "content": prompt_str}],
 
 
 
36
  temperature=0.7,
37
  max_tokens=1024,
38
- stream=False,
39
  )
 
40
  return completion.choices[0].message.content if completion.choices else "Error: No response generated"
 
41
  except Exception as e:
42
- return f"Error generating response: {str(e)}"
43
-
 
44
 
45
- # --------------------------------------
46
- # Dataset-Aware Agent
47
- # --------------------------------------
48
  class DataAnalysisAgent(CodeAgent):
49
- """Extended CodeAgent with dataset awareness."""
50
-
51
  def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
52
  super().__init__(*args, **kwargs)
53
  self._dataset = dataset
54
-
55
  @property
56
  def dataset(self) -> pd.DataFrame:
57
- """Access the stored dataset."""
58
  return self._dataset
59
 
60
  def run(self, prompt: str) -> str:
61
- """Override run method to include dataset context."""
62
  dataset_info = f"""
63
  Dataset Shape: {self.dataset.shape}
64
  Columns: {', '.join(self.dataset.columns)}
@@ -74,125 +74,220 @@ class DataAnalysisAgent(CodeAgent):
74
  """
75
  return super().run(enhanced_prompt)
76
 
77
-
78
- # --------------------------------------
79
- # Tools
80
- # --------------------------------------
81
  @tool
82
  def analyze_basic_stats(data: pd.DataFrame) -> str:
83
- """Calculate basic statistical measures for numerical columns."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if data is None:
85
  data = tool.agent.dataset
86
- stats = data.describe().to_markdown()
87
- return f"### Basic Statistics\n{stats}"
88
-
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  @tool
91
  def generate_correlation_matrix(data: pd.DataFrame) -> str:
92
- """Generate a visual correlation matrix for numerical columns."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if data is None:
94
  data = tool.agent.dataset
 
95
  numeric_data = data.select_dtypes(include=[np.number])
 
96
  plt.figure(figsize=(10, 8))
97
- sns.heatmap(numeric_data.corr(), annot=True, cmap="coolwarm")
98
- plt.title("Correlation Matrix")
 
99
  buf = io.BytesIO()
100
- plt.savefig(buf, format="png")
101
  plt.close()
102
  return base64.b64encode(buf.getvalue()).decode()
103
 
104
-
105
  @tool
106
  def analyze_categorical_columns(data: pd.DataFrame) -> str:
107
- """Analyze categorical columns in the dataset."""
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  if data is None:
109
  data = tool.agent.dataset
110
- categorical_cols = data.select_dtypes(include=["object", "category"]).columns
 
111
  analysis = {}
 
112
  for col in categorical_cols:
113
  analysis[col] = {
114
- "unique_values": data[col].nunique(),
115
- "top_categories": data[col].value_counts().head(5).to_dict(),
116
- "missing": data[col].isnull().sum(),
117
  }
 
118
  return str(analysis)
119
 
120
-
121
  @tool
122
  def suggest_features(data: pd.DataFrame) -> str:
123
- """Suggest potential feature engineering steps."""
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  if data is None:
125
  data = tool.agent.dataset
 
126
  suggestions = []
127
  numeric_cols = data.select_dtypes(include=[np.number]).columns
128
- categorical_cols = data.select_dtypes(include=["object", "category"]).columns
 
129
  if len(numeric_cols) >= 2:
130
  suggestions.append("Consider creating interaction terms between numerical features")
 
131
  if len(categorical_cols) > 0:
132
  suggestions.append("Consider one-hot encoding for categorical variables")
 
133
  for col in numeric_cols:
134
  if data[col].skew() > 1 or data[col].skew() < -1:
135
  suggestions.append(f"Consider log transformation for {col} due to skewness")
136
- return "\n".join(suggestions)
137
-
138
-
139
- # --------------------------------------
140
- # Export Report
141
- # --------------------------------------
142
- def export_report(content: str, filename: str):
143
- """Export analysis report as a PDF."""
144
- with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as tmp:
145
- tmp.write(content.encode("utf-8"))
146
- tmp_path = tmp.name
147
- pdf_path = f"{filename}.pdf"
148
- pdfkit.from_file(tmp_path, pdf_path)
149
- with open(pdf_path, "rb") as pdf_file:
150
- st.download_button(
151
- label="Download Report as PDF",
152
- data=pdf_file.read(),
153
- file_name=pdf_path,
154
- mime="application/pdf",
155
- )
156
- os.remove(tmp_path)
157
- os.remove(pdf_path)
158
-
159
 
160
- # --------------------------------------
161
- # Streamlit App
162
- # --------------------------------------
163
  def main():
164
  st.title("Data Analysis Assistant")
165
  st.write("Upload your dataset and get automated analysis with natural language interaction.")
166
-
167
- if "data" not in st.session_state:
168
- st.session_state["data"] = None
169
-
170
- uploaded_file = st.file_uploader("Upload CSV File", type="csv")
171
- if uploaded_file:
172
- st.session_state["data"] = pd.read_csv(uploaded_file)
173
- st.success(f"Loaded dataset with {st.session_state['data'].shape[0]} rows and {st.session_state['data'].shape[1]} columns.")
174
- st.dataframe(st.session_state["data"].head())
175
-
176
- agent = DataAnalysisAgent(
177
- dataset=st.session_state["data"],
178
- tools=[analyze_basic_stats, generate_correlation_matrix, analyze_categorical_columns, suggest_features],
179
- model=GroqLLM(),
180
- )
181
-
182
- analysis_type = st.selectbox("Choose Analysis Type", ["Basic Statistics", "Correlation Analysis", "Categorical Analysis", "Feature Suggestions"])
183
- if analysis_type == "Basic Statistics":
184
- st.markdown(agent.run("Analyze basic statistics."))
185
- elif analysis_type == "Correlation Analysis":
186
- result = agent.run("Generate a correlation matrix.")
187
- st.image(f"data:image/png;base64,{result}")
188
- elif analysis_type == "Categorical Analysis":
189
- st.markdown(agent.run("Analyze categorical columns."))
190
- elif analysis_type == "Feature Suggestions":
191
- st.markdown(agent.run("Suggest feature engineering ideas."))
192
-
193
- if st.button("Export Report"):
194
- export_report(agent.run("Generate full report."), "data_analysis_report")
195
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  if __name__ == "__main__":
198
- main()
 
5
  from typing import Union, List, Dict, Optional
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
 
8
  import os
9
  from groq import Groq
10
+ from dataclasses import dataclass
11
  import tempfile
12
+ import base64
13
+ import io
14
 
 
 
 
15
  class GroqLLM:
16
+ """Compatible LLM interface for smolagents CodeAgent"""
 
17
  def __init__(self, model_name="llama-3.1-8B-Instant"):
18
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
19
  self.model_name = model_name
20
+
21
  def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
22
+ """Make the class callable as required by smolagents"""
23
  try:
24
+ # Handle different prompt formats
25
  if isinstance(prompt, (dict, list)):
26
  prompt_str = str(prompt)
27
  else:
28
  prompt_str = str(prompt)
29
+
30
+ # Create a properly formatted message
31
  completion = self.client.chat.completions.create(
32
  model=self.model_name,
33
+ messages=[{
34
+ "role": "user",
35
+ "content": prompt_str
36
+ }],
37
  temperature=0.7,
38
  max_tokens=1024,
39
+ stream=False
40
  )
41
+
42
  return completion.choices[0].message.content if completion.choices else "Error: No response generated"
43
+
44
  except Exception as e:
45
+ error_msg = f"Error generating response: {str(e)}"
46
+ print(error_msg)
47
+ return error_msg
48
 
 
 
 
49
  class DataAnalysisAgent(CodeAgent):
50
+ """Extended CodeAgent with dataset awareness"""
 
51
  def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
52
  super().__init__(*args, **kwargs)
53
  self._dataset = dataset
54
+
55
  @property
56
  def dataset(self) -> pd.DataFrame:
57
+ """Access the stored dataset"""
58
  return self._dataset
59
 
60
  def run(self, prompt: str) -> str:
61
+ """Override run method to include dataset context"""
62
  dataset_info = f"""
63
  Dataset Shape: {self.dataset.shape}
64
  Columns: {', '.join(self.dataset.columns)}
 
74
  """
75
  return super().run(enhanced_prompt)
76
 
 
 
 
 
77
  @tool
78
  def analyze_basic_stats(data: pd.DataFrame) -> str:
79
+ """Calculate basic statistical measures for numerical columns in the dataset.
80
+
81
+ This function computes fundamental statistical metrics including mean, median,
82
+ standard deviation, skewness, and counts of missing values for all numerical
83
+ columns in the provided DataFrame.
84
+
85
+ Args:
86
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
87
+ should contain at least one numerical column for meaningful analysis.
88
+
89
+ Returns:
90
+ str: A string containing formatted basic statistics for each numerical column,
91
+ including mean, median, standard deviation, skewness, and missing value counts.
92
+ """
93
+ # Access dataset from agent if no data provided
94
  if data is None:
95
  data = tool.agent.dataset
96
+
97
+ stats = {}
98
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
99
+
100
+ for col in numeric_cols:
101
+ stats[col] = {
102
+ 'mean': float(data[col].mean()),
103
+ 'median': float(data[col].median()),
104
+ 'std': float(data[col].std()),
105
+ 'skew': float(data[col].skew()),
106
+ 'missing': int(data[col].isnull().sum())
107
+ }
108
+
109
+ return str(stats)
110
 
111
  @tool
112
  def generate_correlation_matrix(data: pd.DataFrame) -> str:
113
+ """Generate a visual correlation matrix for numerical columns in the dataset.
114
+
115
+ This function creates a heatmap visualization showing the correlations between
116
+ all numerical columns in the dataset. The correlation values are displayed
117
+ using a color-coded matrix for easy interpretation.
118
+
119
+ Args:
120
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
121
+ should contain at least two numerical columns for correlation analysis.
122
+
123
+ Returns:
124
+ str: A base64 encoded string representing the correlation matrix plot image,
125
+ which can be displayed in a web interface or saved as an image file.
126
+ """
127
+ # Access dataset from agent if no data provided
128
  if data is None:
129
  data = tool.agent.dataset
130
+
131
  numeric_data = data.select_dtypes(include=[np.number])
132
+
133
  plt.figure(figsize=(10, 8))
134
+ sns.heatmap(numeric_data.corr(), annot=True, cmap='coolwarm')
135
+ plt.title('Correlation Matrix')
136
+
137
  buf = io.BytesIO()
138
+ plt.savefig(buf, format='png')
139
  plt.close()
140
  return base64.b64encode(buf.getvalue()).decode()
141
 
 
142
  @tool
143
  def analyze_categorical_columns(data: pd.DataFrame) -> str:
144
+ """Analyze categorical columns in the dataset for distribution and frequencies.
145
+
146
+ This function examines categorical columns to identify unique values, top categories,
147
+ and missing value counts, providing insights into the categorical data distribution.
148
+
149
+ Args:
150
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
151
+ should contain at least one categorical column for meaningful analysis.
152
+
153
+ Returns:
154
+ str: A string containing formatted analysis results for each categorical column,
155
+ including unique value counts, top categories, and missing value counts.
156
+ """
157
+ # Access dataset from agent if no data provided
158
  if data is None:
159
  data = tool.agent.dataset
160
+
161
+ categorical_cols = data.select_dtypes(include=['object', 'category']).columns
162
  analysis = {}
163
+
164
  for col in categorical_cols:
165
  analysis[col] = {
166
+ 'unique_values': int(data[col].nunique()),
167
+ 'top_categories': data[col].value_counts().head(5).to_dict(),
168
+ 'missing': int(data[col].isnull().sum())
169
  }
170
+
171
  return str(analysis)
172
 
 
173
  @tool
174
  def suggest_features(data: pd.DataFrame) -> str:
175
+ """Suggest potential feature engineering steps based on data characteristics.
176
+
177
+ This function analyzes the dataset's structure and statistical properties to
178
+ recommend possible feature engineering steps that could improve model performance.
179
+
180
+ Args:
181
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
182
+ can contain both numerical and categorical columns.
183
+
184
+ Returns:
185
+ str: A string containing suggestions for feature engineering based on
186
+ the characteristics of the input data.
187
+ """
188
+ # Access dataset from agent if no data provided
189
  if data is None:
190
  data = tool.agent.dataset
191
+
192
  suggestions = []
193
  numeric_cols = data.select_dtypes(include=[np.number]).columns
194
+ categorical_cols = data.select_dtypes(include=['object', 'category']).columns
195
+
196
  if len(numeric_cols) >= 2:
197
  suggestions.append("Consider creating interaction terms between numerical features")
198
+
199
  if len(categorical_cols) > 0:
200
  suggestions.append("Consider one-hot encoding for categorical variables")
201
+
202
  for col in numeric_cols:
203
  if data[col].skew() > 1 or data[col].skew() < -1:
204
  suggestions.append(f"Consider log transformation for {col} due to skewness")
205
+
206
+ return '\n'.join(suggestions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
 
 
 
208
  def main():
209
  st.title("Data Analysis Assistant")
210
  st.write("Upload your dataset and get automated analysis with natural language interaction.")
211
+
212
+ # Initialize session state
213
+ if 'data' not in st.session_state:
214
+ st.session_state['data'] = None
215
+ if 'agent' not in st.session_state:
216
+ st.session_state['agent'] = None
217
+
218
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
219
+
220
+ try:
221
+ if uploaded_file is not None:
222
+ with st.spinner('Loading and processing your data...'):
223
+ # Load the dataset
224
+ data = pd.read_csv(uploaded_file)
225
+ st.session_state['data'] = data
226
+
227
+ # Initialize the agent with the dataset
228
+ st.session_state['agent'] = DataAnalysisAgent(
229
+ dataset=data,
230
+ tools=[analyze_basic_stats, generate_correlation_matrix,
231
+ analyze_categorical_columns, suggest_features],
232
+ model=GroqLLM(),
233
+ additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"]
234
+ )
235
+
236
+ st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
237
+ st.subheader("Data Preview")
238
+ st.dataframe(data.head())
239
+
240
+ if st.session_state['data'] is not None:
241
+ analysis_type = st.selectbox(
242
+ "Choose analysis type",
243
+ ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
244
+ "Feature Engineering", "Custom Question"]
245
+ )
246
+
247
+ if analysis_type == "Basic Statistics":
248
+ with st.spinner('Analyzing basic statistics...'):
249
+ result = st.session_state['agent'].run(
250
+ "Use the analyze_basic_stats tool to analyze this dataset and "
251
+ "provide insights about the numerical distributions."
252
+ )
253
+ st.write(result)
254
+
255
+ elif analysis_type == "Correlation Analysis":
256
+ with st.spinner('Generating correlation matrix...'):
257
+ result = st.session_state['agent'].run(
258
+ "Use the generate_correlation_matrix tool to analyze correlations "
259
+ "and explain any strong relationships found."
260
+ )
261
+ if isinstance(result, str) and result.startswith('data:image') or ',' in result:
262
+ st.image(f"data:image/png;base64,{result.split(',')[-1]}")
263
+ else:
264
+ st.write(result)
265
+
266
+ elif analysis_type == "Categorical Analysis":
267
+ with st.spinner('Analyzing categorical columns...'):
268
+ result = st.session_state['agent'].run(
269
+ "Use the analyze_categorical_columns tool to examine the "
270
+ "categorical variables and explain the distributions."
271
+ )
272
+ st.write(result)
273
+
274
+ elif analysis_type == "Feature Engineering":
275
+ with st.spinner('Generating feature suggestions...'):
276
+ result = st.session_state['agent'].run(
277
+ "Use the suggest_features tool to recommend potential "
278
+ "feature engineering steps for this dataset."
279
+ )
280
+ st.write(result)
281
+
282
+ elif analysis_type == "Custom Question":
283
+ question = st.text_input("What would you like to know about your data?")
284
+ if question:
285
+ with st.spinner('Analyzing...'):
286
+ result = st.session_state['agent'].run(question)
287
+ st.write(result)
288
+
289
+ except Exception as e:
290
+ st.error(f"An error occurred: {str(e)}")
291
 
292
  if __name__ == "__main__":
293
+ main()