mgbam commited on
Commit
102a9b5
·
verified ·
1 Parent(s): 06300b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +347 -167
app.py CHANGED
@@ -1,216 +1,396 @@
1
  import streamlit as st
2
  import numpy as np
3
  import pandas as pd
 
 
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
6
  import os
 
 
 
7
  import base64
8
  import io
9
- from groq import Groq
10
- from langchain.tools import tool
11
- from langchain.agents import AgentType, initialize_agent
12
- from langchain.chains import LLMChain
13
- from langchain.prompts import PromptTemplate
14
- from typing import Optional, Dict, List
15
-
16
- # Initialize Groq Client
17
- client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
18
-
19
- class GroqAnalyst:
20
- """Advanced AI Researcher & Data Analyst using Groq"""
21
- def __init__(self, model_name="mixtral-8x7b-32768"):
22
  self.model_name = model_name
23
- self.system_prompt = """
24
- You are an expert AI research assistant and data scientist.
25
- Provide detailed, technical analysis with professional visualizations.
26
- """
27
 
28
- def analyze(self, prompt: str, data: pd.DataFrame) -> str:
29
- """Execute complex data analysis using Groq"""
30
  try:
31
- dataset_info = f"""
32
- Dataset Shape: {data.shape}
33
- Columns: {', '.join(data.columns)}
34
- Data Types: {data.dtypes.to_dict()}
35
- Sample Data: {data.head(3).to_dict()}
36
- """
37
-
38
- completion = client.chat.completions.create(
39
- messages=[
40
- {"role": "system", "content": self.system_prompt},
41
- {"role": "user", "content": f"{dataset_info}\n\nTask: {prompt}"}
42
- ],
43
  model=self.model_name,
44
- temperature=0.3,
45
- max_tokens=4096,
46
- stream=False
 
47
  )
48
-
49
- return completion.choices[0].message.content
50
-
 
 
 
51
  except Exception as e:
52
- return f"Analysis Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  @tool
55
- def advanced_eda(data: pd.DataFrame) -> Dict:
56
- """Perform comprehensive exploratory data analysis.
57
-
 
 
 
 
58
  Args:
59
- data (pd.DataFrame): Input dataset for analysis
60
-
 
61
  Returns:
62
- Dict: Contains statistical summary, missing values, and data quality report
 
63
  """
64
- analysis = {
65
- "statistical_summary": data.describe().to_dict(),
66
- "missing_values": data.isnull().sum().to_dict(),
67
- "data_quality": {
68
- "duplicates": data.duplicated().sum(),
69
- "zero_values": (data == 0).sum().to_dict()
 
 
 
 
70
  }
71
- }
72
- return analysis
 
73
 
74
  @tool
75
- def visualize_distributions(data: pd.DataFrame, columns: List[str]) -> str:
76
- """Generate distribution plots for specified numerical columns.
77
-
 
 
 
 
78
  Args:
79
- data (pd.DataFrame): Input dataset
80
- columns (List[str]): List of numerical columns to visualize
81
-
82
  Returns:
83
- str: Base64 encoded image of the visualization
 
84
  """
85
- plt.figure(figsize=(12, 6))
86
- for i, col in enumerate(columns, 1):
87
- plt.subplot(1, len(columns), i)
88
- sns.histplot(data[col], kde=True)
89
- plt.title(f'Distribution of {col}')
90
- plt.tight_layout()
91
-
92
  buf = io.BytesIO()
93
- plt.savefig(buf, format='png')
94
  plt.close()
95
  return base64.b64encode(buf.getvalue()).decode()
96
 
 
97
  @tool
98
- def temporal_analysis(data: pd.DataFrame, time_col: str, value_col: str) -> str:
99
- """Analyze time series data and generate trend visualization.
100
-
 
 
 
101
  Args:
102
- data (pd.DataFrame): Dataset containing time series
103
- time_col (str): Name of timestamp column
104
- value_col (str): Name of value column to analyze
105
-
106
  Returns:
107
- str: Base64 encoded image of time series plot
 
108
  """
109
- plt.figure(figsize=(12, 6))
110
- data[time_col] = pd.to_datetime(data[time_col])
111
- data.set_index(time_col)[value_col].plot()
112
- plt.title(f'Temporal Trend of {value_col}')
113
- plt.xlabel('Date')
114
- plt.ylabel('Value')
115
-
116
- buf = io.BytesIO()
117
- plt.savefig(buf, format='png')
118
- plt.close()
119
- return base64.b64encode(buf.getvalue()).decode()
 
120
 
121
  @tool
122
- def hypothesis_testing(data: pd.DataFrame, group_col: str, value_col: str) -> Dict:
123
- """Perform statistical hypothesis testing between groups.
124
-
 
 
 
125
  Args:
126
- data (pd.DataFrame): Input dataset
127
- group_col (str): Categorical column defining groups
128
- value_col (str): Numerical column to compare
129
-
130
  Returns:
131
- Dict: Contains test results, p-value, and conclusion
 
132
  """
133
- from scipy.stats import ttest_ind
134
-
135
- groups = data[group_col].unique()
136
- if len(groups) != 2:
137
- return {"error": "Hypothesis testing requires exactly two groups"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- group1 = data[data[group_col] == groups[0]][value_col]
140
- group2 = data[data[group_col] == groups[1]][value_col]
141
 
142
- t_stat, p_value = ttest_ind(group1, group2)
 
 
143
 
144
- return {
145
- "t_statistic": t_stat,
146
- "p_value": p_value,
147
- "conclusion": "Significant difference" if p_value < 0.05 else "No significant difference"
148
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  def main():
151
- st.title("🔬 AI Research Assistant with Groq")
152
- st.markdown("Advanced data analysis powered by Groq's accelerated computing")
153
-
154
  # Initialize session state
155
- if 'data' not in st.session_state:
156
- st.session_state.data = None
157
- if 'analyst' not in st.session_state:
158
- st.session_state.analyst = GroqAnalyst()
159
-
160
- # File upload section
161
- with st.sidebar:
162
- st.header("Data Upload")
163
- uploaded_file = st.file_uploader("Upload dataset (CSV)", type="csv")
164
- if uploaded_file:
165
- with st.spinner("Analyzing dataset..."):
166
- st.session_state.data = pd.read_csv(uploaded_file)
167
- st.success(f"Loaded {len(st.session_state.data)} records")
168
 
169
- # Main analysis interface
170
- if st.session_state.data is not None:
171
- st.subheader("Dataset Overview")
172
- st.dataframe(st.session_state.data.head(), use_container_width=True)
173
-
174
- analysis_type = st.selectbox("Select Analysis Type", [
175
- "Exploratory Data Analysis",
176
- "Temporal Analysis",
177
- "Statistical Testing",
178
- "Custom Research Query"
179
- ])
180
-
181
- if analysis_type == "Exploratory Data Analysis":
182
- with st.expander("Advanced EDA"):
183
- eda_result = advanced_eda(st.session_state.data)
184
- st.json(eda_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist()
187
- if num_cols:
188
- selected_cols = st.multiselect("Select columns for distribution analysis", num_cols)
189
- if selected_cols:
190
- img_data = visualize_distributions(st.session_state.data, selected_cols)
191
- st.image(f"data:image/png;base64,{img_data}")
192
-
193
- elif analysis_type == "Temporal Analysis":
194
- time_col = st.selectbox("Select time column", st.session_state.data.columns)
195
- value_col = st.selectbox("Select value column", st.session_state.data.select_dtypes(include=np.number).columns)
196
- if time_col and value_col:
197
- img_data = temporal_analysis(st.session_state.data, time_col, value_col)
198
- st.image(f"data:image/png;base64,{img_data}")
199
-
200
- elif analysis_type == "Statistical Testing":
201
- group_col = st.selectbox("Select group column", st.session_state.data.select_dtypes(include='object').columns)
202
- value_col = st.selectbox("Select metric to compare", st.session_state.data.select_dtypes(include=np.number).columns)
203
- if group_col and value_col:
204
- test_result = hypothesis_testing(st.session_state.data, group_col, value_col)
205
- st.json(test_result)
206
-
207
- elif analysis_type == "Custom Research Query":
208
- research_query = st.text_area("Enter your research question:")
209
- if research_query:
210
- with st.spinner("Conducting advanced analysis..."):
211
- result = st.session_state.analyst.analyze(research_query, st.session_state.data)
212
- st.markdown("### Research Findings")
213
- st.markdown(result)
214
 
215
  if __name__ == "__main__":
216
  main()
 
1
  import streamlit as st
2
  import numpy as np
3
  import pandas as pd
4
+ from smolagents import CodeAgent, tool
5
+ from typing import Union, List, Dict, Optional
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
  import os
9
+ from groq import Groq
10
+ from dataclasses import dataclass
11
+ import tempfile
12
  import base64
13
  import io
14
+ import json
15
+ from streamlit_ace import st_ace
16
+ from contextlib import contextmanager
17
+
18
+
19
+ class GroqLLM:
20
+ """Compatible LLM interface for smolagents CodeAgent"""
21
+
22
+ def __init__(self, model_name="llama-3.1-8B-Instant"):
23
+ self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
 
 
 
24
  self.model_name = model_name
 
 
 
 
25
 
26
+ def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
27
+ """Make the class callable as required by smolagents"""
28
  try:
29
+ # Handle different prompt formats
30
+ if isinstance(prompt, (dict, list)):
31
+ prompt_str = str(prompt)
32
+ else:
33
+ prompt_str = str(prompt)
34
+
35
+ # Create a properly formatted message
36
+ completion = self.client.chat.completions.create(
 
 
 
 
37
  model=self.model_name,
38
+ messages=[{"role": "user", "content": prompt_str}],
39
+ temperature=0.7,
40
+ max_tokens=1024,
41
+ stream=True, # Enable streaming
42
  )
43
+
44
+ full_response = ""
45
+ for chunk in completion:
46
+ if chunk.choices[0].delta.content is not None:
47
+ full_response += chunk.choices[0].delta.content
48
+ return full_response
49
  except Exception as e:
50
+ error_msg = f"Error generating response: {str(e)}"
51
+ print(error_msg)
52
+ return error_msg
53
+
54
+
55
+ class DataAnalysisAgent(CodeAgent):
56
+ """Extended CodeAgent with dataset awareness"""
57
+
58
+ def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
59
+ super().__init__(*args, **kwargs)
60
+ self._dataset = dataset
61
+
62
+ @property
63
+ def dataset(self) -> pd.DataFrame:
64
+ """Access the stored dataset"""
65
+ return self._dataset
66
+
67
+ def run(self, prompt: str, **kwargs) -> str:
68
+ """Override run method to include dataset context"""
69
+ dataset_info = f"""
70
+ Dataset Shape: {self.dataset.shape}
71
+ Columns: {', '.join(self.dataset.columns)}
72
+ Data Types: {self.dataset.dtypes.to_dict()}
73
+ """
74
+ enhanced_prompt = f"""
75
+ Analyze the following dataset:
76
+ {dataset_info}
77
+
78
+ Task: {prompt}
79
+
80
+ Use the provided tools to analyze this specific dataset and return detailed results.
81
+ """
82
+ return super().run(enhanced_prompt, data=self.dataset, **kwargs) # Pass data as argument
83
+
84
 
85
  @tool
86
+ def analyze_basic_stats(data: pd.DataFrame) -> str:
87
+ """Calculate basic statistical measures for numerical columns in the dataset.
88
+
89
+ This function computes fundamental statistical metrics including mean, median,
90
+ standard deviation, skewness, and counts of missing values for all numerical
91
+ columns in the provided DataFrame.
92
+
93
  Args:
94
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
95
+ should contain at least one numerical column for meaningful analysis.
96
+
97
  Returns:
98
+ str: A string containing formatted basic statistics for each numerical column,
99
+ including mean, median, standard deviation, skewness, and missing value counts.
100
  """
101
+ stats = {}
102
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
103
+
104
+ for col in numeric_cols:
105
+ stats[col] = {
106
+ "mean": float(data[col].mean()),
107
+ "median": float(data[col].median()),
108
+ "std": float(data[col].std()),
109
+ "skew": float(data[col].skew()),
110
+ "missing": int(data[col].isnull().sum()),
111
  }
112
+
113
+ return str(stats)
114
+
115
 
116
  @tool
117
+ def generate_correlation_matrix(data: pd.DataFrame) -> str:
118
+ """Generate a visual correlation matrix for numerical columns in the dataset.
119
+
120
+ This function creates a heatmap visualization showing the correlations between
121
+ all numerical columns in the dataset. The correlation values are displayed
122
+ using a color-coded matrix for easy interpretation.
123
+
124
  Args:
125
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
126
+ should contain at least two numerical columns for correlation analysis.
127
+
128
  Returns:
129
+ str: A base64 encoded string representing the correlation matrix plot image,
130
+ which can be displayed in a web interface or saved as an image file.
131
  """
132
+ numeric_data = data.select_dtypes(include=[np.number])
133
+
134
+ plt.figure(figsize=(10, 8))
135
+ sns.heatmap(numeric_data.corr(), annot=True, cmap="coolwarm")
136
+ plt.title("Correlation Matrix")
137
+
 
138
  buf = io.BytesIO()
139
+ plt.savefig(buf, format="png")
140
  plt.close()
141
  return base64.b64encode(buf.getvalue()).decode()
142
 
143
+
144
  @tool
145
+ def analyze_categorical_columns(data: pd.DataFrame) -> str:
146
+ """Analyze categorical columns in the dataset for distribution and frequencies.
147
+
148
+ This function examines categorical columns to identify unique values, top categories,
149
+ and missing value counts, providing insights into the categorical data distribution.
150
+
151
  Args:
152
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
153
+ should contain at least one categorical column for meaningful analysis.
154
+
 
155
  Returns:
156
+ str: A string containing formatted analysis results for each categorical column,
157
+ including unique value counts, top categories, and missing value counts.
158
  """
159
+ categorical_cols = data.select_dtypes(include=["object", "category"]).columns
160
+ analysis = {}
161
+
162
+ for col in categorical_cols:
163
+ analysis[col] = {
164
+ "unique_values": int(data[col].nunique()),
165
+ "top_categories": data[col].value_counts().head(5).to_dict(),
166
+ "missing": int(data[col].isnull().sum()),
167
+ }
168
+
169
+ return str(analysis)
170
+
171
 
172
  @tool
173
+ def suggest_features(data: pd.DataFrame) -> str:
174
+ """Suggest potential feature engineering steps based on data characteristics.
175
+
176
+ This function analyzes the dataset's structure and statistical properties to
177
+ recommend possible feature engineering steps that could improve model performance.
178
+
179
  Args:
180
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
181
+ can contain both numerical and categorical columns.
182
+
 
183
  Returns:
184
+ str: A string containing suggestions for feature engineering based on
185
+ the characteristics of the input data.
186
  """
187
+ suggestions = []
188
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
189
+ categorical_cols = data.select_dtypes(include=["object", "category"]).columns
190
+
191
+ if len(numeric_cols) >= 2:
192
+ suggestions.append("Consider creating interaction terms between numerical features")
193
+
194
+ if len(categorical_cols) > 0:
195
+ suggestions.append("Consider one-hot encoding for categorical variables")
196
+
197
+ for col in numeric_cols:
198
+ if data[col].skew() > 1 or data[col].skew() < -1:
199
+ suggestions.append(f"Consider log transformation for {col} due to skewness")
200
+
201
+ return "\n".join(suggestions)
202
+
203
+
204
+ @tool
205
+ def describe_data(data: pd.DataFrame) -> str:
206
+ """Generates a comprehensive descriptive statistics report for the entire DataFrame.
207
 
208
+ Args:
209
+ data: A pandas DataFrame containing the dataset to analyze.
210
 
211
+ Returns:
212
+ str: String representation of the descriptive statistics
213
+ """
214
 
215
+ return data.describe(include='all').to_string()
216
+
217
+
218
+
219
+ @tool
220
+ def execute_code(code_string: str, data: pd.DataFrame) -> str:
221
+ """Executes python code and returns results as a string.
222
+
223
+ Args:
224
+ code_string (str): Python code to execute.
225
+ data (pd.DataFrame): The dataframe to use in the code
226
+ Returns:
227
+ str: The result of executing the code or an error message
228
+ """
229
+ try:
230
+ # This dictionary will be available to the code
231
+ local_vars = {"data": data, "pd": pd, "np": np, "plt": plt, "sns": sns}
232
+
233
+ # Execute the code with the passed variables
234
+ exec(code_string, local_vars)
235
+
236
+ if "result" in local_vars:
237
+ if isinstance(local_vars["result"], (pd.DataFrame, pd.Series)):
238
+ return local_vars["result"].to_string()
239
+ elif isinstance(local_vars["result"], plt.Figure):
240
+ buf = io.BytesIO()
241
+ local_vars["result"].savefig(buf, format='png')
242
+ plt.close(local_vars["result"])
243
+ return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
244
+ else:
245
+ return str(local_vars["result"])
246
+ else:
247
+ return "Code executed successfully, but no variable called 'result' was assigned."
248
+
249
+ except Exception as e:
250
+ return f"Error executing code: {str(e)}"
251
+
252
+
253
+ @st.cache_data
254
+ def load_data(uploaded_file):
255
+ """Loads data from an uploaded file with caching."""
256
+ try:
257
+ if uploaded_file.name.endswith(".csv"):
258
+ return pd.read_csv(uploaded_file)
259
+ elif uploaded_file.name.endswith((".xls", ".xlsx")):
260
+ return pd.read_excel(uploaded_file)
261
+ elif uploaded_file.name.endswith(".json"):
262
+ return pd.read_json(uploaded_file)
263
+ else:
264
+ raise ValueError(
265
+ "Unsupported file format. Please upload a CSV, Excel, or JSON file."
266
+ )
267
+ except Exception as e:
268
+ st.error(f"Error loading data: {e}")
269
+ return None
270
+
271
 
272
  def main():
273
+ st.title("Data Analysis Assistant")
274
+ st.write("Upload your dataset and get automated analysis with natural language interaction.")
275
+
276
  # Initialize session state
277
+ if "data" not in st.session_state:
278
+ st.session_state["data"] = None
279
+ if "agent" not in st.session_state:
280
+ st.session_state["agent"] = None
281
+ if "custom_code" not in st.session_state:
282
+ st.session_state['custom_code'] = ""
283
+
284
+ uploaded_file = st.file_uploader("Choose a CSV, Excel, or JSON file", type=["csv", "xlsx", "xls", "json"])
 
 
 
 
 
285
 
286
+ if uploaded_file:
287
+ with st.spinner("Loading and processing your data..."):
288
+ data = load_data(uploaded_file)
289
+ if data is not None:
290
+ st.session_state["data"] = data
291
+
292
+ st.session_state["agent"] = DataAnalysisAgent(
293
+ dataset=data,
294
+ tools=[
295
+ analyze_basic_stats,
296
+ generate_correlation_matrix,
297
+ analyze_categorical_columns,
298
+ suggest_features,
299
+ describe_data,
300
+ execute_code
301
+ ],
302
+ model=GroqLLM(),
303
+ additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"],
304
+ )
305
+ st.success(
306
+ f"Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns"
307
+ )
308
+ st.subheader("Data Preview")
309
+ st.dataframe(data.head())
310
+
311
+ if st.session_state["data"] is not None:
312
+ analysis_type = st.selectbox(
313
+ "Choose analysis type",
314
+ [
315
+ "Basic Statistics",
316
+ "Correlation Analysis",
317
+ "Categorical Analysis",
318
+ "Feature Engineering",
319
+ "Data Description",
320
+ "Custom Code",
321
+ "Custom Question",
322
+ ],
323
+ )
324
+
325
+ if analysis_type == "Basic Statistics":
326
+ with st.spinner("Analyzing basic statistics..."):
327
+ result = st.session_state["agent"].run(
328
+ "Use the analyze_basic_stats tool to analyze this dataset and "
329
+ "provide insights about the numerical distributions."
330
+ )
331
+ st.write(result)
332
+
333
+ elif analysis_type == "Correlation Analysis":
334
+ with st.spinner("Generating correlation matrix..."):
335
+ result = st.session_state["agent"].run(
336
+ "Use the generate_correlation_matrix tool to analyze correlations "
337
+ "and explain any strong relationships found."
338
+ )
339
+ if isinstance(result, str) and result.startswith("data:image") or "," in result:
340
+ st.image(f"data:image/png;base64,{result.split(',')[-1]}")
341
+ else:
342
+ st.write(result)
343
+
344
+ elif analysis_type == "Categorical Analysis":
345
+ with st.spinner("Analyzing categorical columns..."):
346
+ result = st.session_state["agent"].run(
347
+ "Use the analyze_categorical_columns tool to examine the "
348
+ "categorical variables and explain the distributions."
349
+ )
350
+ st.write(result)
351
+
352
+ elif analysis_type == "Feature Engineering":
353
+ with st.spinner("Generating feature suggestions..."):
354
+ result = st.session_state["agent"].run(
355
+ "Use the suggest_features tool to recommend potential "
356
+ "feature engineering steps for this dataset."
357
+ )
358
+ st.write(result)
359
+
360
+ elif analysis_type == "Data Description":
361
+ with st.spinner("Generating data description"):
362
+ result = st.session_state["agent"].run(
363
+ "Use the describe_data tool to generate a comprehensive description "
364
+ "of the data."
365
+ )
366
+ st.write(result)
367
 
368
+ elif analysis_type == "Custom Code":
369
+ st.session_state['custom_code'] = st_ace(
370
+ placeholder="Enter your Python code here...",
371
+ language="python",
372
+ theme="github",
373
+ key="code_editor",
374
+ value=st.session_state['custom_code'],
375
+ )
376
+ if st.button("Run Code"):
377
+ with st.spinner("Executing custom code..."):
378
+ result = st.session_state["agent"].run(
379
+ f"Execute the following code and return any 'result' variable"
380
+ f"```python\n{st.session_state['custom_code']}\n```"
381
+ )
382
+ if isinstance(result, str) and result.startswith("data:image"):
383
+ st.image(f"{result}")
384
+ else:
385
+ st.write(result)
386
+
387
+ elif analysis_type == "Custom Question":
388
+ question = st.text_input("What would you like to know about your data?")
389
+ if question:
390
+ with st.spinner("Analyzing..."):
391
+ result = st.session_state["agent"].run(question, stream=True) # Pass stream argument here
392
+ st.write(result)
393
+
 
 
394
 
395
  if __name__ == "__main__":
396
  main()