|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import os |
|
import base64 |
|
import io |
|
from groq import Groq |
|
from pydantic import BaseModel, Field |
|
from typing import Dict, List, Optional |
|
from langchain.tools import tool |
|
from langchain.agents import initialize_agent, AgentType |
|
from scipy.stats import ttest_ind, f_oneway |
|
from statsmodels.tsa.seasonal import seasonal_decompose |
|
from statsmodels.tsa.stattools import adfuller |
|
from langchain.prompts import PromptTemplate |
|
|
|
|
|
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
|
|
|
class ResearchInput(BaseModel): |
|
"""Base schema for research tool inputs""" |
|
data_key: str = Field(..., description="Session state key containing DataFrame") |
|
columns: Optional[List[str]] = Field(None, description="List of columns to analyze") |
|
|
|
class TemporalAnalysisInput(ResearchInput): |
|
"""Schema for temporal analysis""" |
|
time_col: str = Field(..., description="Name of timestamp column") |
|
value_col: str = Field(..., description="Name of value column to analyze") |
|
|
|
class HypothesisInput(ResearchInput): |
|
"""Schema for hypothesis testing""" |
|
group_col: str = Field(..., description="Categorical column defining groups") |
|
value_col: str = Field(..., description="Numerical column to compare") |
|
|
|
class GroqResearcher: |
|
"""Advanced AI Research Engine using Groq""" |
|
def __init__(self, model_name="mixtral-8x7b-32768"): |
|
self.model_name = model_name |
|
self.system_template = """You are a senior data scientist at a research institution. |
|
Analyze this dataset with rigorous statistical methods and provide academic-quality insights: |
|
{dataset_info} |
|
|
|
User Question: {query} |
|
|
|
Required Format: |
|
- Executive Summary (1 paragraph) |
|
- Methodology (bullet points) |
|
- Key Findings (numbered list) |
|
- Limitations |
|
- Recommended Next Steps""" |
|
|
|
def research(self, query: str, data: pd.DataFrame) -> str: |
|
"""Conduct academic-level analysis using Groq""" |
|
try: |
|
dataset_info = f""" |
|
Dataset Dimensions: {data.shape} |
|
Variables: {', '.join(data.columns)} |
|
Temporal Coverage: {data.select_dtypes(include='datetime').columns.tolist()} |
|
Missing Values: {data.isnull().sum().to_dict()} |
|
""" |
|
|
|
prompt = PromptTemplate.from_template(self.system_template).format( |
|
dataset_info=dataset_info, |
|
query=query |
|
) |
|
|
|
completion = client.chat.completions.create( |
|
messages=[ |
|
{"role": "system", "content": "You are a research AI assistant"}, |
|
{"role": "user", "content": prompt} |
|
], |
|
model=self.model_name, |
|
temperature=0.2, |
|
max_tokens=4096, |
|
stream=False |
|
) |
|
|
|
return completion.choices[0].message.content |
|
|
|
except Exception as e: |
|
return f"Research Error: {str(e)}" |
|
|
|
@tool(args_schema=ResearchInput) |
|
def advanced_eda(data_key: str) -> Dict: |
|
"""Comprehensive Exploratory Data Analysis with Statistical Profiling""" |
|
try: |
|
data = st.session_state[data_key] |
|
analysis = { |
|
"dimensionality": { |
|
"rows": len(data), |
|
"columns": list(data.columns), |
|
"memory_usage": f"{data.memory_usage().sum() / 1e6:.2f} MB" |
|
}, |
|
"statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(), |
|
"temporal_analysis": { |
|
"date_ranges": { |
|
col: { |
|
"min": data[col].min(), |
|
"max": data[col].max() |
|
} for col in data.select_dtypes(include='datetime').columns |
|
} |
|
}, |
|
"data_quality": { |
|
"missing_values": data.isnull().sum().to_dict(), |
|
"duplicates": data.duplicated().sum(), |
|
"cardinality": { |
|
col: data[col].nunique() for col in data.columns |
|
} |
|
} |
|
} |
|
return analysis |
|
except Exception as e: |
|
return {"error": f"EDA Failed: {str(e)}"} |
|
|
|
@tool(args_schema=ResearchInput) |
|
def visualize_distributions(data_key: str, columns: List[str]) -> str: |
|
"""Generate publication-quality distribution visualizations""" |
|
try: |
|
data = st.session_state[data_key] |
|
plt.figure(figsize=(12, 6)) |
|
for i, col in enumerate(columns, 1): |
|
plt.subplot(1, len(columns), i) |
|
sns.histplot(data[col], kde=True, stat="density") |
|
plt.title(f'Distribution of {col}', fontsize=10) |
|
plt.xticks(fontsize=8) |
|
plt.yticks(fontsize=8) |
|
plt.tight_layout() |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=300, bbox_inches='tight') |
|
plt.close() |
|
return base64.b64encode(buf.getvalue()).decode() |
|
except Exception as e: |
|
return f"Visualization Error: {str(e)}" |
|
|
|
@tool(args_schema=TemporalAnalysisInput) |
|
def temporal_analysis(data_key: str, time_col: str, value_col: str) -> Dict: |
|
"""Time Series Decomposition and Trend Analysis""" |
|
try: |
|
data = st.session_state[data_key] |
|
ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col] |
|
|
|
decomposition = seasonal_decompose(ts_data, period=365) |
|
|
|
plt.figure(figsize=(12, 8)) |
|
decomposition.plot() |
|
plt.tight_layout() |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
plot_data = base64.b64encode(buf.getvalue()).decode() |
|
|
|
return { |
|
"trend_statistics": { |
|
"stationarity": adfuller(ts_data)[1], |
|
"seasonality_strength": max(decomposition.seasonal) |
|
}, |
|
"visualization": plot_data |
|
} |
|
except Exception as e: |
|
return {"error": f"Temporal Analysis Failed: {str(e)}"} |
|
|
|
@tool(args_schema=HypothesisInput) |
|
def hypothesis_testing(data_key: str, group_col: str, value_col: str) -> Dict: |
|
"""Statistical Hypothesis Testing with Automated Assumption Checking""" |
|
try: |
|
data = st.session_state[data_key] |
|
groups = data[group_col].unique() |
|
|
|
if len(groups) < 2: |
|
return {"error": "Insufficient groups for comparison"} |
|
|
|
if len(groups) == 2: |
|
group_data = [data[data[group_col] == g][value_col] for g in groups] |
|
stat, p = ttest_ind(*group_data) |
|
test_type = "Independent t-test" |
|
else: |
|
group_data = [data[data[group_col] == g][value_col] for g in groups] |
|
stat, p = f_oneway(*group_data) |
|
test_type = "ANOVA" |
|
|
|
return { |
|
"test_type": test_type, |
|
"test_statistic": stat, |
|
"p_value": p, |
|
"effect_size": { |
|
"cohens_d": abs(group_data[0].mean() - group_data[1].mean())/np.sqrt( |
|
(group_data[0].var() + group_data[1].var())/2 |
|
) if len(groups) == 2 else None |
|
}, |
|
"interpretation": interpret_p_value(p) |
|
} |
|
except Exception as e: |
|
return {"error": f"Hypothesis Testing Failed: {str(e)}"} |
|
|
|
def interpret_p_value(p: float) -> str: |
|
"""Scientific interpretation of p-values""" |
|
if p < 0.001: return "Very strong evidence against H0" |
|
elif p < 0.01: return "Strong evidence against H0" |
|
elif p < 0.05: return "Evidence against H0" |
|
elif p < 0.1: return "Weak evidence against H0" |
|
else: return "No significant evidence against H0" |
|
|
|
def main(): |
|
st.set_page_config(page_title="AI Research Lab", layout="wide") |
|
st.title("🧪 Advanced AI Research Laboratory") |
|
|
|
|
|
if 'data' not in st.session_state: |
|
st.session_state.data = None |
|
if 'researcher' not in st.session_state: |
|
st.session_state.researcher = GroqResearcher() |
|
|
|
|
|
with st.sidebar: |
|
st.header("🔬 Data Management") |
|
uploaded_file = st.file_uploader("Upload research dataset", type=["csv", "parquet"]) |
|
if uploaded_file: |
|
with st.spinner("Initializing dataset..."): |
|
try: |
|
st.session_state.data = pd.read_csv(uploaded_file) |
|
st.success(f"Loaded {len(st.session_state.data):,} research observations") |
|
except Exception as e: |
|
st.error(f"Error loading dataset: {e}") |
|
|
|
|
|
if st.session_state.data is not None: |
|
col1, col2 = st.columns([1, 3]) |
|
|
|
with col1: |
|
st.subheader("Dataset Metadata") |
|
st.json({ |
|
"Variables": list(st.session_state.data.columns), |
|
"Time Range": { |
|
col: { |
|
"min": st.session_state.data[col].min(), |
|
"max": st.session_state.data[col].max() |
|
} for col in st.session_state.data.select_dtypes(include='datetime').columns |
|
}, |
|
"Size": f"{st.session_state.data.memory_usage().sum() / 1e6:.2f} MB" |
|
}) |
|
|
|
with col2: |
|
analysis_tab, research_tab = st.tabs(["Automated Analysis", "Custom Research"]) |
|
|
|
with analysis_tab: |
|
analysis_type = st.selectbox("Select Analysis Mode", [ |
|
"Exploratory Data Analysis", |
|
"Temporal Pattern Analysis", |
|
"Comparative Statistics", |
|
"Distribution Analysis" |
|
]) |
|
|
|
if analysis_type == "Exploratory Data Analysis": |
|
eda_result = advanced_eda.invoke({"data_key": "data"}) |
|
st.subheader("Data Quality Report") |
|
st.json(eda_result) |
|
|
|
elif analysis_type == "Temporal Pattern Analysis": |
|
time_col = st.selectbox("Temporal Variable", |
|
st.session_state.data.select_dtypes(include='datetime').columns) |
|
value_col = st.selectbox("Analysis Variable", |
|
st.session_state.data.select_dtypes(include=np.number).columns) |
|
|
|
if time_col and value_col: |
|
result = temporal_analysis.invoke({ |
|
"data_key": "data", |
|
"time_col": time_col, |
|
"value_col": value_col |
|
}) |
|
if "visualization" in result: |
|
st.image(f"data:image/png;base64,{result['visualization']}") |
|
st.json(result) |
|
|
|
elif analysis_type == "Comparative Statistics": |
|
group_col = st.selectbox("Grouping Variable", |
|
st.session_state.data.select_dtypes(include='category').columns) |
|
value_col = st.selectbox("Metric Variable", |
|
st.session_state.data.select_dtypes(include=np.number).columns) |
|
|
|
if group_col and value_col: |
|
result = hypothesis_testing.invoke({ |
|
"data_key": "data", |
|
"group_col": group_col, |
|
"value_col": value_col |
|
}) |
|
st.subheader("Statistical Test Results") |
|
st.json(result) |
|
|
|
elif analysis_type == "Distribution Analysis": |
|
num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist() |
|
selected_cols = st.multiselect("Select Variables", num_cols) |
|
if selected_cols: |
|
img_data = visualize_distributions.invoke({ |
|
"data_key": "data", |
|
"columns": selected_cols |
|
}) |
|
st.image(f"data:image/png;base64,{img_data}") |
|
|
|
with research_tab: |
|
research_query = st.text_area("Enter Research Question:", height=150, |
|
placeholder="E.g., 'What factors are most predictive of X outcome?'") |
|
|
|
if st.button("Execute Research"): |
|
with st.spinner("Conducting rigorous analysis..."): |
|
result = st.session_state.researcher.research( |
|
research_query, st.session_state.data |
|
) |
|
st.markdown("## Research Findings") |
|
st.markdown(result) |
|
|
|
if __name__ == "__main__": |
|
main() |