mgbam commited on
Commit
c3ab38e
·
1 Parent(s): 914b42a

Add application file

Browse files
Files changed (1) hide show
  1. app.py +354 -0
app.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from typing import Dict, List, Optional, Any
4
+ from pydantic import BaseModel, Field
5
+ import base64
6
+ import io
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ from abc import ABC, abstractmethod # For abstract base classes
10
+ from sklearn.model_selection import train_test_split # Machine learning modules
11
+ from sklearn.linear_model import LogisticRegression
12
+ from sklearn.metrics import accuracy_score
13
+ from statsmodels.tsa.seasonal import seasonal_decompose
14
+ from statsmodels.tsa.stattools import adfuller
15
+ from langchain.prompts import PromptTemplate
16
+ from groq import Groq
17
+ import os
18
+ import numpy as np
19
+ from scipy.stats import ttest_ind, f_oneway
20
+
21
+ # Initialize Groq Client
22
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
23
+
24
+
25
+ # ---------------------- Base Classes and Schemas ---------------------------
26
+ class ResearchInput(BaseModel):
27
+ """Base schema for research tool inputs"""
28
+ data_key: str = Field(..., description="Session state key containing DataFrame")
29
+ columns: Optional[List[str]] = Field(None, description="List of columns to analyze")
30
+
31
+ class TemporalAnalysisInput(ResearchInput):
32
+ """Schema for temporal analysis"""
33
+ time_col: str = Field(..., description="Name of timestamp column")
34
+ value_col: str = Field(..., description="Name of value column to analyze")
35
+
36
+ class HypothesisInput(ResearchInput):
37
+ """Schema for hypothesis testing"""
38
+ group_col: str = Field(..., description="Categorical column defining groups")
39
+ value_col: str = Field(..., description="Numerical column to compare")
40
+
41
+ class ModelTrainingInput(ResearchInput):
42
+ """Schema for model training"""
43
+ target_col: str = Field(..., description="Name of target column")
44
+
45
+ class DataAnalyzer(ABC):
46
+ """Abstract base class for data analysis modules"""
47
+ @abstractmethod
48
+ def invoke(self, **kwargs) -> Dict[str, Any]:
49
+ pass
50
+
51
+ # ---------------------- Concrete Analyzer Implementations ---------------------------
52
+ class AdvancedEDA(DataAnalyzer):
53
+ """Comprehensive Exploratory Data Analysis"""
54
+ def invoke(self, data_key: str, **kwargs) -> Dict[str, Any]:
55
+ try:
56
+ data = st.session_state[data_key]
57
+ analysis = {
58
+ "dimensionality": {
59
+ "rows": len(data),
60
+ "columns": list(data.columns),
61
+ "memory_usage": f"{data.memory_usage().sum() / 1e6:.2f} MB"
62
+ },
63
+ "statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(),
64
+ "temporal_analysis": {
65
+ "date_ranges": {
66
+ col: {
67
+ "min": data[col].min(),
68
+ "max": data[col].max()
69
+ } for col in data.select_dtypes(include='datetime').columns
70
+ }
71
+ },
72
+ "data_quality": {
73
+ "missing_values": data.isnull().sum().to_dict(),
74
+ "duplicates": data.duplicated().sum(),
75
+ "cardinality": {
76
+ col: data[col].nunique() for col in data.columns
77
+ }
78
+ }
79
+ }
80
+ return analysis
81
+ except Exception as e:
82
+ return {"error": f"EDA Failed: {str(e)}"}
83
+
84
+ class DistributionVisualizer(DataAnalyzer):
85
+ """Distribution visualizations"""
86
+ def invoke(self, data_key: str, columns: List[str], **kwargs) -> str:
87
+ try:
88
+ data = st.session_state[data_key]
89
+ plt.figure(figsize=(12, 6))
90
+ for i, col in enumerate(columns, 1):
91
+ plt.subplot(1, len(columns), i)
92
+ sns.histplot(data[col], kde=True, stat="density")
93
+ plt.title(f'Distribution of {col}', fontsize=10)
94
+ plt.xticks(fontsize=8)
95
+ plt.yticks(fontsize=8)
96
+ plt.tight_layout()
97
+
98
+ buf = io.BytesIO()
99
+ plt.savefig(buf, format='png', dpi=300, bbox_inches='tight')
100
+ plt.close()
101
+ return base64.b64encode(buf.getvalue()).decode()
102
+ except Exception as e:
103
+ return f"Visualization Error: {str(e)}"
104
+
105
+ class TemporalAnalyzer(DataAnalyzer):
106
+ """Time series analysis"""
107
+ def invoke(self, data_key: str, time_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
108
+ try:
109
+ data = st.session_state[data_key]
110
+ ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col]
111
+ decomposition = seasonal_decompose(ts_data, period=365)
112
+
113
+ plt.figure(figsize=(12, 8))
114
+ decomposition.plot()
115
+ plt.tight_layout()
116
+
117
+ buf = io.BytesIO()
118
+ plt.savefig(buf, format='png')
119
+ plt.close()
120
+ plot_data = base64.b64encode(buf.getvalue()).decode()
121
+
122
+ return {
123
+ "trend_statistics": {
124
+ "stationarity": adfuller(ts_data)[1],
125
+ "seasonality_strength": max(decomposition.seasonal)
126
+ },
127
+ "visualization": plot_data
128
+ }
129
+ except Exception as e:
130
+ return {"error": f"Temporal Analysis Failed: {str(e)}"}
131
+
132
+ class HypothesisTester(DataAnalyzer):
133
+ """Statistical hypothesis testing"""
134
+ def invoke(self, data_key: str, group_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
135
+ try:
136
+ data = st.session_state[data_key]
137
+ groups = data[group_col].unique()
138
+
139
+ if len(groups) < 2:
140
+ return {"error": "Insufficient groups for comparison"}
141
+
142
+ if len(groups) == 2:
143
+ group_data = [data[data[group_col] == g][value_col] for g in groups]
144
+ stat, p = ttest_ind(*group_data)
145
+ test_type = "Independent t-test"
146
+ else:
147
+ group_data = [data[data[group_col] == g][value_col] for g in groups]
148
+ stat, p = f_oneway(*group_data)
149
+ test_type = "ANOVA"
150
+
151
+ return {
152
+ "test_type": test_type,
153
+ "test_statistic": stat,
154
+ "p_value": p,
155
+ "effect_size": {
156
+ "cohens_d": abs(group_data[0].mean() - group_data[1].mean())/np.sqrt(
157
+ (group_data[0].var() + group_data[1].var())/2
158
+ ) if len(groups) == 2 else None
159
+ },
160
+ "interpretation": self.interpret_p_value(p)
161
+ }
162
+ except Exception as e:
163
+ return {"error": f"Hypothesis Testing Failed: {str(e)}"}
164
+
165
+ def interpret_p_value(self, p: float) -> str:
166
+ if p < 0.001: return "Very strong evidence against H0"
167
+ elif p < 0.01: return "Strong evidence against H0"
168
+ elif p < 0.05: return "Evidence against H0"
169
+ elif p < 0.1: return "Weak evidence against H0"
170
+ else: return "No significant evidence against H0"
171
+
172
+ class LogisticRegressionTrainer(DataAnalyzer):
173
+ """Logistic Regression Model Trainer"""
174
+ def invoke(self, data_key: str, target_col: str, columns: List[str], **kwargs) -> Dict[str, Any]:
175
+ try:
176
+ data = st.session_state[data_key]
177
+ X = data[columns]
178
+ y = data[target_col]
179
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
180
+ model = LogisticRegression(max_iter=1000)
181
+ model.fit(X_train, y_train)
182
+ y_pred = model.predict(X_test)
183
+ accuracy = accuracy_score(y_test, y_pred)
184
+ return {
185
+ "model_type": "Logistic Regression",
186
+ "accuracy": accuracy,
187
+ "model_params": model.get_params()
188
+ }
189
+ except Exception as e:
190
+ return {"error": f"Logistic Regression Model Error: {str(e)}"}
191
+
192
+ # ---------------------- Groq Research Agent ---------------------------
193
+
194
+ class GroqResearcher:
195
+ """Advanced AI Research Engine using Groq"""
196
+ def __init__(self, model_name="mixtral-8x7b-32768"):
197
+ self.model_name = model_name
198
+ self.system_template = """You are a senior data scientist at a research institution.
199
+ Analyze this dataset with rigorous statistical methods and provide academic-quality insights:
200
+ {dataset_info}
201
+
202
+ User Question: {query}
203
+
204
+ Required Format:
205
+ - Executive Summary (1 paragraph)
206
+ - Methodology (bullet points)
207
+ - Key Findings (numbered list)
208
+ - Limitations
209
+ - Recommended Next Steps"""
210
+
211
+ def research(self, query: str, data: pd.DataFrame) -> str:
212
+ """Conduct academic-level analysis using Groq"""
213
+ try:
214
+ dataset_info = f"""
215
+ Dataset Dimensions: {data.shape}
216
+ Variables: {', '.join(data.columns)}
217
+ Temporal Coverage: {data.select_dtypes(include='datetime').columns.tolist()}
218
+ Missing Values: {data.isnull().sum().to_dict()}
219
+ """
220
+
221
+ prompt = PromptTemplate.from_template(self.system_template).format(
222
+ dataset_info=dataset_info,
223
+ query=query
224
+ )
225
+
226
+ completion = client.chat.completions.create(
227
+ messages=[
228
+ {"role": "system", "content": "You are a research AI assistant"},
229
+ {"role": "user", "content": prompt}
230
+ ],
231
+ model=self.model_name,
232
+ temperature=0.2,
233
+ max_tokens=4096,
234
+ stream=False
235
+ )
236
+
237
+ return completion.choices[0].message.content
238
+
239
+ except Exception as e:
240
+ return f"Research Error: {str(e)}"
241
+ # ---------------------- Main Streamlit Application ---------------------------
242
+ def main():
243
+ st.set_page_config(page_title="AI Data Analysis Lab", layout="wide")
244
+ st.title("🧪 Advanced AI Data Analysis Laboratory")
245
+
246
+ # Session State
247
+ if 'data' not in st.session_state:
248
+ st.session_state.data = None
249
+ if 'researcher' not in st.session_state:
250
+ st.session_state.researcher = GroqResearcher()
251
+
252
+ # Data Upload
253
+ with st.sidebar:
254
+ st.header("🔬 Data Management")
255
+ uploaded_file = st.file_uploader("Upload research dataset", type=["csv", "parquet"])
256
+ if uploaded_file:
257
+ with st.spinner("Initializing dataset..."):
258
+ try:
259
+ st.session_state.data = pd.read_csv(uploaded_file)
260
+ st.success(f"Loaded {len(st.session_state.data):,} research observations")
261
+ except Exception as e:
262
+ st.error(f"Error loading dataset: {e}")
263
+
264
+
265
+ if st.session_state.data is not None:
266
+ col1, col2 = st.columns([1, 3])
267
+ with col1:
268
+ st.subheader("Dataset Metadata")
269
+ st.json({
270
+ "Variables": list(st.session_state.data.columns),
271
+ "Time Range": {
272
+ col: {
273
+ "min": st.session_state.data[col].min(),
274
+ "max": st.session_state.data[col].max()
275
+ } for col in st.session_state.data.select_dtypes(include='datetime').columns
276
+ },
277
+ "Size": f"{st.session_state.data.memory_usage().sum() / 1e6:.2f} MB"
278
+ })
279
+
280
+ with col2:
281
+ analysis_tab, research_tab = st.tabs(["Automated Analysis", "Custom Research"])
282
+ with analysis_tab:
283
+ analysis_type = st.selectbox("Select Analysis Mode", [
284
+ "Exploratory Data Analysis",
285
+ "Temporal Pattern Analysis",
286
+ "Comparative Statistics",
287
+ "Distribution Analysis",
288
+ "Train Logistic Regression Model"
289
+ ])
290
+
291
+ if analysis_type == "Exploratory Data Analysis":
292
+ analyzer = AdvancedEDA()
293
+ eda_result = analyzer.invoke(data_key="data")
294
+ st.subheader("Data Quality Report")
295
+ st.json(eda_result)
296
+
297
+ elif analysis_type == "Temporal Pattern Analysis":
298
+ time_col = st.selectbox("Temporal Variable",
299
+ st.session_state.data.select_dtypes(include='datetime').columns)
300
+ value_col = st.selectbox("Analysis Variable",
301
+ st.session_state.data.select_dtypes(include=np.number).columns)
302
+
303
+ if time_col and value_col:
304
+ analyzer = TemporalAnalyzer()
305
+ result = analyzer.invoke(data_key="data", time_col=time_col, value_col=value_col)
306
+ if "visualization" in result:
307
+ st.image(f"data:image/png;base64,{result['visualization']}")
308
+ st.json(result)
309
+
310
+ elif analysis_type == "Comparative Statistics":
311
+ group_col = st.selectbox("Grouping Variable",
312
+ st.session_state.data.select_dtypes(include='category').columns)
313
+ value_col = st.selectbox("Metric Variable",
314
+ st.session_state.data.select_dtypes(include=np.number).columns)
315
+
316
+ if group_col and value_col:
317
+ analyzer = HypothesisTester()
318
+ result = analyzer.invoke(data_key="data", group_col=group_col, value_col=value_col)
319
+ st.subheader("Statistical Test Results")
320
+ st.json(result)
321
+
322
+ elif analysis_type == "Distribution Analysis":
323
+ num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist()
324
+ selected_cols = st.multiselect("Select Variables", num_cols)
325
+ if selected_cols:
326
+ analyzer = DistributionVisualizer()
327
+ img_data = analyzer.invoke(data_key="data", columns=selected_cols)
328
+ st.image(f"data:image/png;base64,{img_data}")
329
+
330
+ elif analysis_type == "Train Logistic Regression Model":
331
+ num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist()
332
+ target_col = st.selectbox("Select Target Variable",
333
+ st.session_state.data.columns.tolist())
334
+ selected_cols = st.multiselect("Select Feature Variables", num_cols)
335
+ if selected_cols and target_col:
336
+ analyzer = LogisticRegressionTrainer()
337
+ result = analyzer.invoke(data_key="data", target_col=target_col, columns=selected_cols)
338
+ st.subheader("Logistic Regression Model Results")
339
+ st.json(result)
340
+
341
+ with research_tab:
342
+ research_query = st.text_area("Enter Research Question:", height=150,
343
+ placeholder="E.g., 'What factors are most predictive of X outcome?'")
344
+
345
+ if st.button("Execute Research"):
346
+ with st.spinner("Conducting rigorous analysis..."):
347
+ result = st.session_state.researcher.research(
348
+ research_query, st.session_state.data
349
+ )
350
+ st.markdown("## Research Findings")
351
+ st.markdown(result)
352
+
353
+ if __name__ == "__main__":
354
+ main()