Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
from typing import Dict, List, Optional, Any | |
from pydantic import BaseModel, Field | |
import base64 | |
import io | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from abc import ABC, abstractmethod # For abstract base classes | |
from sklearn.model_selection import train_test_split # Machine learning modules | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.metrics import accuracy_score | |
from statsmodels.tsa.seasonal import seasonal_decompose | |
from statsmodels.tsa.stattools import adfuller | |
from langchain.prompts import PromptTemplate | |
from groq import Groq | |
import os | |
import numpy as np | |
from scipy.stats import ttest_ind, f_oneway | |
# Initialize Groq Client | |
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
# ---------------------- Base Classes and Schemas --------------------------- | |
class ResearchInput(BaseModel): | |
"""Base schema for research tool inputs""" | |
data_key: str = Field(..., description="Session state key containing DataFrame") | |
columns: Optional[List[str]] = Field(None, description="List of columns to analyze") | |
class TemporalAnalysisInput(ResearchInput): | |
"""Schema for temporal analysis""" | |
time_col: str = Field(..., description="Name of timestamp column") | |
value_col: str = Field(..., description="Name of value column to analyze") | |
class HypothesisInput(ResearchInput): | |
"""Schema for hypothesis testing""" | |
group_col: str = Field(..., description="Categorical column defining groups") | |
value_col: str = Field(..., description="Numerical column to compare") | |
class ModelTrainingInput(ResearchInput): | |
"""Schema for model training""" | |
target_col: str = Field(..., description="Name of target column") | |
class DataAnalyzer(ABC): | |
"""Abstract base class for data analysis modules""" | |
def invoke(self, **kwargs) -> Dict[str, Any]: | |
pass | |
# ---------------------- Concrete Analyzer Implementations --------------------------- | |
class AdvancedEDA(DataAnalyzer): | |
"""Comprehensive Exploratory Data Analysis""" | |
def invoke(self, data_key: str, **kwargs) -> Dict[str, Any]: | |
try: | |
data = st.session_state[data_key] | |
analysis = { | |
"dimensionality": { | |
"rows": len(data), | |
"columns": list(data.columns), | |
"memory_usage": f"{data.memory_usage().sum() / 1e6:.2f} MB" | |
}, | |
"statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(), | |
"temporal_analysis": { | |
"date_ranges": { | |
col: { | |
"min": data[col].min(), | |
"max": data[col].max() | |
} for col in data.select_dtypes(include='datetime').columns | |
} | |
}, | |
"data_quality": { | |
"missing_values": data.isnull().sum().to_dict(), | |
"duplicates": data.duplicated().sum(), | |
"cardinality": { | |
col: data[col].nunique() for col in data.columns | |
} | |
} | |
} | |
return analysis | |
except Exception as e: | |
return {"error": f"EDA Failed: {str(e)}"} | |
class DistributionVisualizer(DataAnalyzer): | |
"""Distribution visualizations""" | |
def invoke(self, data_key: str, columns: List[str], **kwargs) -> str: | |
try: | |
data = st.session_state[data_key] | |
plt.figure(figsize=(12, 6)) | |
for i, col in enumerate(columns, 1): | |
plt.subplot(1, len(columns), i) | |
sns.histplot(data[col], kde=True, stat="density") | |
plt.title(f'Distribution of {col}', fontsize=10) | |
plt.xticks(fontsize=8) | |
plt.yticks(fontsize=8) | |
plt.tight_layout() | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=300, bbox_inches='tight') | |
plt.close() | |
return base64.b64encode(buf.getvalue()).decode() | |
except Exception as e: | |
return f"Visualization Error: {str(e)}" | |
class TemporalAnalyzer(DataAnalyzer): | |
"""Time series analysis""" | |
def invoke(self, data_key: str, time_col: str, value_col: str, **kwargs) -> Dict[str, Any]: | |
try: | |
data = st.session_state[data_key] | |
ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col] | |
decomposition = seasonal_decompose(ts_data, period=365) | |
plt.figure(figsize=(12, 8)) | |
decomposition.plot() | |
plt.tight_layout() | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
plot_data = base64.b64encode(buf.getvalue()).decode() | |
return { | |
"trend_statistics": { | |
"stationarity": adfuller(ts_data)[1], | |
"seasonality_strength": max(decomposition.seasonal) | |
}, | |
"visualization": plot_data | |
} | |
except Exception as e: | |
return {"error": f"Temporal Analysis Failed: {str(e)}"} | |
class HypothesisTester(DataAnalyzer): | |
"""Statistical hypothesis testing""" | |
def invoke(self, data_key: str, group_col: str, value_col: str, **kwargs) -> Dict[str, Any]: | |
try: | |
data = st.session_state[data_key] | |
groups = data[group_col].unique() | |
if len(groups) < 2: | |
return {"error": "Insufficient groups for comparison"} | |
if len(groups) == 2: | |
group_data = [data[data[group_col] == g][value_col] for g in groups] | |
stat, p = ttest_ind(*group_data) | |
test_type = "Independent t-test" | |
else: | |
group_data = [data[data[group_col] == g][value_col] for g in groups] | |
stat, p = f_oneway(*group_data) | |
test_type = "ANOVA" | |
return { | |
"test_type": test_type, | |
"test_statistic": stat, | |
"p_value": p, | |
"effect_size": { | |
"cohens_d": abs(group_data[0].mean() - group_data[1].mean())/np.sqrt( | |
(group_data[0].var() + group_data[1].var())/2 | |
) if len(groups) == 2 else None | |
}, | |
"interpretation": self.interpret_p_value(p) | |
} | |
except Exception as e: | |
return {"error": f"Hypothesis Testing Failed: {str(e)}"} | |
def interpret_p_value(self, p: float) -> str: | |
if p < 0.001: return "Very strong evidence against H0" | |
elif p < 0.01: return "Strong evidence against H0" | |
elif p < 0.05: return "Evidence against H0" | |
elif p < 0.1: return "Weak evidence against H0" | |
else: return "No significant evidence against H0" | |
class LogisticRegressionTrainer(DataAnalyzer): | |
"""Logistic Regression Model Trainer""" | |
def invoke(self, data_key: str, target_col: str, columns: List[str], **kwargs) -> Dict[str, Any]: | |
try: | |
data = st.session_state[data_key] | |
X = data[columns] | |
y = data[target_col] | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
model = LogisticRegression(max_iter=1000) | |
model.fit(X_train, y_train) | |
y_pred = model.predict(X_test) | |
accuracy = accuracy_score(y_test, y_pred) | |
return { | |
"model_type": "Logistic Regression", | |
"accuracy": accuracy, | |
"model_params": model.get_params() | |
} | |
except Exception as e: | |
return {"error": f"Logistic Regression Model Error: {str(e)}"} | |
# ---------------------- Groq Research Agent --------------------------- | |
class GroqResearcher: | |
"""Advanced AI Research Engine using Groq""" | |
def __init__(self, model_name="mixtral-8x7b-32768"): | |
self.model_name = model_name | |
self.system_template = """You are a senior data scientist at a research institution. | |
Analyze this dataset with rigorous statistical methods and provide academic-quality insights: | |
{dataset_info} | |
User Question: {query} | |
Required Format: | |
- Executive Summary (1 paragraph) | |
- Methodology (bullet points) | |
- Key Findings (numbered list) | |
- Limitations | |
- Recommended Next Steps""" | |
def research(self, query: str, data: pd.DataFrame) -> str: | |
"""Conduct academic-level analysis using Groq""" | |
try: | |
dataset_info = f""" | |
Dataset Dimensions: {data.shape} | |
Variables: {', '.join(data.columns)} | |
Temporal Coverage: {data.select_dtypes(include='datetime').columns.tolist()} | |
Missing Values: {data.isnull().sum().to_dict()} | |
""" | |
prompt = PromptTemplate.from_template(self.system_template).format( | |
dataset_info=dataset_info, | |
query=query | |
) | |
completion = client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a research AI assistant"}, | |
{"role": "user", "content": prompt} | |
], | |
model=self.model_name, | |
temperature=0.2, | |
max_tokens=4096, | |
stream=False | |
) | |
return completion.choices[0].message.content | |
except Exception as e: | |
return f"Research Error: {str(e)}" | |
# ---------------------- Main Streamlit Application --------------------------- | |
def main(): | |
st.set_page_config(page_title="AI Data Analysis Lab", layout="wide") | |
st.title("🧪 Advanced AI Data Analysis Laboratory") | |
# Session State | |
if 'data' not in st.session_state: | |
st.session_state.data = None | |
if 'researcher' not in st.session_state: | |
st.session_state.researcher = GroqResearcher() | |
# Data Upload | |
with st.sidebar: | |
st.header("🔬 Data Management") | |
uploaded_file = st.file_uploader("Upload research dataset", type=["csv", "parquet"]) | |
if uploaded_file: | |
with st.spinner("Initializing dataset..."): | |
try: | |
st.session_state.data = pd.read_csv(uploaded_file) | |
st.success(f"Loaded {len(st.session_state.data):,} research observations") | |
except Exception as e: | |
st.error(f"Error loading dataset: {e}") | |
if st.session_state.data is not None: | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
st.subheader("Dataset Metadata") | |
st.json({ | |
"Variables": list(st.session_state.data.columns), | |
"Time Range": { | |
col: { | |
"min": st.session_state.data[col].min(), | |
"max": st.session_state.data[col].max() | |
} for col in st.session_state.data.select_dtypes(include='datetime').columns | |
}, | |
"Size": f"{st.session_state.data.memory_usage().sum() / 1e6:.2f} MB" | |
}) | |
with col2: | |
analysis_tab, research_tab = st.tabs(["Automated Analysis", "Custom Research"]) | |
with analysis_tab: | |
analysis_type = st.selectbox("Select Analysis Mode", [ | |
"Exploratory Data Analysis", | |
"Temporal Pattern Analysis", | |
"Comparative Statistics", | |
"Distribution Analysis", | |
"Train Logistic Regression Model" | |
]) | |
if analysis_type == "Exploratory Data Analysis": | |
analyzer = AdvancedEDA() | |
eda_result = analyzer.invoke(data_key="data") | |
st.subheader("Data Quality Report") | |
st.json(eda_result) | |
elif analysis_type == "Temporal Pattern Analysis": | |
time_col = st.selectbox("Temporal Variable", | |
st.session_state.data.select_dtypes(include='datetime').columns) | |
value_col = st.selectbox("Analysis Variable", | |
st.session_state.data.select_dtypes(include=np.number).columns) | |
if time_col and value_col: | |
analyzer = TemporalAnalyzer() | |
result = analyzer.invoke(data_key="data", time_col=time_col, value_col=value_col) | |
if "visualization" in result: | |
st.image(f"data:image/png;base64,{result['visualization']}") | |
st.json(result) | |
elif analysis_type == "Comparative Statistics": | |
group_col = st.selectbox("Grouping Variable", | |
st.session_state.data.select_dtypes(include='category').columns) | |
value_col = st.selectbox("Metric Variable", | |
st.session_state.data.select_dtypes(include=np.number).columns) | |
if group_col and value_col: | |
analyzer = HypothesisTester() | |
result = analyzer.invoke(data_key="data", group_col=group_col, value_col=value_col) | |
st.subheader("Statistical Test Results") | |
st.json(result) | |
elif analysis_type == "Distribution Analysis": | |
num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist() | |
selected_cols = st.multiselect("Select Variables", num_cols) | |
if selected_cols: | |
analyzer = DistributionVisualizer() | |
img_data = analyzer.invoke(data_key="data", columns=selected_cols) | |
st.image(f"data:image/png;base64,{img_data}") | |
elif analysis_type == "Train Logistic Regression Model": | |
num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist() | |
target_col = st.selectbox("Select Target Variable", | |
st.session_state.data.columns.tolist()) | |
selected_cols = st.multiselect("Select Feature Variables", num_cols) | |
if selected_cols and target_col: | |
analyzer = LogisticRegressionTrainer() | |
result = analyzer.invoke(data_key="data", target_col=target_col, columns=selected_cols) | |
st.subheader("Logistic Regression Model Results") | |
st.json(result) | |
with research_tab: | |
research_query = st.text_area("Enter Research Question:", height=150, | |
placeholder="E.g., 'What factors are most predictive of X outcome?'") | |
if st.button("Execute Research"): | |
with st.spinner("Conducting rigorous analysis..."): | |
result = st.session_state.researcher.research( | |
research_query, st.session_state.data | |
) | |
st.markdown("## Research Findings") | |
st.markdown(result) | |
if __name__ == "__main__": | |
main() |