mgbam commited on
Commit
e207857
·
verified ·
1 Parent(s): 57776e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -321
app.py CHANGED
@@ -1,104 +1,64 @@
1
- # app.py
2
-
3
  import streamlit as st
4
  import numpy as np
5
- import pandas as pd
6
  from smolagents import CodeAgent, tool
7
  from typing import Union, List, Dict, Optional
8
  import matplotlib.pyplot as plt
9
  import seaborn as sns
 
10
  import os
11
  from groq import Groq
12
- from dataclasses import dataclass
13
  import tempfile
14
- import base64
15
- import io
16
 
17
- # ------------------------------
18
- # Language Model Interface
19
- # ------------------------------
 
20
  class GroqLLM:
21
- """Compatible LLM interface for smolagents CodeAgent"""
22
-
23
- def __init__(self, model_name: str = "llama-3.1-8B-Instant"):
24
- """
25
- Initialize the GroqLLM with the specified model.
26
-
27
- Args:
28
- model_name (str): The name of the language model to use.
29
- """
30
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
31
  self.model_name = model_name
32
-
33
  def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
34
- """
35
- Make the class callable as required by smolagents.
36
-
37
- Args:
38
- prompt (Union[str, dict, List[Dict]]): The input prompt for the language model.
39
-
40
- Returns:
41
- str: The generated response from the language model.
42
- """
43
  try:
44
- # Handle different prompt formats
45
  if isinstance(prompt, (dict, list)):
46
  prompt_str = str(prompt)
47
  else:
48
  prompt_str = str(prompt)
49
-
50
- # Create a properly formatted message
51
  completion = self.client.chat.completions.create(
52
  model=self.model_name,
53
- messages=[{
54
- "role": "user",
55
- "content": prompt_str
56
- }],
57
  temperature=0.7,
58
  max_tokens=1024,
59
- stream=False
60
  )
61
-
62
  return completion.choices[0].message.content if completion.choices else "Error: No response generated"
63
-
64
  except Exception as e:
65
- error_msg = f"Error generating response: {str(e)}"
66
- print(error_msg)
67
- return error_msg
68
 
69
- # ------------------------------
70
- # Data Analysis Agent
71
- # ------------------------------
72
  class DataAnalysisAgent(CodeAgent):
73
- """Extended CodeAgent with dataset awareness"""
74
-
75
  def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
76
- """
77
- Initialize the DataAnalysisAgent with the provided dataset.
78
-
79
- Args:
80
- dataset (pd.DataFrame): The dataset to analyze.
81
- *args: Variable length argument list.
82
- **kwargs: Arbitrary keyword arguments.
83
- """
84
  super().__init__(*args, **kwargs)
85
  self._dataset = dataset
86
-
87
  @property
88
  def dataset(self) -> pd.DataFrame:
89
  """Access the stored dataset."""
90
  return self._dataset
91
-
92
  def run(self, prompt: str) -> str:
93
- """
94
- Override run method to include dataset context.
95
-
96
- Args:
97
- prompt (str): The task prompt for analysis.
98
-
99
- Returns:
100
- str: The result of the analysis.
101
- """
102
  dataset_info = f"""
103
  Dataset Shape: {self.dataset.shape}
104
  Columns: {', '.join(self.dataset.columns)}
@@ -114,303 +74,125 @@ class DataAnalysisAgent(CodeAgent):
114
  """
115
  return super().run(enhanced_prompt)
116
 
117
- # ------------------------------
118
- # Tool Definitions
119
- # ------------------------------
120
 
 
 
 
121
  @tool
122
- def analyze_basic_stats(data: Optional[pd.DataFrame] = None) -> str:
123
- """
124
- Calculate basic statistical measures for numerical columns in the dataset.
125
-
126
- This function computes fundamental statistical metrics including mean, median,
127
- standard deviation, skewness, and counts of missing values for all numerical
128
- columns in the provided DataFrame.
129
-
130
- Args:
131
- data (Optional[pd.DataFrame], optional):
132
- A pandas DataFrame containing the dataset to analyze. The DataFrame
133
- should contain at least one numerical column for meaningful analysis.
134
-
135
- Returns:
136
- str: A string containing formatted basic statistics for each numerical column,
137
- including mean, median, standard deviation, skewness, and missing value counts.
138
- """
139
- # Access dataset from agent if no data provided
140
  if data is None:
141
  data = tool.agent.dataset
142
-
143
- stats = {}
144
- numeric_cols = data.select_dtypes(include=[np.number]).columns
145
-
146
- for col in numeric_cols:
147
- stats[col] = {
148
- 'mean': float(data[col].mean()),
149
- 'median': float(data[col].median()),
150
- 'std': float(data[col].std()),
151
- 'skew': float(data[col].skew()),
152
- 'missing': int(data[col].isnull().sum())
153
- }
154
-
155
- return str(stats)
156
 
157
  @tool
158
- def generate_correlation_matrix(data: Optional[pd.DataFrame] = None) -> str:
159
- """
160
- Generate a visual correlation matrix for numerical columns in the dataset.
161
-
162
- This function creates a heatmap visualization showing the correlations between
163
- all numerical columns in the dataset. The correlation values are displayed
164
- using a color-coded matrix for easy interpretation.
165
-
166
- Args:
167
- data (Optional[pd.DataFrame], optional):
168
- A pandas DataFrame containing the dataset to analyze. The DataFrame
169
- should contain at least two numerical columns for correlation analysis.
170
-
171
- Returns:
172
- str: A base64 encoded string representing the correlation matrix plot image,
173
- which can be displayed in a web interface or saved as an image file.
174
- """
175
- # Access dataset from agent if no data provided
176
  if data is None:
177
  data = tool.agent.dataset
178
-
179
  numeric_data = data.select_dtypes(include=[np.number])
180
-
181
  plt.figure(figsize=(10, 8))
182
- sns.heatmap(numeric_data.corr(), annot=True, cmap='coolwarm')
183
- plt.title('Correlation Matrix')
184
-
185
  buf = io.BytesIO()
186
- plt.savefig(buf, format='png')
187
  plt.close()
188
  return base64.b64encode(buf.getvalue()).decode()
189
 
 
190
  @tool
191
- def analyze_categorical_columns(data: Optional[pd.DataFrame] = None) -> str:
192
- """
193
- Analyze categorical columns in the dataset for distribution and frequencies.
194
-
195
- This function examines categorical columns to identify unique values, top categories,
196
- and missing value counts, providing insights into the categorical data distribution.
197
-
198
- Args:
199
- data (Optional[pd.DataFrame], optional):
200
- A pandas DataFrame containing the dataset to analyze. The DataFrame
201
- should contain at least one categorical column for meaningful analysis.
202
-
203
- Returns:
204
- str: A string containing formatted analysis results for each categorical column,
205
- including unique value counts, top categories, and missing value counts.
206
- """
207
- # Access dataset from agent if no data provided
208
  if data is None:
209
  data = tool.agent.dataset
210
-
211
- categorical_cols = data.select_dtypes(include=['object', 'category']).columns
212
  analysis = {}
213
-
214
  for col in categorical_cols:
215
  analysis[col] = {
216
- 'unique_values': int(data[col].nunique()),
217
- 'top_categories': data[col].value_counts().head(5).to_dict(),
218
- 'missing': int(data[col].isnull().sum())
219
  }
220
-
221
  return str(analysis)
222
 
 
223
  @tool
224
- def suggest_features(data: Optional[pd.DataFrame] = None) -> str:
225
- """
226
- Suggest potential feature engineering steps based on data characteristics.
227
-
228
- This function analyzes the dataset's structure and statistical properties to
229
- recommend possible feature engineering steps that could improve model performance.
230
-
231
- Args:
232
- data (Optional[pd.DataFrame], optional):
233
- A pandas DataFrame containing the dataset to analyze. The DataFrame
234
- can contain both numerical and categorical columns.
235
-
236
- Returns:
237
- str: A string containing suggestions for feature engineering based on
238
- the characteristics of the input data.
239
- """
240
- # Access dataset from agent if no data provided
241
  if data is None:
242
  data = tool.agent.dataset
243
-
244
  suggestions = []
245
  numeric_cols = data.select_dtypes(include=[np.number]).columns
246
- categorical_cols = data.select_dtypes(include=['object', 'category']).columns
247
-
248
  if len(numeric_cols) >= 2:
249
  suggestions.append("Consider creating interaction terms between numerical features")
250
-
251
  if len(categorical_cols) > 0:
252
  suggestions.append("Consider one-hot encoding for categorical variables")
253
-
254
  for col in numeric_cols:
255
  if data[col].skew() > 1 or data[col].skew() < -1:
256
  suggestions.append(f"Consider log transformation for {col} due to skewness")
257
-
258
- return '\n'.join(suggestions)
259
 
260
- # ------------------------------
261
- # Report Exporting Function
262
- # ------------------------------
 
263
  def export_report(content: str, filename: str):
264
- """
265
- Export the given content as a PDF report.
266
-
267
- This function converts markdown content into a PDF file using pdfkit and provides
268
- a download button for users to obtain the report.
269
-
270
- Args:
271
- content (str): The markdown content to be included in the PDF report.
272
- filename (str): The desired name for the exported PDF file.
273
-
274
- Returns:
275
- None
276
- """
277
- # Save content to a temporary HTML file
278
- with tempfile.NamedTemporaryFile(delete=False, suffix='.html') as tmp_file:
279
- tmp_file.write(content.encode('utf-8'))
280
- tmp_file_path = tmp_file.name
281
-
282
- # Define output PDF path
283
  pdf_path = f"{filename}.pdf"
284
-
285
- # Convert HTML to PDF using pdfkit
286
- try:
287
- # Configure pdfkit options for HuggingFace Spaces environment
288
- config = pdfkit.configuration()
289
- pdfkit.from_file(tmp_file_path, pdf_path, configuration=config)
290
- with open(pdf_path, "rb") as pdf_file:
291
- PDFbyte = pdf_file.read()
292
-
293
- # Provide download link
294
- st.download_button(label="📥 Download Report as PDF",
295
- data=PDFbyte,
296
- file_name=pdf_path,
297
- mime='application/octet-stream')
298
- except Exception as e:
299
- st.error(f"⚠️ Error exporting report: {str(e)}")
300
- finally:
301
- os.remove(tmp_file_path)
302
- if os.path.exists(pdf_path):
303
- os.remove(pdf_path)
304
 
305
- # ------------------------------
306
- # Main Application Function
307
- # ------------------------------
308
  def main():
309
- st.set_page_config(page_title="📊 Business Intelligence Assistant", layout="wide")
310
- st.title("📊 **Business Intelligence Assistant**")
311
  st.write("Upload your dataset and get automated analysis with natural language interaction.")
312
-
313
- # Initialize session state
314
- if 'data' not in st.session_state:
315
- st.session_state['data'] = None
316
- if 'agent' not in st.session_state:
317
- st.session_state['agent'] = None
318
- if 'report_content' not in st.session_state:
319
- st.session_state['report_content'] = ""
320
-
321
- uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
322
-
323
- try:
324
- if uploaded_file is not None:
325
- with st.spinner('🔄 Loading and processing your data...'):
326
- # Load the dataset
327
- data = pd.read_csv(uploaded_file)
328
- st.session_state['data'] = data
329
-
330
- # Initialize the agent with the dataset
331
- st.session_state['agent'] = DataAnalysisAgent(
332
- dataset=data,
333
- tools=[analyze_basic_stats, generate_correlation_matrix,
334
- analyze_categorical_columns, suggest_features],
335
- model=GroqLLM(),
336
- additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"]
337
- )
338
-
339
- st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
340
- st.subheader("🔍 **Data Preview**")
341
- st.dataframe(data.head())
342
-
343
- if st.session_state['data'] is not None:
344
- analysis_type = st.selectbox(
345
- "Choose analysis type",
346
- ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
347
- "Feature Engineering", "Custom Question"]
348
- )
349
-
350
- if analysis_type == "Basic Statistics":
351
- with st.spinner('Analyzing basic statistics...'):
352
- result = st.session_state['agent'].run(
353
- "Use the analyze_basic_stats tool to analyze this dataset and "
354
- "provide insights about the numerical distributions."
355
- )
356
- st.write(result)
357
- st.session_state['report_content'] += result + "\n\n"
358
-
359
- elif analysis_type == "Correlation Analysis":
360
- with st.spinner('Generating correlation matrix...'):
361
- result = st.session_state['agent'].run(
362
- "Use the generate_correlation_matrix tool to analyze correlations "
363
- "and explain any strong relationships found."
364
- )
365
- if isinstance(result, str) and 'base64' in result:
366
- # Extract base64 string and display the image
367
- image_data = f"data:image/png;base64,{result}"
368
- st.image(image_data, caption='Correlation Matrix')
369
- else:
370
- st.write(result)
371
- st.session_state['report_content'] += "### Correlation Analysis\n" + result + "\n\n"
372
-
373
- elif analysis_type == "Categorical Analysis":
374
- with st.spinner('Analyzing categorical columns...'):
375
- result = st.session_state['agent'].run(
376
- "Use the analyze_categorical_columns tool to examine the "
377
- "categorical variables and explain the distributions."
378
- )
379
- st.write(result)
380
- st.session_state['report_content'] += "### Categorical Analysis\n" + result + "\n\n"
381
-
382
- elif analysis_type == "Feature Engineering":
383
- with st.spinner('Generating feature suggestions...'):
384
- result = st.session_state['agent'].run(
385
- "Use the suggest_features tool to recommend potential "
386
- "feature engineering steps for this dataset."
387
- )
388
- st.write(result)
389
- st.session_state['report_content'] += "### Feature Engineering Suggestions\n" + result + "\n\n"
390
-
391
- elif analysis_type == "Custom Question":
392
- question = st.text_input("What would you like to know about your data?")
393
- if st.button("🔍 Get Answer"):
394
- if question:
395
- with st.spinner('Analyzing...'):
396
- result = st.session_state['agent'].run(question)
397
- st.write(result)
398
- st.session_state['report_content'] += f"### Custom Question: {question}\n{result}\n\n"
399
- else:
400
- st.warning("Please enter a question.")
401
-
402
- # Option to Export Report
403
- if st.session_state['report_content']:
404
- st.markdown("---")
405
- if st.button("📤 **Export Analysis Report**"):
406
- export_report(st.session_state['report_content'], "Business_Intelligence_Report")
407
- st.success("✅ Report exported successfully!")
408
-
409
- except Exception as e:
410
- st.error(f"⚠️ An error occurred: {str(e)}")
411
 
412
- # ------------------------------
413
- # Application Entry Point
414
- # ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  if __name__ == "__main__":
416
  main()
 
 
 
1
  import streamlit as st
2
  import numpy as np
3
+ import pandas as pd
4
  from smolagents import CodeAgent, tool
5
  from typing import Union, List, Dict, Optional
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
+ import base64
9
  import os
10
  from groq import Groq
11
+ import io
12
  import tempfile
13
+ import pdfkit
 
14
 
15
+
16
+ # --------------------------------------
17
+ # LLM Interface
18
+ # --------------------------------------
19
  class GroqLLM:
20
+ """Compatible LLM interface for smolagents CodeAgent."""
21
+
22
+ def __init__(self, model_name="llama-3.1-8B-Instant"):
 
 
 
 
 
 
23
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
24
  self.model_name = model_name
25
+
26
  def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
27
+ """Make the class callable as required by smolagents."""
 
 
 
 
 
 
 
 
28
  try:
 
29
  if isinstance(prompt, (dict, list)):
30
  prompt_str = str(prompt)
31
  else:
32
  prompt_str = str(prompt)
 
 
33
  completion = self.client.chat.completions.create(
34
  model=self.model_name,
35
+ messages=[{"role": "user", "content": prompt_str}],
 
 
 
36
  temperature=0.7,
37
  max_tokens=1024,
38
+ stream=False,
39
  )
 
40
  return completion.choices[0].message.content if completion.choices else "Error: No response generated"
 
41
  except Exception as e:
42
+ return f"Error generating response: {str(e)}"
43
+
 
44
 
45
+ # --------------------------------------
46
+ # Dataset-Aware Agent
47
+ # --------------------------------------
48
  class DataAnalysisAgent(CodeAgent):
49
+ """Extended CodeAgent with dataset awareness."""
50
+
51
  def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
 
 
 
 
 
 
 
 
52
  super().__init__(*args, **kwargs)
53
  self._dataset = dataset
54
+
55
  @property
56
  def dataset(self) -> pd.DataFrame:
57
  """Access the stored dataset."""
58
  return self._dataset
59
+
60
  def run(self, prompt: str) -> str:
61
+ """Override run method to include dataset context."""
 
 
 
 
 
 
 
 
62
  dataset_info = f"""
63
  Dataset Shape: {self.dataset.shape}
64
  Columns: {', '.join(self.dataset.columns)}
 
74
  """
75
  return super().run(enhanced_prompt)
76
 
 
 
 
77
 
78
+ # --------------------------------------
79
+ # Tools
80
+ # --------------------------------------
81
  @tool
82
+ def analyze_basic_stats(data: pd.DataFrame) -> str:
83
+ """Calculate basic statistical measures for numerical columns."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if data is None:
85
  data = tool.agent.dataset
86
+ stats = data.describe().to_markdown()
87
+ return f"### Basic Statistics\n{stats}"
88
+
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  @tool
91
+ def generate_correlation_matrix(data: pd.DataFrame) -> str:
92
+ """Generate a visual correlation matrix for numerical columns."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if data is None:
94
  data = tool.agent.dataset
 
95
  numeric_data = data.select_dtypes(include=[np.number])
 
96
  plt.figure(figsize=(10, 8))
97
+ sns.heatmap(numeric_data.corr(), annot=True, cmap="coolwarm")
98
+ plt.title("Correlation Matrix")
 
99
  buf = io.BytesIO()
100
+ plt.savefig(buf, format="png")
101
  plt.close()
102
  return base64.b64encode(buf.getvalue()).decode()
103
 
104
+
105
  @tool
106
+ def analyze_categorical_columns(data: pd.DataFrame) -> str:
107
+ """Analyze categorical columns in the dataset."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  if data is None:
109
  data = tool.agent.dataset
110
+ categorical_cols = data.select_dtypes(include=["object", "category"]).columns
 
111
  analysis = {}
 
112
  for col in categorical_cols:
113
  analysis[col] = {
114
+ "unique_values": data[col].nunique(),
115
+ "top_categories": data[col].value_counts().head(5).to_dict(),
116
+ "missing": data[col].isnull().sum(),
117
  }
 
118
  return str(analysis)
119
 
120
+
121
  @tool
122
+ def suggest_features(data: pd.DataFrame) -> str:
123
+ """Suggest potential feature engineering steps."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  if data is None:
125
  data = tool.agent.dataset
 
126
  suggestions = []
127
  numeric_cols = data.select_dtypes(include=[np.number]).columns
128
+ categorical_cols = data.select_dtypes(include=["object", "category"]).columns
 
129
  if len(numeric_cols) >= 2:
130
  suggestions.append("Consider creating interaction terms between numerical features")
 
131
  if len(categorical_cols) > 0:
132
  suggestions.append("Consider one-hot encoding for categorical variables")
 
133
  for col in numeric_cols:
134
  if data[col].skew() > 1 or data[col].skew() < -1:
135
  suggestions.append(f"Consider log transformation for {col} due to skewness")
136
+ return "\n".join(suggestions)
 
137
 
138
+
139
+ # --------------------------------------
140
+ # Export Report
141
+ # --------------------------------------
142
  def export_report(content: str, filename: str):
143
+ """Export analysis report as a PDF."""
144
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as tmp:
145
+ tmp.write(content.encode("utf-8"))
146
+ tmp_path = tmp.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  pdf_path = f"{filename}.pdf"
148
+ pdfkit.from_file(tmp_path, pdf_path)
149
+ with open(pdf_path, "rb") as pdf_file:
150
+ st.download_button(
151
+ label="Download Report as PDF",
152
+ data=pdf_file.read(),
153
+ file_name=pdf_path,
154
+ mime="application/pdf",
155
+ )
156
+ os.remove(tmp_path)
157
+ os.remove(pdf_path)
158
+
 
 
 
 
 
 
 
 
 
159
 
160
+ # --------------------------------------
161
+ # Streamlit App
162
+ # --------------------------------------
163
  def main():
164
+ st.title("Data Analysis Assistant")
 
165
  st.write("Upload your dataset and get automated analysis with natural language interaction.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ if "data" not in st.session_state:
168
+ st.session_state["data"] = None
169
+
170
+ uploaded_file = st.file_uploader("Upload CSV File", type="csv")
171
+ if uploaded_file:
172
+ st.session_state["data"] = pd.read_csv(uploaded_file)
173
+ st.success(f"Loaded dataset with {st.session_state['data'].shape[0]} rows and {st.session_state['data'].shape[1]} columns.")
174
+ st.dataframe(st.session_state["data"].head())
175
+
176
+ agent = DataAnalysisAgent(
177
+ dataset=st.session_state["data"],
178
+ tools=[analyze_basic_stats, generate_correlation_matrix, analyze_categorical_columns, suggest_features],
179
+ model=GroqLLM(),
180
+ )
181
+
182
+ analysis_type = st.selectbox("Choose Analysis Type", ["Basic Statistics", "Correlation Analysis", "Categorical Analysis", "Feature Suggestions"])
183
+ if analysis_type == "Basic Statistics":
184
+ st.markdown(agent.run("Analyze basic statistics."))
185
+ elif analysis_type == "Correlation Analysis":
186
+ result = agent.run("Generate a correlation matrix.")
187
+ st.image(f"data:image/png;base64,{result}")
188
+ elif analysis_type == "Categorical Analysis":
189
+ st.markdown(agent.run("Analyze categorical columns."))
190
+ elif analysis_type == "Feature Suggestions":
191
+ st.markdown(agent.run("Suggest feature engineering ideas."))
192
+
193
+ if st.button("Export Report"):
194
+ export_report(agent.run("Generate full report."), "data_analysis_report")
195
+
196
+
197
  if __name__ == "__main__":
198
  main()