mgbam commited on
Commit
33c4308
·
verified ·
1 Parent(s): a50cd69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +307 -48
app.py CHANGED
@@ -1,51 +1,310 @@
1
- from smolagents._function_type_hints_utils import FunctionSchema
2
-
3
- @tool(schema=FunctionSchema(
4
- name="execute_code",
5
- description="Executes Python code and returns the result as a string.",
6
- parameters=[
7
- {
8
- "name": "code_string",
9
- "type": "string",
10
- "description": "Python code to execute."
11
- },
12
- {
13
- "name": "data",
14
- "type": "object",
15
- "description": "The pandas DataFrame to use.",
16
- "properties":{},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  }
18
- ],
19
- returns={
20
- "type":"string",
21
- "description": "The result of executing the code."
22
- }
23
- ))
24
- def execute_code(code_string: str, data: pd.DataFrame) -> str:
25
- """Executes python code and returns results as a string.
26
-
27
- Args:
28
- code_string (str): Python code to execute.
29
- data (pd.DataFrame): The dataframe to use in the code
30
- Returns:
31
- str: The result of executing the code or an error message
32
- """
33
- try:
34
- local_vars = {"data": data, "pd": pd, "np": np, "plt": plt, "sns": sns}
35
- exec(code_string, local_vars)
36
-
37
- if "result" in local_vars:
38
- if isinstance(local_vars["result"], (pd.DataFrame, pd.Series)):
39
- return local_vars["result"].to_string()
40
- elif isinstance(local_vars["result"], plt.Figure):
41
- buf = io.BytesIO()
42
- local_vars["result"].savefig(buf, format="png")
43
- plt.close(local_vars["result"])
44
- return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
45
- else:
46
- return str(local_vars["result"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  else:
48
- return "Code executed successfully, but no variable called 'result' was assigned."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- except Exception as e:
51
- return f"Error executing code: {str(e)}"
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import os
7
+ import base64
8
+ import io
9
+ from groq import Groq
10
+ from pydantic import BaseModel, Field
11
+ from typing import Dict, List, Optional
12
+ from langchain.tools import tool
13
+ from langchain.agents import initialize_agent, AgentType
14
+ from scipy.stats import ttest_ind, f_oneway
15
+
16
+ # Initialize Groq Client
17
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
18
+
19
+ class ResearchInput(BaseModel):
20
+ """Base schema for research tool inputs"""
21
+ data_key: str = Field(..., description="Session state key containing DataFrame")
22
+ columns: Optional[List[str]] = Field(None, description="List of columns to analyze")
23
+
24
+ class TemporalAnalysisInput(ResearchInput):
25
+ """Schema for temporal analysis"""
26
+ time_col: str = Field(..., description="Name of timestamp column")
27
+ value_col: str = Field(..., description="Name of value column to analyze")
28
+
29
+ class HypothesisInput(ResearchInput):
30
+ """Schema for hypothesis testing"""
31
+ group_col: str = Field(..., description="Categorical column defining groups")
32
+ value_col: str = Field(..., description="Numerical column to compare")
33
+
34
+ class GroqResearcher:
35
+ """Advanced AI Research Engine using Groq"""
36
+ def __init__(self, model_name="mixtral-8x7b-32768"):
37
+ self.model_name = model_name
38
+ self.system_template = """You are a senior data scientist at a research institution.
39
+ Analyze this dataset with rigorous statistical methods and provide academic-quality insights:
40
+ {dataset_info}
41
+
42
+ User Question: {query}
43
+
44
+ Required Format:
45
+ - Executive Summary (1 paragraph)
46
+ - Methodology (bullet points)
47
+ - Key Findings (numbered list)
48
+ - Limitations
49
+ - Recommended Next Steps"""
50
+
51
+ def research(self, query: str, data: pd.DataFrame) -> str:
52
+ """Conduct academic-level analysis using Groq"""
53
+ try:
54
+ dataset_info = f"""
55
+ Dataset Dimensions: {data.shape}
56
+ Variables: {', '.join(data.columns)}
57
+ Temporal Coverage: {data.select_dtypes(include='datetime').columns.tolist()}
58
+ Missing Values: {data.isnull().sum().to_dict()}
59
+ """
60
+
61
+ prompt = PromptTemplate.from_template(self.system_template).format(
62
+ dataset_info=dataset_info,
63
+ query=query
64
+ )
65
+
66
+ completion = client.chat.completions.create(
67
+ messages=[
68
+ {"role": "system", "content": "You are a research AI assistant"},
69
+ {"role": "user", "content": prompt}
70
+ ],
71
+ model=self.model_name,
72
+ temperature=0.2,
73
+ max_tokens=4096,
74
+ stream=False
75
+ )
76
+
77
+ return completion.choices[0].message.content
78
+
79
+ except Exception as e:
80
+ return f"Research Error: {str(e)}"
81
+
82
+ @tool(args_schema=ResearchInput)
83
+ def advanced_eda(data_key: str) -> Dict:
84
+ """Comprehensive Exploratory Data Analysis with Statistical Profiling"""
85
+ try:
86
+ data = st.session_state[data_key]
87
+ analysis = {
88
+ "dimensionality": {
89
+ "rows": len(data),
90
+ "columns": list(data.columns),
91
+ "memory_usage": f"{data.memory_usage().sum() / 1e6:.2f} MB"
92
+ },
93
+ "statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(),
94
+ "temporal_analysis": {
95
+ "date_ranges": {
96
+ col: {
97
+ "min": data[col].min(),
98
+ "max": data[col].max()
99
+ } for col in data.select_dtypes(include='datetime').columns
100
+ }
101
+ },
102
+ "data_quality": {
103
+ "missing_values": data.isnull().sum().to_dict(),
104
+ "duplicates": data.duplicated().sum(),
105
+ "cardinality": {
106
+ col: data[col].nunique() for col in data.columns
107
+ }
108
+ }
109
  }
110
+ return analysis
111
+ except Exception as e:
112
+ return {"error": f"EDA Failed: {str(e)}"}
113
+
114
+ @tool(args_schema=ResearchInput)
115
+ def visualize_distributions(data_key: str, columns: List[str]) -> str:
116
+ """Generate publication-quality distribution visualizations"""
117
+ try:
118
+ data = st.session_state[data_key]
119
+ plt.figure(figsize=(12, 6))
120
+ for i, col in enumerate(columns, 1):
121
+ plt.subplot(1, len(columns), i)
122
+ sns.histplot(data[col], kde=True, stat="density")
123
+ plt.title(f'Distribution of {col}', fontsize=10)
124
+ plt.xticks(fontsize=8)
125
+ plt.yticks(fontsize=8)
126
+ plt.tight_layout()
127
+
128
+ buf = io.BytesIO()
129
+ plt.savefig(buf, format='png', dpi=300, bbox_inches='tight')
130
+ plt.close()
131
+ return base64.b64encode(buf.getvalue()).decode()
132
+ except Exception as e:
133
+ return f"Visualization Error: {str(e)}"
134
+
135
+ @tool(args_schema=TemporalAnalysisInput)
136
+ def temporal_analysis(data_key: str, time_col: str, value_col: str) -> Dict:
137
+ """Time Series Decomposition and Trend Analysis"""
138
+ try:
139
+ data = st.session_state[data_key]
140
+ ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col]
141
+
142
+ decomposition = seasonal_decompose(ts_data, period=365)
143
+
144
+ plt.figure(figsize=(12, 8))
145
+ decomposition.plot()
146
+ plt.tight_layout()
147
+
148
+ buf = io.BytesIO()
149
+ plt.savefig(buf, format='png')
150
+ plt.close()
151
+ plot_data = base64.b64encode(buf.getvalue()).decode()
152
+
153
+ return {
154
+ "trend_statistics": {
155
+ "stationarity": adfuller(ts_data)[1],
156
+ "seasonality_strength": max(decomposition.seasonal)
157
+ },
158
+ "visualization": plot_data
159
+ }
160
+ except Exception as e:
161
+ return {"error": f"Temporal Analysis Failed: {str(e)}"}
162
+
163
+ @tool(args_schema=HypothesisInput)
164
+ def hypothesis_testing(data_key: str, group_col: str, value_col: str) -> Dict:
165
+ """Statistical Hypothesis Testing with Automated Assumption Checking"""
166
+ try:
167
+ data = st.session_state[data_key]
168
+ groups = data[group_col].unique()
169
+
170
+ if len(groups) < 2:
171
+ return {"error": "Insufficient groups for comparison"}
172
+
173
+ if len(groups) == 2:
174
+ group_data = [data[data[group_col] == g][value_col] for g in groups]
175
+ stat, p = ttest_ind(*group_data)
176
+ test_type = "Independent t-test"
177
  else:
178
+ group_data = [data[data[group_col] == g][value_col] for g in groups]
179
+ stat, p = f_oneway(*group_data)
180
+ test_type = "ANOVA"
181
+
182
+ return {
183
+ "test_type": test_type,
184
+ "test_statistic": stat,
185
+ "p_value": p,
186
+ "effect_size": {
187
+ "cohens_d": abs(group_data[0].mean() - group_data[1].mean())/np.sqrt(
188
+ (group_data[0].var() + group_data[1].var())/2
189
+ ) if len(groups) == 2 else None
190
+ },
191
+ "interpretation": interpret_p_value(p)
192
+ }
193
+ except Exception as e:
194
+ return {"error": f"Hypothesis Testing Failed: {str(e)}"}
195
+
196
+ def interpret_p_value(p: float) -> str:
197
+ """Scientific interpretation of p-values"""
198
+ if p < 0.001: return "Very strong evidence against H0"
199
+ elif p < 0.01: return "Strong evidence against H0"
200
+ elif p < 0.05: return "Evidence against H0"
201
+ elif p < 0.1: return "Weak evidence against H0"
202
+ else: return "No significant evidence against H0"
203
+
204
+ def main():
205
+ st.set_page_config(page_title="AI Research Lab", layout="wide")
206
+ st.title("🧪 Advanced AI Research Laboratory")
207
+
208
+ # Session state initialization
209
+ if 'data' not in st.session_state:
210
+ st.session_state.data = None
211
+ if 'researcher' not in st.session_state:
212
+ st.session_state.researcher = GroqResearcher()
213
+
214
+ # Data upload and management
215
+ with st.sidebar:
216
+ st.header("🔬 Data Management")
217
+ uploaded_file = st.file_uploader("Upload research dataset", type=["csv", "parquet"])
218
+ if uploaded_file:
219
+ with st.spinner("Initializing dataset..."):
220
+ st.session_state.data = pd.read_csv(uploaded_file)
221
+ st.success(f"Loaded {len(st.session_state.data):,} research observations")
222
+
223
+ # Main research interface
224
+ if st.session_state.data is not None:
225
+ col1, col2 = st.columns([1, 3])
226
+
227
+ with col1:
228
+ st.subheader("Dataset Metadata")
229
+ st.json({
230
+ "Variables": list(st.session_state.data.columns),
231
+ "Time Range": {
232
+ col: {
233
+ "min": st.session_state.data[col].min(),
234
+ "max": st.session_state.data[col].max()
235
+ } for col in st.session_state.data.select_dtypes(include='datetime').columns
236
+ },
237
+ "Size": f"{st.session_state.data.memory_usage().sum() / 1e6:.2f} MB"
238
+ })
239
+
240
+ with col2:
241
+ analysis_tab, research_tab = st.tabs(["Automated Analysis", "Custom Research"])
242
+
243
+ with analysis_tab:
244
+ analysis_type = st.selectbox("Select Analysis Mode", [
245
+ "Exploratory Data Analysis",
246
+ "Temporal Pattern Analysis",
247
+ "Comparative Statistics",
248
+ "Distribution Analysis"
249
+ ])
250
+
251
+ if analysis_type == "Exploratory Data Analysis":
252
+ eda_result = advanced_eda.invoke({"data_key": "data"})
253
+ st.subheader("Data Quality Report")
254
+ st.json(eda_result)
255
+
256
+ elif analysis_type == "Temporal Pattern Analysis":
257
+ time_col = st.selectbox("Temporal Variable",
258
+ st.session_state.data.select_dtypes(include='datetime').columns)
259
+ value_col = st.selectbox("Analysis Variable",
260
+ st.session_state.data.select_dtypes(include=np.number).columns)
261
+
262
+ if time_col and value_col:
263
+ result = temporal_analysis.invoke({
264
+ "data_key": "data",
265
+ "time_col": time_col,
266
+ "value_col": value_col
267
+ })
268
+ if "visualization" in result:
269
+ st.image(f"data:image/png;base64,{result['visualization']}")
270
+ st.json(result)
271
+
272
+ elif analysis_type == "Comparative Statistics":
273
+ group_col = st.selectbox("Grouping Variable",
274
+ st.session_state.data.select_dtypes(include='category').columns)
275
+ value_col = st.selectbox("Metric Variable",
276
+ st.session_state.data.select_dtypes(include=np.number).columns)
277
+
278
+ if group_col and value_col:
279
+ result = hypothesis_testing.invoke({
280
+ "data_key": "data",
281
+ "group_col": group_col,
282
+ "value_col": value_col
283
+ })
284
+ st.subheader("Statistical Test Results")
285
+ st.json(result)
286
+
287
+ elif analysis_type == "Distribution Analysis":
288
+ num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist()
289
+ selected_cols = st.multiselect("Select Variables", num_cols)
290
+ if selected_cols:
291
+ img_data = visualize_distributions.invoke({
292
+ "data_key": "data",
293
+ "columns": selected_cols
294
+ })
295
+ st.image(f"data:image/png;base64,{img_data}")
296
+
297
+ with research_tab:
298
+ research_query = st.text_area("Enter Research Question:", height=150,
299
+ placeholder="E.g., 'What factors are most predictive of X outcome?'")
300
+
301
+ if st.button("Execute Research"):
302
+ with st.spinner("Conducting rigorous analysis..."):
303
+ result = st.session_state.researcher.research(
304
+ research_query, st.session_state.data
305
+ )
306
+ st.markdown("## Research Findings")
307
+ st.markdown(result)
308
 
309
+ if __name__ == "__main__":
310
+ main()