|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import os |
|
import base64 |
|
import io |
|
from groq import Groq |
|
from pydantic import BaseModel, Field |
|
from typing import Dict, List, Optional |
|
from langchain.tools import tool |
|
from langchain.agents import initialize_agent, AgentType |
|
from scipy.stats import ttest_ind, f_oneway |
|
from statsmodels.tsa.seasonal import seasonal_decompose |
|
from statsmodels.tsa.stattools import adfuller |
|
from jinja2 import Template |
|
|
|
|
|
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
|
|
|
|
|
class ResearchInput(BaseModel): |
|
"""Base schema for research tool inputs, ensuring type and description integrity.""" |
|
data_key: str = Field(..., description="Session state key containing the DataFrame.") |
|
columns: Optional[List[str]] = Field(None, description="List of column names to analyze.") |
|
|
|
|
|
class TemporalAnalysisInput(ResearchInput): |
|
"""Schema for temporal analysis inputs, focusing on specific time-series requirements.""" |
|
time_col: str = Field(..., description="Name of the column containing timestamp data.") |
|
value_col: str = Field(..., description="Name of the column containing numerical values to analyze.") |
|
|
|
|
|
class HypothesisInput(ResearchInput): |
|
"""Schema for hypothesis testing, demanding group and value specification for statistical rigor.""" |
|
group_col: str = Field(..., description="Categorical column defining the groups for comparison.") |
|
value_col: str = Field(..., description="Numerical column for comparing means across groups.") |
|
|
|
|
|
class GroqResearcher: |
|
""" |
|
A sophisticated AI research engine powered by Groq, designed for rigorous academic-style analysis. |
|
This class handles complex data queries and delivers structured research outputs. |
|
""" |
|
|
|
def __init__(self, model_name="mixtral-8x7b-32768"): |
|
self.model_name = model_name |
|
self.system_template = """ |
|
You are a senior data scientist at a prestigious research institution. Your analysis must |
|
adhere to rigorous scientific standards. Consider the dataset properties and the user's query. |
|
|
|
Dataset Context: |
|
- Dimensions: {{ dataset_shape }} |
|
- Variables: {{ dataset_variables }} |
|
- Temporal Coverage: {{ temporal_coverage }} |
|
- Missing Value Counts: {{ missing_values }} |
|
|
|
User Inquiry: {{ query }} |
|
|
|
Response Structure (Critical for all analyses): |
|
1. **Executive Summary:** Provide a 1-2 paragraph overview of the findings, contextualized within the dataset's characteristics. |
|
2. **Methodology:** Detail the exact analysis techniques used, including statistical tests or model types, and their justification. |
|
3. **Key Findings:** Present the most significant observations and statistical results (p-values, effect sizes) with proper interpretation. |
|
4. **Limitations:** Acknowledge and describe the constraints of the dataset or analytical methods that might affect the results' interpretation or generalizability. |
|
5. **Recommended Next Steps:** Suggest future studies, experiments, or analyses that could extend the current investigation and address the noted limitations. |
|
|
|
""" |
|
|
|
def research(self, query: str, data: pd.DataFrame) -> str: |
|
"""Executes in-depth research using the Groq API to produce academic-quality analyses.""" |
|
try: |
|
dataset_info = { |
|
"dataset_shape": str(data.shape), |
|
"dataset_variables": ", ".join(data.columns), |
|
"temporal_coverage": str(data.select_dtypes(include='datetime').columns.tolist()), |
|
"missing_values": str(data.isnull().sum().to_dict()), |
|
} |
|
|
|
prompt = Template(self.system_template).render(**dataset_info, query=query) |
|
|
|
completion = client.chat.completions.create( |
|
messages=[ |
|
{"role": "system", "content": "You are a research AI assistant."}, |
|
{"role": "user", "content": prompt} |
|
], |
|
model=self.model_name, |
|
temperature=0.2, |
|
max_tokens=4096, |
|
stream=False |
|
) |
|
return completion.choices[0].message.content |
|
except Exception as e: |
|
return f"Research Error Encountered: {str(e)}" |
|
|
|
|
|
@tool(args_schema=ResearchInput) |
|
def advanced_eda(data_key: str) -> Dict: |
|
""" |
|
Performs a comprehensive Exploratory Data Analysis, including statistical profiling, |
|
temporal analysis of datetime columns, and detailed quality checks. |
|
""" |
|
try: |
|
data = st.session_state[data_key] |
|
analysis = { |
|
"dimensionality": { |
|
"rows": int(len(data)), |
|
"columns": list(data.columns), |
|
"memory_usage": f"{data.memory_usage().sum() / 1e6:.2f} MB" |
|
}, |
|
"statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(), |
|
"temporal_analysis": { |
|
"date_ranges": { |
|
col: { |
|
"min": str(data[col].min()), |
|
"max": str(data[col].max()) |
|
} for col in data.select_dtypes(include='datetime').columns |
|
} |
|
}, |
|
"data_quality": { |
|
"missing_values": data.isnull().sum().to_dict(), |
|
"duplicates": int(data.duplicated().sum()), |
|
"cardinality": { |
|
col: int(data[col].nunique()) for col in data.columns |
|
} |
|
} |
|
} |
|
return analysis |
|
except Exception as e: |
|
return {"error": f"Advanced EDA Failed: {str(e)}"} |
|
|
|
@tool(args_schema=ResearchInput) |
|
def visualize_distributions(data_key: str, columns: List[str]) -> str: |
|
""" |
|
Generates high-quality, publication-ready distribution visualizations (histograms with KDE) |
|
for selected numerical columns, and returns the image as a base64 encoded string. |
|
""" |
|
try: |
|
data = st.session_state[data_key] |
|
plt.figure(figsize=(15, 7)) |
|
for i, col in enumerate(columns, 1): |
|
plt.subplot(1, len(columns), i) |
|
sns.histplot(data[col], kde=True, stat="density", color=sns.color_palette()[i % len(sns.color_palette())]) |
|
plt.title(f'Distribution of {col}', fontsize=14, fontweight='bold') |
|
plt.xlabel(col, fontsize=12) |
|
plt.ylabel('Density', fontsize=12) |
|
plt.xticks(fontsize=10) |
|
plt.yticks(fontsize=10) |
|
plt.grid(axis='y', linestyle='--') |
|
sns.despine(top=True, right=True) |
|
plt.tight_layout(pad=2) |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=300, bbox_inches='tight') |
|
plt.close() |
|
return base64.b64encode(buf.getvalue()).decode() |
|
except Exception as e: |
|
return f"Distribution Visualization Error: {str(e)}" |
|
|
|
|
|
@tool(args_schema=TemporalAnalysisInput) |
|
def temporal_analysis(data_key: str, time_col: str, value_col: str) -> Dict: |
|
""" |
|
Performs a sophisticated time series analysis, including decomposition and trend assessment, |
|
providing both statistical insights and a visual representation. |
|
""" |
|
try: |
|
data = st.session_state[data_key] |
|
ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col].dropna() |
|
|
|
if ts_data.empty: |
|
return {"error": "No valid time series data found for analysis after NaN removal."} |
|
|
|
decomposition = seasonal_decompose(ts_data, model='additive', period=min(len(ts_data), 365) if len(ts_data) > 10 else 1) |
|
|
|
plt.figure(figsize=(16, 10)) |
|
decomposition.plot() |
|
plt.tight_layout() |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=300) |
|
plt.close() |
|
plot_data = base64.b64encode(buf.getvalue()).decode() |
|
|
|
adf_result = adfuller(ts_data) |
|
stationarity_p_value = adf_result[1] |
|
|
|
return { |
|
"trend_statistics": { |
|
"stationarity": stationarity_p_value, |
|
"stationarity_interpretation": interpret_p_value(stationarity_p_value), |
|
"seasonality_strength": max(decomposition.seasonal) if hasattr(decomposition, 'seasonal') else None |
|
}, |
|
"visualization": plot_data, |
|
"decomposition_data": { |
|
"trend": decomposition.trend.dropna().to_dict() if hasattr(decomposition, 'trend') else None, |
|
"seasonal": decomposition.seasonal.dropna().to_dict() if hasattr(decomposition, 'seasonal') else None, |
|
"residual": decomposition.resid.dropna().to_dict() if hasattr(decomposition, 'resid') else None, |
|
} |
|
} |
|
except Exception as e: |
|
return {"error": f"Temporal Analysis Failure: {str(e)}"} |
|
|
|
@tool(args_schema=HypothesisInput) |
|
def hypothesis_testing(data_key: str, group_col: str, value_col: str) -> Dict: |
|
""" |
|
Conducts statistical hypothesis testing, providing detailed test results, effect size measures, |
|
and interpretations for both t-tests and ANOVAs. |
|
""" |
|
try: |
|
data = st.session_state[data_key] |
|
groups = data[group_col].unique() |
|
|
|
if len(groups) < 2: |
|
return {"error": "Insufficient groups for comparison. Must have at least two groups."} |
|
|
|
group_data = [data[data[group_col] == g][value_col].dropna() for g in groups] |
|
|
|
if any(len(group) < 2 for group in group_data): |
|
return {"error": "Each group must have at least two data points for testing."} |
|
|
|
if len(groups) == 2: |
|
stat, p = ttest_ind(*group_data) |
|
test_type = "Independent t-test" |
|
else: |
|
stat, p = f_oneway(*group_data) |
|
test_type = "ANOVA" |
|
|
|
effect_size = None |
|
if len(groups) == 2: |
|
pooled_variance = np.sqrt((group_data[0].var() + group_data[1].var()) / 2) |
|
if pooled_variance != 0: |
|
cohens_d = abs(group_data[0].mean() - group_data[1].mean()) / pooled_variance |
|
effect_size = {"cohens_d": cohens_d} |
|
else: |
|
effect_size = {"cohens_d": None, "error": "Cannot compute effect size due to zero pooled variance."} |
|
|
|
return { |
|
"test_type": test_type, |
|
"test_statistic": float(stat), |
|
"p_value": float(p), |
|
"effect_size": effect_size, |
|
"interpretation": interpret_p_value(p), |
|
"group_means": {g: float(data[data[group_col] == g][value_col].mean()) for g in groups} |
|
} |
|
except Exception as e: |
|
return {"error": f"Hypothesis Testing Failed: {str(e)}"} |
|
|
|
def interpret_p_value(p: float) -> str: |
|
"""Provides nuanced interpretations of p-values, including qualitative descriptors.""" |
|
if p < 0.001: return "Highly significant evidence against the null hypothesis (p < 0.001)." |
|
elif p < 0.01: return "Strong evidence against the null hypothesis (0.001 ≤ p < 0.01)." |
|
elif p < 0.05: return "Moderate evidence against the null hypothesis (0.01 ≤ p < 0.05)." |
|
elif p < 0.1: return "Weak evidence against the null hypothesis (0.05 ≤ p < 0.1)." |
|
else: return "No significant evidence against the null hypothesis (p ≥ 0.1)." |
|
|
|
def main(): |
|
st.set_page_config(page_title="AI Research Lab", layout="wide") |
|
st.title("🧪 Advanced AI Research Laboratory") |
|
|
|
|
|
if 'data' not in st.session_state: |
|
st.session_state.data = None |
|
if 'researcher' not in st.session_state: |
|
st.session_state.researcher = GroqResearcher() |
|
|
|
|
|
with st.sidebar: |
|
st.header("🔬 Data Management") |
|
uploaded_file = st.file_uploader("Upload research dataset", type=["csv", "parquet"]) |
|
if uploaded_file: |
|
with st.spinner("Initializing dataset..."): |
|
try: |
|
st.session_state.data = pd.read_csv(uploaded_file) |
|
st.success(f"Loaded {len(st.session_state.data):,} research observations") |
|
except Exception as e: |
|
st.error(f"Error loading the dataset. Please ensure it's a valid CSV or Parquet format. Error details: {e}") |
|
|
|
|
|
if st.session_state.data is not None: |
|
col1, col2 = st.columns([1, 3]) |
|
|
|
with col1: |
|
st.subheader("Dataset Metadata") |
|
st.json({ |
|
"Variables": list(st.session_state.data.columns), |
|
"Time Range": { |
|
col: { |
|
"min": str(st.session_state.data[col].min()), |
|
"max": str(st.session_state.data[col].max()) |
|
} for col in st.session_state.data.select_dtypes(include='datetime').columns |
|
} if st.session_state.data.select_dtypes(include='datetime').columns.tolist() else "No Temporal Data", |
|
"Size": f"{st.session_state.data.memory_usage().sum() / 1e6:.2f} MB" |
|
}) |
|
|
|
with col2: |
|
analysis_tab, research_tab = st.tabs(["Automated Analysis", "Custom Research"]) |
|
|
|
with analysis_tab: |
|
analysis_type = st.selectbox("Select Analysis Mode", [ |
|
"Exploratory Data Analysis", |
|
"Temporal Pattern Analysis", |
|
"Comparative Statistics", |
|
"Distribution Analysis" |
|
]) |
|
|
|
if analysis_type == "Exploratory Data Analysis": |
|
eda_result = advanced_eda.invoke({"data_key": "data"}) |
|
st.subheader("Data Quality Report") |
|
st.json(eda_result) |
|
|
|
elif analysis_type == "Temporal Pattern Analysis": |
|
time_cols = st.session_state.data.select_dtypes(include='datetime').columns.tolist() |
|
if not time_cols: |
|
st.warning("No datetime columns detected. Please ensure you have a datetime column for this analysis.") |
|
else: |
|
time_col = st.selectbox("Temporal Variable", time_cols) |
|
value_col = st.selectbox("Analysis Variable", |
|
st.session_state.data.select_dtypes(include=np.number).columns) |
|
|
|
if time_col and value_col: |
|
result = temporal_analysis.invoke({ |
|
"data_key": "data", |
|
"time_col": time_col, |
|
"value_col": value_col |
|
}) |
|
if "visualization" in result: |
|
st.image(f"data:image/png;base64,{result['visualization']}", |
|
use_column_width=True) |
|
st.json(result) |
|
|
|
elif analysis_type == "Comparative Statistics": |
|
cat_cols = st.session_state.data.select_dtypes(include='category').columns.tolist() + st.session_state.data.select_dtypes(include='object').columns.tolist() |
|
if not cat_cols: |
|
st.warning("No categorical columns detected. Please ensure you have a categorical column for this analysis.") |
|
else: |
|
group_col = st.selectbox("Grouping Variable", cat_cols) |
|
value_col = st.selectbox("Metric Variable", |
|
st.session_state.data.select_dtypes(include=np.number).columns) |
|
|
|
if group_col and value_col: |
|
result = hypothesis_testing.invoke({ |
|
"data_key": "data", |
|
"group_col": group_col, |
|
"value_col": value_col |
|
}) |
|
st.subheader("Statistical Test Results") |
|
st.json(result) |
|
|
|
elif analysis_type == "Distribution Analysis": |
|
num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist() |
|
selected_cols = st.multiselect("Select Variables", num_cols) |
|
if selected_cols: |
|
img_data = visualize_distributions.invoke({ |
|
"data_key": "data", |
|
"columns": selected_cols |
|
}) |
|
st.image(f"data:image/png;base64,{img_data}", |
|
use_column_width=True) |
|
|
|
with research_tab: |
|
research_query = st.text_area("Enter Research Question:", height=150, |
|
placeholder="E.g., 'What factors are most predictive of X outcome?'") |
|
|
|
if st.button("Execute Research"): |
|
with st.spinner("Conducting rigorous analysis..."): |
|
result = st.session_state.researcher.research( |
|
research_query, st.session_state.data |
|
) |
|
st.markdown("## Research Findings") |
|
st.markdown(result) |
|
|
|
if __name__ == "__main__": |
|
main() |