girishwangikar commited on
Commit
7ef0257
·
verified ·
1 Parent(s): 4ed3667

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -101
app.py CHANGED
@@ -14,7 +14,7 @@ import io
14
 
15
  class GroqLLM:
16
  """Compatible LLM interface for smolagents CodeAgent"""
17
- def __init__(self, model_name="llama-3.1-8B-Instant"):
18
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
19
  self.model_name = model_name
20
 
@@ -23,7 +23,6 @@ class GroqLLM:
23
  try:
24
  # Handle different prompt formats
25
  if isinstance(prompt, (dict, list)):
26
- # If prompt is a dictionary or list, convert it to a string representation
27
  prompt_str = str(prompt)
28
  else:
29
  prompt_str = str(prompt)
@@ -40,17 +39,41 @@ class GroqLLM:
40
  stream=False
41
  )
42
 
43
- # Extract and return the response content
44
- if completion.choices and len(completion.choices) > 0:
45
- return completion.choices[0].message.content
46
- return "Error: No response generated"
47
 
48
  except Exception as e:
49
- # Provide more detailed error handling
50
  error_msg = f"Error generating response: {str(e)}"
51
- print(error_msg) # Log the error
52
  return error_msg
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  @tool
55
  def analyze_basic_stats(data: pd.DataFrame) -> str:
56
  """Calculate basic statistical measures for numerical columns in the dataset.
@@ -67,16 +90,20 @@ def analyze_basic_stats(data: pd.DataFrame) -> str:
67
  str: A string containing formatted basic statistics for each numerical column,
68
  including mean, median, standard deviation, skewness, and missing value counts.
69
  """
 
 
 
 
70
  stats = {}
71
  numeric_cols = data.select_dtypes(include=[np.number]).columns
72
 
73
  for col in numeric_cols:
74
  stats[col] = {
75
- 'mean': data[col].mean(),
76
- 'median': data[col].median(),
77
- 'std': data[col].std(),
78
- 'skew': data[col].skew(),
79
- 'missing': data[col].isnull().sum()
80
  }
81
 
82
  return str(stats)
@@ -97,6 +124,10 @@ def generate_correlation_matrix(data: pd.DataFrame) -> str:
97
  str: A base64 encoded string representing the correlation matrix plot image,
98
  which can be displayed in a web interface or saved as an image file.
99
  """
 
 
 
 
100
  numeric_data = data.select_dtypes(include=[np.number])
101
 
102
  plt.figure(figsize=(10, 8))
@@ -117,21 +148,24 @@ def analyze_categorical_columns(data: pd.DataFrame) -> str:
117
 
118
  Args:
119
  data: A pandas DataFrame containing the dataset to analyze. The DataFrame
120
- should contain at least one categorical column (object or category dtype)
121
- for meaningful analysis.
122
 
123
  Returns:
124
  str: A string containing formatted analysis results for each categorical column,
125
  including unique value counts, top categories, and missing value counts.
126
  """
 
 
 
 
127
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
128
  analysis = {}
129
 
130
  for col in categorical_cols:
131
  analysis[col] = {
132
- 'unique_values': data[col].nunique(),
133
  'top_categories': data[col].value_counts().head(5).to_dict(),
134
- 'missing': data[col].isnull().sum()
135
  }
136
 
137
  return str(analysis)
@@ -145,13 +179,16 @@ def suggest_features(data: pd.DataFrame) -> str:
145
 
146
  Args:
147
  data: A pandas DataFrame containing the dataset to analyze. The DataFrame
148
- can contain both numerical and categorical columns for feature
149
- engineering suggestions.
150
 
151
  Returns:
152
- str: A string containing line-separated suggestions for feature engineering,
153
- based on the characteristics of the input data.
154
  """
 
 
 
 
155
  suggestions = []
156
  numeric_cols = data.select_dtypes(include=[np.number]).columns
157
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
@@ -168,106 +205,89 @@ def suggest_features(data: pd.DataFrame) -> str:
168
 
169
  return '\n'.join(suggestions)
170
 
171
- # Initialize session state at the start
172
- if 'data' not in st.session_state:
173
- st.session_state['data'] = None
174
- if 'file_uploaded' not in st.session_state:
175
- st.session_state['file_uploaded'] = False
176
- if 'processing' not in st.session_state:
177
- st.session_state['processing'] = False
178
- if 'agent' not in st.session_state:
179
- st.session_state['agent'] = None
180
-
181
  def main():
182
  st.title("Data Analysis Assistant")
183
  st.write("Upload your dataset and get automated analysis with natural language interaction.")
184
 
185
- # File uploader with error handling
 
 
 
 
 
186
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
187
 
188
  try:
189
- if uploaded_file is not None and not st.session_state['file_uploaded']:
190
- # Show loading spinner while processing the file
191
  with st.spinner('Loading and processing your data...'):
192
- try:
193
- data = pd.read_csv(uploaded_file)
194
- st.session_state['data'] = data
195
- st.session_state['file_uploaded'] = True
196
-
197
- # Initialize agent with GroqLLM
198
- st.session_state['agent'] = CodeAgent(
199
- tools=[analyze_basic_stats, generate_correlation_matrix,
200
- analyze_categorical_columns, suggest_features],
201
- model=GroqLLM(),
202
- additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"]
203
- )
204
-
205
- # Show success message
206
- st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
207
-
208
- # Display data preview
209
- st.subheader("Data Preview")
210
- st.dataframe(data.head())
211
-
212
- except Exception as e:
213
- st.error(f"Error loading file: {str(e)}")
214
- st.session_state['file_uploaded'] = False
215
- return
216
 
217
- # Only show analysis options if data is loaded
218
- if st.session_state['file_uploaded'] and st.session_state['data'] is not None:
219
- # Analysis options
220
  analysis_type = st.selectbox(
221
  "Choose analysis type",
222
  ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
223
  "Feature Engineering", "Custom Question"]
224
  )
225
 
226
- # Process analysis with loading indicators
227
- if analysis_type:
228
- with st.spinner(f'Performing {analysis_type.lower()}...'):
229
- if analysis_type == "Basic Statistics":
230
- result = st.session_state['agent'].run(
231
- f"Analyze and explain the basic statistics of this dataset. "
232
- f"Dataset info: {st.session_state['data'].info()}\n"
233
- f"Use the analyze_basic_stats tool and provide natural language explanations."
234
- )
235
- st.write(result)
236
-
237
- elif analysis_type == "Correlation Analysis":
238
- correlation_plot = st.session_state['agent'].run(
239
- "Generate and explain correlations between numerical variables. "
240
- "Use the generate_correlation_matrix tool."
241
- )
242
- if correlation_plot:
243
- st.image(f"data:image/png;base64,{correlation_plot}")
244
-
245
- elif analysis_type == "Categorical Analysis":
246
- result = st.session_state['agent'].run(
247
- "Analyze categorical variables in the dataset. "
248
- "Use the analyze_categorical_columns tool and explain the findings."
249
- )
250
  st.write(result)
251
-
252
- elif analysis_type == "Feature Engineering":
253
- result = st.session_state['agent'].run(
254
- "Suggest potential feature engineering steps for this dataset. "
255
- "Use the suggest_features tool and explain your suggestions."
256
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  st.write(result)
258
 
259
- elif analysis_type == "Custom Question":
260
- question = st.text_input("What would you like to know about your data?")
261
- if question:
262
- result = st.session_state['agent'].run(
263
- f"Answer this question about the dataset: {question}\n"
264
- f"Use appropriate tools to analyze and explain."
265
- )
266
- st.write(result)
267
-
268
  except Exception as e:
269
  st.error(f"An error occurred: {str(e)}")
270
- st.session_state['file_uploaded'] = False
271
 
272
  if __name__ == "__main__":
273
  main()
 
14
 
15
  class GroqLLM:
16
  """Compatible LLM interface for smolagents CodeAgent"""
17
+ def __init__(self, model_name="llama2-70b-3.5"):
18
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
19
  self.model_name = model_name
20
 
 
23
  try:
24
  # Handle different prompt formats
25
  if isinstance(prompt, (dict, list)):
 
26
  prompt_str = str(prompt)
27
  else:
28
  prompt_str = str(prompt)
 
39
  stream=False
40
  )
41
 
42
+ return completion.choices[0].message.content if completion.choices else "Error: No response generated"
 
 
 
43
 
44
  except Exception as e:
 
45
  error_msg = f"Error generating response: {str(e)}"
46
+ print(error_msg)
47
  return error_msg
48
 
49
+ class DataAnalysisAgent(CodeAgent):
50
+ """Extended CodeAgent with dataset awareness"""
51
+ def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+ self._dataset = dataset
54
+
55
+ @property
56
+ def dataset(self) -> pd.DataFrame:
57
+ """Access the stored dataset"""
58
+ return self._dataset
59
+
60
+ def run(self, prompt: str) -> str:
61
+ """Override run method to include dataset context"""
62
+ dataset_info = f"""
63
+ Dataset Shape: {self.dataset.shape}
64
+ Columns: {', '.join(self.dataset.columns)}
65
+ Data Types: {self.dataset.dtypes.to_dict()}
66
+ """
67
+ enhanced_prompt = f"""
68
+ Analyze the following dataset:
69
+ {dataset_info}
70
+
71
+ Task: {prompt}
72
+
73
+ Use the provided tools to analyze this specific dataset and return detailed results.
74
+ """
75
+ return super().run(enhanced_prompt)
76
+
77
  @tool
78
  def analyze_basic_stats(data: pd.DataFrame) -> str:
79
  """Calculate basic statistical measures for numerical columns in the dataset.
 
90
  str: A string containing formatted basic statistics for each numerical column,
91
  including mean, median, standard deviation, skewness, and missing value counts.
92
  """
93
+ # Access dataset from agent if no data provided
94
+ if data is None:
95
+ data = tool.agent.dataset
96
+
97
  stats = {}
98
  numeric_cols = data.select_dtypes(include=[np.number]).columns
99
 
100
  for col in numeric_cols:
101
  stats[col] = {
102
+ 'mean': float(data[col].mean()),
103
+ 'median': float(data[col].median()),
104
+ 'std': float(data[col].std()),
105
+ 'skew': float(data[col].skew()),
106
+ 'missing': int(data[col].isnull().sum())
107
  }
108
 
109
  return str(stats)
 
124
  str: A base64 encoded string representing the correlation matrix plot image,
125
  which can be displayed in a web interface or saved as an image file.
126
  """
127
+ # Access dataset from agent if no data provided
128
+ if data is None:
129
+ data = tool.agent.dataset
130
+
131
  numeric_data = data.select_dtypes(include=[np.number])
132
 
133
  plt.figure(figsize=(10, 8))
 
148
 
149
  Args:
150
  data: A pandas DataFrame containing the dataset to analyze. The DataFrame
151
+ should contain at least one categorical column for meaningful analysis.
 
152
 
153
  Returns:
154
  str: A string containing formatted analysis results for each categorical column,
155
  including unique value counts, top categories, and missing value counts.
156
  """
157
+ # Access dataset from agent if no data provided
158
+ if data is None:
159
+ data = tool.agent.dataset
160
+
161
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
162
  analysis = {}
163
 
164
  for col in categorical_cols:
165
  analysis[col] = {
166
+ 'unique_values': int(data[col].nunique()),
167
  'top_categories': data[col].value_counts().head(5).to_dict(),
168
+ 'missing': int(data[col].isnull().sum())
169
  }
170
 
171
  return str(analysis)
 
179
 
180
  Args:
181
  data: A pandas DataFrame containing the dataset to analyze. The DataFrame
182
+ can contain both numerical and categorical columns.
 
183
 
184
  Returns:
185
+ str: A string containing suggestions for feature engineering based on
186
+ the characteristics of the input data.
187
  """
188
+ # Access dataset from agent if no data provided
189
+ if data is None:
190
+ data = tool.agent.dataset
191
+
192
  suggestions = []
193
  numeric_cols = data.select_dtypes(include=[np.number]).columns
194
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
 
205
 
206
  return '\n'.join(suggestions)
207
 
 
 
 
 
 
 
 
 
 
 
208
  def main():
209
  st.title("Data Analysis Assistant")
210
  st.write("Upload your dataset and get automated analysis with natural language interaction.")
211
 
212
+ # Initialize session state
213
+ if 'data' not in st.session_state:
214
+ st.session_state['data'] = None
215
+ if 'agent' not in st.session_state:
216
+ st.session_state['agent'] = None
217
+
218
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
219
 
220
  try:
221
+ if uploaded_file is not None:
 
222
  with st.spinner('Loading and processing your data...'):
223
+ # Load the dataset
224
+ data = pd.read_csv(uploaded_file)
225
+ st.session_state['data'] = data
226
+
227
+ # Initialize the agent with the dataset
228
+ st.session_state['agent'] = DataAnalysisAgent(
229
+ dataset=data,
230
+ tools=[analyze_basic_stats, generate_correlation_matrix,
231
+ analyze_categorical_columns, suggest_features],
232
+ model=GroqLLM(),
233
+ additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"]
234
+ )
235
+
236
+ st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
237
+ st.subheader("Data Preview")
238
+ st.dataframe(data.head())
 
 
 
 
 
 
 
 
239
 
240
+ if st.session_state['data'] is not None:
 
 
241
  analysis_type = st.selectbox(
242
  "Choose analysis type",
243
  ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
244
  "Feature Engineering", "Custom Question"]
245
  )
246
 
247
+ if analysis_type == "Basic Statistics":
248
+ with st.spinner('Analyzing basic statistics...'):
249
+ result = st.session_state['agent'].run(
250
+ "Use the analyze_basic_stats tool to analyze this dataset and "
251
+ "provide insights about the numerical distributions."
252
+ )
253
+ st.write(result)
254
+
255
+ elif analysis_type == "Correlation Analysis":
256
+ with st.spinner('Generating correlation matrix...'):
257
+ result = st.session_state['agent'].run(
258
+ "Use the generate_correlation_matrix tool to analyze correlations "
259
+ "and explain any strong relationships found."
260
+ )
261
+ if isinstance(result, str) and result.startswith('data:image') or ',' in result:
262
+ st.image(f"data:image/png;base64,{result.split(',')[-1]}")
263
+ else:
 
 
 
 
 
 
 
264
  st.write(result)
265
+
266
+ elif analysis_type == "Categorical Analysis":
267
+ with st.spinner('Analyzing categorical columns...'):
268
+ result = st.session_state['agent'].run(
269
+ "Use the analyze_categorical_columns tool to examine the "
270
+ "categorical variables and explain the distributions."
271
+ )
272
+ st.write(result)
273
+
274
+ elif analysis_type == "Feature Engineering":
275
+ with st.spinner('Generating feature suggestions...'):
276
+ result = st.session_state['agent'].run(
277
+ "Use the suggest_features tool to recommend potential "
278
+ "feature engineering steps for this dataset."
279
+ )
280
+ st.write(result)
281
+
282
+ elif analysis_type == "Custom Question":
283
+ question = st.text_input("What would you like to know about your data?")
284
+ if question:
285
+ with st.spinner('Analyzing...'):
286
+ result = st.session_state['agent'].run(question)
287
  st.write(result)
288
 
 
 
 
 
 
 
 
 
 
289
  except Exception as e:
290
  st.error(f"An error occurred: {str(e)}")
 
291
 
292
  if __name__ == "__main__":
293
  main()