diff --git a/.env.example b/.env.example
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..b4adeeb5e105e13a0bb0211536d60b18551d0219
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,37 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.pdf filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..0e74110c4db07af2c4d2a83d0d2adc844bee3dd6
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,217 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be added to the global gitignore or merged into this project gitignore. For a PyCharm
+# project, it is recommended to include directory-based project settings:
+.idea/
+
+# VS Code
+.vscode/
+
+# Data files
+*.csv
+*.xlsx
+*.xls
+*.json
+*.parquet
+*.pickle
+*.pkl
+*.h5
+*.hdf5
+
+# Model files
+*.model
+*.joblib
+*.sav
+
+# Output directories
+outputs/
+results/
+logs/
+checkpoints/
+wandb/
+
+# Temporary files
+*.tmp
+*.temp
+*~
+
+# OS generated files
+.DS_Store
+.DS_Store?
+._*
+.Spotlight-V100
+.Trashes
+ehthumbs.db
+Thumbs.db
+
+# LLM API keys and secrets
+.env.local
+.env.production
+secrets.json
+api_keys.txt
+
+# Experiment tracking
+mlruns/
+.mlflow/
+
+# Large files (adjust sizes as needed)
+*.zip
+*.tar.gz
+*.rar
+
+# Project specific
+tests/output/
\ No newline at end of file
diff --git a/README copy.md b/README copy.md
new file mode 100644
index 0000000000000000000000000000000000000000..05685b5ab0e610da3b04c92dc67ba02a89f07aec
--- /dev/null
+++ b/README copy.md
@@ -0,0 +1,198 @@
+
+
+
+Causal AI Scientist: Facilitating Causal Data Science with
+Large Language Models
+
+
+
+**Causal AI Scientist (CAIS)** is an LLM-powered tool for generating data-driven answers to natural language causal queries. It takes a natural language query (for example, "Does participating in a job training program lead to higher income?"), an accompanying dataset, and the corresponding description as inputs. CAIS then frames a suitable causal estimation problem by selecting appropriate treatment and outcome variables. It finds the suitable method for causal effect estimation, implements it, runs diagnostic tests, and finally interprets the numerical results in the context of the original query.
+
+This repo includes instructions on both using the tool to perform causal analysis on a dataset of interest and reproducing results from our paper.
+
+**Note** : This repository is a work in progress and will be updated with additional instructions and files.
+
+
+
+## Getting Started
+
+#### 🔧 Environment Installation
+
+
+**Prerequisites:**
+- **Python 3.10** (create a new conda environment first)
+- Required Python libraries (specified in `requirements.txt`)
+
+
+**Step 1: Copy the example configuration**
+```bash
+cp .env.example .env
+```
+
+**Step 2: Create Python 3.10 environment**
+```bash
+# Create a new conda environment with Python 3.10
+conda create -n auto_causal python=3.10
+conda activate auto_causal
+pip install -r requirement.txt
+```
+
+**Step3: Setup auto_causal library**
+```bash
+pip install -e .
+```
+
+## Dataset Information
+
+All datasets used to evaluate CAIs and the baseline models are available in the data/ directory. Specifically:
+
+* `all_data`: Folder containing all CSV files from the QRData and real-world study collections.
+* `synthetic_data`: Folder containing all CSV files corresponding to synthetic datasets.
+* `qr_info.csv`: Metadata for QRData files. For each file, this includes the filename, description, causal query, reference causal effect, intended inference method, and additional remarks.
+* `real_info.csv`: Metadata for the real-world datasets.
+* `synthetic_info.csv`: Metadata for the synthetic datasets.
+
+## Run
+To execute CAIS, run
+```python
+python main/run_cais.py \
+ --metadata_path {path_to_metadata} \
+ --data_dir {path_to_data_folder} \
+ --output_dir {output_folder} \
+ --output_name {output_filename} \
+ --llm_name {llm_name}
+```
+Args:
+
+* metadata_path (str): Path to the CSV file containing the queries, dataset descriptions, and data file names
+* data_dir (str): Path to the folder containing the data in CSV format
+* output_dir (str): Path to the folder where the output JSON results will be saved
+* output_name (str): Name of the JSON file where the outputs will be saved
+* llm_name (str): Name of the LLM to be used (e.g., 'gpt-4', 'claude-3', etc.)
+
+A specific example,
+```python
+python main/run_cais.py \
+ --metadata_path "data/qr_info.csv" \
+ --data_dir "data/all_data" \
+ --output_dir "output" \
+ --output_name "results_qr_4o" \
+ --llm_name "gpt-4o-mini"
+```
+
+
+## Reproducing paper results
+**Will be updated soon**
+
+**⚠️ Important Notes:**
+- Keep your `.env` file secure and never commit it to version control
+
+## License
+
+Distributed under the MIT License. See `LICENSE` for more information.
+
+
+
+
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..750300d277fa9d22d58e95326938d09be535250e
--- /dev/null
+++ b/README.md
@@ -0,0 +1,13 @@
+---
+title: Causal AI Scientist
+emoji: 🌍
+colorFrom: green
+colorTo: pink
+sdk: gradio
+sdk_version: 5.41.1
+app_file: app.py
+pinned: false
+license: mit
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..11d9f48d1bbbfb483ce565b2fa2692bd7f806092
--- /dev/null
+++ b/app.py
@@ -0,0 +1,339 @@
+import os
+import sys
+import json
+from pathlib import Path
+import gradio as gr
+import time
+
+# Make your repo importable (expecting a folder named causal-agent at repo root)
+sys.path.append(str(Path(__file__).parent / "causal-agent"))
+
+from auto_causal.agent import run_causal_analysis # uses env for provider/model
+
+# -------- LLM config (OpenAI only; key via HF Secrets) --------
+os.environ.setdefault("LLM_PROVIDER", "openai")
+os.environ.setdefault("LLM_MODEL", "gpt-4o")
+
+# Lazy import to avoid import-time errors if key missing
+def _get_openai_client():
+ if os.getenv("LLM_PROVIDER", "openai") != "openai":
+ raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.")
+ if not os.getenv("OPENAI_API_KEY"):
+ raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).")
+ try:
+ # OpenAI SDK v1+
+ from openai import OpenAI
+ return OpenAI()
+ except Exception as e:
+ raise RuntimeError(f"OpenAI SDK not available: {e}")
+
+# -------- System prompt you asked for (verbatim) --------
+SYSTEM_PROMPT = """You are an expert in statistics and causal inference.
+You will be given:
+1) The original research question.
+2) The analysis method used.
+3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group.
+4) A brief dataset description.
+
+Your task is to produce a clear, concise, and non-technical summary that:
+- Directly answers the research question.
+- States whether the effect is statistically significant.
+- Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change).
+- Mentions the method used in one sentence.
+- Optionally ranks the treatment effects from largest to smallest if multiple treatments exist.
+
+Formatting rules:
+- Use bullet points or short paragraphs.
+- Report effect sizes to two decimal places.
+- Clearly state the interpretation in plain English without technical jargon.
+
+Example Output Structure:
+- **Method:** [Name of method + 1-line rationale]
+- **Key Finding:** [Main answer to the research question]
+- **Details:**
+ - [Treatment name]: +X.XX percentage points (95% CI: [L, U]), p < 0.001 — [Significance comment]
+ - …
+- **Rank Order of Effects:** [Largest → Smallest]
+"""
+
+def _extract_minimal_payload(agent_result: dict) -> dict:
+ """
+ Extract the minimal, LLM-friendly payload from run_causal_analysis output.
+ Falls back gracefully if any fields are missing.
+ """
+ # Try both top-level and nested (your JSON showed both patterns)
+ res = agent_result or {}
+ results = res.get("results", {}) if isinstance(res.get("results"), dict) else {}
+ inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {}
+ vars_ = results.get("variables", {}) if isinstance(results.get("variables"), dict) else {}
+ dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {}
+
+ # Pull best-available fields
+ question = (
+ results.get("original_query")
+ or dataset_analysis.get("original_query")
+ or res.get("query")
+ or "N/A"
+ )
+ method = (
+ inner.get("method_used")
+ or res.get("method_used")
+ or results.get("method_used")
+ or "N/A"
+ )
+
+ effect_estimate = (
+ inner.get("effect_estimate")
+ or res.get("effect_estimate")
+ or {}
+ )
+ confidence_interval = (
+ inner.get("confidence_interval")
+ or res.get("confidence_interval")
+ or {}
+ )
+ standard_error = (
+ inner.get("standard_error")
+ or res.get("standard_error")
+ or {}
+ )
+ p_value = (
+ inner.get("p_value")
+ or res.get("p_value")
+ or {}
+ )
+
+ dataset_desc = (
+ results.get("dataset_description")
+ or res.get("dataset_description")
+ or "N/A"
+ )
+
+ return {
+ "original_question": question,
+ "method_used": method,
+ "estimates": {
+ "effect_estimate": effect_estimate,
+ "confidence_interval": confidence_interval,
+ "standard_error": standard_error,
+ "p_value": p_value,
+ },
+ "dataset_description": dataset_desc,
+ }
+
+def _format_effects_md(effect_estimate: dict) -> str:
+ """
+ Minimal human-readable view of effect estimates for display.
+ """
+ if not effect_estimate or not isinstance(effect_estimate, dict):
+ return "_No effect estimates found._"
+ # Render as bullet list
+ lines = []
+ for k, v in effect_estimate.items():
+ try:
+ lines.append(f"- **{k}**: {float(v):+.4f}")
+ except Exception:
+ lines.append(f"- **{k}**: {v}")
+ return "\n".join(lines)
+
+def _summarize_with_llm(payload: dict) -> str:
+ """
+ Calls OpenAI with the provided SYSTEM_PROMPT and the JSON payload as the user message.
+ Returns the model's text, or raises on error.
+ """
+ client = _get_openai_client()
+ model_name = os.getenv("LLM_MODEL", "gpt-4o-mini")
+
+ user_content = (
+ "Summarize the following causal analysis results:\n\n"
+ + json.dumps(payload, indent=2, ensure_ascii=False)
+ )
+
+ # Use Chat Completions for broad compatibility
+ resp = client.chat.completions.create(
+ model=model_name,
+ messages=[
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": user_content},
+ ],
+ temperature=0
+ )
+ text = resp.choices[0].message.content.strip()
+ return text
+
+def run_agent(query: str, csv_path: str, dataset_description: str):
+ """
+ Modified to use yield for progressive updates and immediate feedback
+ """
+ # Immediate feedback - show processing has started
+ processing_html = """
+
+
🔄 Analysis in Progress...
+
This may take 1-2 minutes depending on dataset size
+
+ """
+
+ yield (
+ processing_html, # method_out
+ processing_html, # effects_out
+ processing_html, # explanation_out
+ {"status": "Processing started..."} # raw_results
+ )
+
+ # Input validation
+ if not os.getenv("OPENAI_API_KEY"):
+ error_html = "⚠️ Set a Space Secret named OPENAI_API_KEY
"
+ yield (error_html, "", "", {})
+ return
+
+ if not csv_path:
+ error_html = "Please upload a CSV dataset.
"
+ yield (error_html, "", "", {})
+ return
+
+ try:
+ # Update status to show causal analysis is running
+ analysis_html = """
+
+
📊 Running Causal Analysis...
+
Analyzing dataset and selecting optimal method
+
+ """
+
+ yield (
+ analysis_html,
+ analysis_html,
+ analysis_html,
+ {"status": "Running causal analysis..."}
+ )
+
+ result = run_causal_analysis(
+ query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
+ dataset_path=csv_path,
+ dataset_description=(dataset_description or "").strip(),
+ )
+
+ # Update to show LLM summarization step
+ llm_html = """
+
+
🤖 Generating Summary...
+
Creating human-readable interpretation
+
+ """
+
+ yield (
+ llm_html,
+ llm_html,
+ llm_html,
+ {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}}
+ )
+
+ except Exception as e:
+ error_html = f"❌ Error: {e}
"
+ yield (error_html, "", "", {})
+ return
+
+ try:
+ payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
+ method = payload.get("method_used", "N/A")
+
+ # Format method output with simple styling
+ method_html = f"""
+
+
Selected Method
+
{method}
+
+ """
+
+ # Format effects with simple styling
+ effect_estimate = payload.get("estimates", {}).get("effect_estimate", {})
+ if effect_estimate:
+ effects_html = ""
+ effects_html += "
Effect Estimates
"
+ # for k, v in effect_estimate.items():
+ # try:
+ # value = f"{float(v):+.4f}"
+ # effects_html += f"
{k}: {value}
"
+ # except:
+ effects_html += f"
{effect_estimate}
"
+ effects_html += "
"
+ else:
+ effects_html = "No effect estimates found
"
+
+ # Generate explanation and format it
+ try:
+ explanation = _summarize_with_llm(payload)
+ explanation_html = f"""
+
+
Detailed Explanation
+
{explanation}
+
+ """
+ except Exception as e:
+ explanation_html = f"⚠️ LLM summary failed: {e}
"
+
+ except Exception as e:
+ error_html = f"❌ Failed to parse results: {e}
"
+ yield (error_html, "", "", {})
+ return
+
+ # Final result
+ yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {})
+
+with gr.Blocks() as demo:
+ gr.Markdown("# Causal Agent")
+ gr.Markdown("Upload your dataset and ask causal questions in natural language. The system will automatically select the appropriate causal inference method and provide clear explanations.")
+
+ with gr.Row():
+ query = gr.Textbox(
+ label="Your causal question (natural language)",
+ placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?",
+ lines=2,
+ )
+
+ with gr.Row():
+ csv_file = gr.File(
+ label="Dataset (CSV)",
+ file_types=[".csv"],
+ type="filepath"
+ )
+
+ dataset_description = gr.Textbox(
+ label="Dataset description (optional)",
+ placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.",
+ lines=4,
+ )
+
+ run_btn = gr.Button("Run analysis", variant="primary")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ method_out = gr.HTML(label="Selected Method")
+ with gr.Column(scale=1):
+ effects_out = gr.HTML(label="Effect Estimates")
+
+ with gr.Row():
+ explanation_out = gr.HTML(label="Detailed Explanation")
+
+ # Add the collapsible raw results section
+ with gr.Accordion("Raw Results (Advanced)", open=False):
+ raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
+
+ run_btn.click(
+ fn=run_agent,
+ inputs=[query, csv_file, dataset_description],
+ outputs=[method_out, effects_out, explanation_out, raw_results],
+ show_progress=True
+ )
+
+ gr.Markdown(
+ """
+ **Tips:**
+ - Be specific about your treatment, outcome, and control variables
+ - Include relevant context in the dataset description
+ - The analysis may take 1-2 minutes for complex datasets
+ """
+ )
+
+if __name__ == "__main__":
+ demo.queue().launch()
\ No newline at end of file
diff --git a/auto_causal/__init__.py b/auto_causal/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecb430ad36e2fbb6583e3bffd06d4c50fd82ca3b
--- /dev/null
+++ b/auto_causal/__init__.py
@@ -0,0 +1,50 @@
+"""
+Auto Causal module for causal inference.
+
+This module provides automated causal inference capabilities
+through a pipeline that selects and applies appropriate causal methods.
+"""
+
+__version__ = "0.1.0"
+
+# Import components
+from auto_causal.components import (
+ parse_input,
+ analyze_dataset,
+ interpret_query,
+ validate_method,
+ generate_explanation,
+ format_output,
+ create_workflow_state_update
+)
+
+# Import tools
+from auto_causal.tools import (
+ input_parser_tool,
+ dataset_analyzer_tool,
+ query_interpreter_tool,
+ method_selector_tool,
+ method_validator_tool,
+ method_executor_tool,
+ explanation_generator_tool,
+ output_formatter_tool
+)
+
+# Import the main agent function
+from .agent import run_causal_analysis
+
+# Remove backward compatibility for old pipeline
+# try:
+# from .pipeline import CausalInferencePipeline
+# except ImportError:
+# # Define a placeholder class if the old pipeline doesn't exist
+# class CausalInferencePipeline:
+# """Placeholder for CausalInferencePipeline."""
+#
+# def __init__(self, *args, **kwargs):
+# pass
+
+# Update __all__ to export the main function
+__all__ = [
+ 'run_causal_analysis'
+]
diff --git a/auto_causal/agent.py b/auto_causal/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c92a56629c7454dd8f4a6b4188d1f9d26f10dc0
--- /dev/null
+++ b/auto_causal/agent.py
@@ -0,0 +1,394 @@
+"""
+LangChain agent for the auto_causal module.
+
+This module configures a LangChain agent with specialized tools for causal inference,
+allowing for an interactive approach to analyzing datasets and applying appropriate
+causal inference methods.
+"""
+
+import logging
+from typing import Dict, List, Any, Optional
+from langchain.agents.react.agent import create_react_agent
+from langchain.agents import AgentExecutor, create_structured_chat_agent, create_tool_calling_agent
+from langchain.chains.conversation.memory import ConversationBufferMemory
+from langchain_core.messages import SystemMessage, HumanMessage
+from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
+from langchain.tools import tool
+# Import the callback handler
+from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
+# Import tool rendering utility
+from langchain.tools.render import render_text_description
+# Import LCEL components
+from langchain.agents.format_scratchpad.tools import format_to_tool_messages
+from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
+from langchain_core.runnables import RunnablePassthrough
+from langchain_core.language_models import BaseChatModel
+from langchain_anthropic.chat_models import convert_to_anthropic_tool
+import os
+# Import actual tools from the tools directory
+from auto_causal.tools.input_parser_tool import input_parser_tool
+from auto_causal.tools.dataset_analyzer_tool import dataset_analyzer_tool
+from auto_causal.tools.query_interpreter_tool import query_interpreter_tool
+from auto_causal.tools.method_selector_tool import method_selector_tool
+from auto_causal.tools.method_validator_tool import method_validator_tool
+from auto_causal.tools.method_executor_tool import method_executor_tool
+from auto_causal.tools.explanation_generator_tool import explanation_generator_tool
+from auto_causal.tools.output_formatter_tool import output_formatter_tool
+#from auto_causal.prompts import SYSTEM_PROMPT # Assuming SYSTEM_PROMPT is defined here or imported
+from langchain_core.output_parsers import StrOutputParser
+# Import the centralized factory function
+from .config import get_llm_client
+#from .prompts import SYSTEM_PROMPT
+from langchain_core.messages import AIMessage, AIMessageChunk
+import re
+import json
+from typing import Union
+from langchain_core.output_parsers import BaseOutputParser
+from langchain.schema import AgentAction, AgentFinish
+from langchain_anthropic.output_parsers import ToolsOutputParser
+from langchain.agents.react.output_parser import ReActOutputParser
+from langchain.agents import AgentOutputParser
+from langchain.agents.agent import AgentAction, AgentFinish, OutputParserException
+import re
+from typing import Union, List
+from auto_causal.models import *
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.exceptions import OutputParserException
+
+from langchain.agents.agent import AgentOutputParser
+from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
+
+FINAL_ANSWER_ACTION = "Final Answer:"
+MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
+ "Invalid Format: Missing 'Action:' after 'Thought:'"
+)
+MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE = (
+ "Invalid Format: Missing 'Action Input:' after 'Action:'"
+)
+FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
+ "Parsing LLM output produced both a final answer and parse-able actions"
+)
+
+
+class ReActMultiInputOutputParser(AgentOutputParser):
+ """Parses ReAct-style output that may contain multiple tool calls."""
+
+ def get_format_instructions(self) -> str:
+ # You can reuse the original FORMAT_INSTRUCTIONS,
+ # but let the model know it may emit multiple actions.
+ return FORMAT_INSTRUCTIONS + (
+ "\n\nIf you need to call more than one tool, simply repeat:\n"
+ "Action: \n"
+ "Action Input: \n"
+ "…for each tool in sequence."
+ )
+
+ @property
+ def _type(self) -> str:
+ return "react-multi-input"
+
+ def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
+ includes_answer = FINAL_ANSWER_ACTION in text
+ print('-------------------')
+ print(text)
+ print('-------------------')
+ # Grab every Action / Action Input block
+ pattern = (
+ r"Action\s*\d*\s*:[\s]*(.*?)\s*"
+ r"Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*?)(?=(?:Action\s*\d*\s*:|$))"
+ )
+ matches = list(re.finditer(pattern, text, re.DOTALL))
+
+ # If we found tool calls…
+ if matches:
+ if includes_answer:
+ # both a final answer *and* tool calls is ambiguous
+ raise OutputParserException(
+ f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
+ )
+
+ actions: List[AgentAction] = []
+ for m in matches:
+ tool_name = m.group(1).strip()
+ tool_input = m.group(2).strip().strip('"')
+ print('\n--------------------------')
+ print(tool_input)
+ print('--------------------------')
+ actions.append(AgentAction(tool_name, json.loads(tool_input), text))
+
+ return actions
+
+ # Otherwise, if there's a final answer, finish
+ if includes_answer:
+ answer = text.split(FINAL_ANSWER_ACTION, 1)[1].strip()
+ return AgentFinish({"output": answer}, text)
+
+ # No calls and no final answer → figure out which error to throw
+ if not re.search(r"Action\s*\d*\s*Input\s*\d*:", text):
+ raise OutputParserException(
+ f"Could not parse LLM output: `{text}`",
+ observation=MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
+ llm_output=text,
+ send_to_llm=True,
+ )
+
+ # Fallback
+ raise OutputParserException(f"Could not parse LLM output: `{text}`")
+
+# Set up basic logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+
+def create_agent_prompt(tools: List[tool]) -> ChatPromptTemplate:
+ """Create the prompt template for the causal inference agent, emphasizing workflow and data handoff.
+ (This is the version required by the LCEL agent structure below)
+ """
+ # Get the tool descriptions
+ tool_description = render_text_description(tools)
+ tool_names = ", ".join([t.name for t in tools])
+
+ # Define the system prompt template string
+ system_template = """
+You are a causal inference expert helping users answer causal questions by following a strict workflow using specialized tools.
+
+Remember you always have to always generate the Thought, Action and Action Input block.
+TOOLS:
+------
+You have access to the following tools:
+
+{tools}
+
+To use a tool, please use the following format:
+
+Thought: Do I need to use a tool? Yes
+Action: the action to take, should be one of [{tool_names}]
+Action Input: the input to the action, as a single, valid JSON object string. Check the tool definition for required arguments and structure.
+Observation: the result of the action, often containing structured data like 'variables', 'dataset_analysis', 'method_info', etc.
+
+When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
+
+Thought: Do I need to use a tool? No
+Final Answer: [your response here]
+
+DO NOT UNDER ANY CIRCUMSTANCE CALL MORE THAN ONE TOOL IN A STEP
+
+**IMPORTANT TOOL USAGE:**
+1. **Action Input Format:** The value for 'Action Input' MUST be a single, valid JSON object string. Do NOT include any other text or formatting around the JSON string.
+2. **Argument Gathering:** You MUST gather ALL required arguments for the Action Input JSON from the initial Human input AND the 'Observation' outputs of PREVIOUS steps. Look carefully at the required arguments for the tool you are calling.
+3. **Data Handoff:** The 'Observation' from a previous step often contains structured data needed by the next tool. For example, the 'variables' output from `query_interpreter_tool` contains fields like `treatment_variable`, `outcome_variable`, `covariates`, `time_variable`, `instrument_variable`, `running_variable`, `cutoff_value`, and `is_rct`. When calling `method_selector_tool`, you MUST construct its required `variables` input argument by including **ALL** these relevant fields identified by the `query_interpreter_tool` in the previous Observation. Similarly, pass the full `dataset_analysis`, `dataset_description`, and `original_query` when required by the next tool.
+
+IMPORTANT WORKFLOW:
+-------------------
+You must follow this exact workflow, selecting the appropriate tool for each step:
+
+1. ALWAYS start with `input_parser_tool` to understand the query
+2. THEN use `dataset_analyzer_tool` to analyze the dataset
+3. THEN use `query_interpreter_tool` to identify variables (output includes `variables` and `dataset_analysis`)
+4. THEN use `method_selector_tool` (input requires `variables` and `dataset_analysis` from previous step)
+5. THEN use `method_validator_tool` (input requires `method_info` and `variables` from previous step)
+6. THEN use `method_executor_tool` (input requires `method`, `variables`, `dataset_path`)
+7. THEN use `explanation_generator_tool` (input requires results, method_info, variables, etc.)
+8. FINALLY use `output_formatter_tool` to return the results
+
+REASONING PROCESS:
+------------------
+EXPLICITLY REASON about:
+1. What step you're currently on (based on previous tool's Observation)
+2. Why you're selecting a particular tool (should follow the workflow)
+3. How the output of the previous tool (especially structured data like `variables`, `dataset_analysis`, `method_info`) informs the inputs required for the current tool.
+
+IMPORTANT RULES:
+1. Do not make more than one tool call in a single step.
+2. Do not include ``` in your output at all.
+3. Don't use action names like default_api.dataset_analyzer_tool, instead use tool names like dataset_analyzer_tool.
+4. Always start, action, and observation with a new line.
+5. Don't use '\\' before double quotes
+6. Don't include ```json for Action Input. Also ensure that Action Input is a valid json. DO no add any text after Action Iput.
+7. You have to always choose one of the tools unless it's the final answer.
+Begin!
+"""
+
+ # Create the prompt template
+ prompt = ChatPromptTemplate.from_messages([
+ ("system", system_template),
+ MessagesPlaceholder("chat_history", optional=True), # Use MessagesPlaceholder
+ # MessagesPlaceholder("agent_scratchpad"),
+
+ ("human", "{input}\n Thought:{agent_scratchpad}"),
+ # ("ai", "{agent_scratchpad}"),
+ # MessagesPlaceholder("agent_scratchpad" ), # Use MessagesPlaceholder
+ # "agent_scratchpad"
+ ])
+ return prompt
+
+def create_causal_agent(llm: BaseChatModel) -> AgentExecutor:
+ """
+ Create and configure the LangChain agent with causal inference tools.
+ (Using explicit LCEL construction, compatible with shared LLM client)
+ """
+ # Define tools available to the agent
+ agent_tools = [
+ input_parser_tool,
+ dataset_analyzer_tool,
+ query_interpreter_tool,
+ method_selector_tool,
+ method_validator_tool,
+ method_executor_tool,
+ explanation_generator_tool,
+ output_formatter_tool
+ ]
+ # anthropic_agent_tools = [ convert_to_anthropic_tool(anthropic_tool) for anthropic_tool in agent_tools]
+ # Create the prompt using the helper
+ prompt = create_agent_prompt(agent_tools)
+ # Bind tools to the LLM (using the passed shared instance)
+
+
+ # Create memory
+ # Consider if memory needs to be passed in or created here
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
+
+ # Manually construct the agent runnable using LCEL
+ from langchain_anthropic.output_parsers import ToolsOutputParser
+ from langchain.agents.output_parsers.json import JSONAgentOutputParser
+ # from langchain.agents.react.output_parser import MultiActionAgentOutputParsers ReActMultiInputOutputParser
+ provider = os.getenv("LLM_PROVIDER", "openai")
+ if provider == "gemini":
+ base_parser=ReActMultiInputOutputParser()
+ llm_with_tools = llm.bind_tools(agent_tools)
+ else:
+ base_parser=ToolsAgentOutputParser()
+ llm_with_tools = llm.bind_tools(agent_tools, tool_choice="any")
+ agent = create_react_agent(llm_with_tools, agent_tools, prompt, output_parser=base_parser)
+
+
+ # Create executor (should now work with the manually constructed agent)
+ executor = AgentExecutor(
+ agent=agent,
+ tools=agent_tools,
+ memory=memory, # Pass the memory object
+ verbose=True,
+ callbacks=[ConsoleCallbackHandler()], # Optional: for console debugging
+ handle_parsing_errors=True, # Let AE handle parsing errors
+ max_retries = 100
+ )
+
+ return executor
+
+def run_causal_analysis(query: str, dataset_path: str,
+ dataset_description: Optional[str] = None,
+ api_key: Optional[str] = None) -> Dict[str, Any]:
+ """
+ Run causal analysis on a dataset based on a user query.
+
+ Args:
+ query: User's causal question
+ dataset_path: Path to the dataset
+ dataset_description: Optional textual description of the dataset
+ api_key: Optional OpenAI API key (DEPRECATED - will be ignored)
+
+ Returns:
+ Dictionary containing the final formatted analysis results from the agent's last step.
+ """
+ # Log the start of the analysis
+ logger.info("Starting causal analysis run...")
+
+ try:
+ # --- Instantiate the shared LLM client ---
+ model_name = os.getenv("LLM_MODEL", "gpt-4")
+ if model_name in ['o3', 'o4-mini', 'o3-mini']:
+ print('-------------------------')
+ shared_llm = get_llm_client()
+ else:
+ shared_llm = get_llm_client(temperature=0) # Or read provider/model from env
+
+ # --- Dependency Injection Note (REMAINS RELEVANT) ---
+ # If tools need the LLM, they must be adapted. Example using partial:
+ # from functools import partial
+ # from .components import input_parser
+ # # Assume input_parser.parse_input needs llm
+ # input_parser_tool_with_llm = tool(partial(input_parser.parse_input, llm=shared_llm))
+ # Use input_parser_tool_with_llm in the tools list passed to the agent below.
+ # Similar adjustments needed for decision_tree._recommend_ps_method if used.
+ # --- End Note ---
+
+ # --- Create agent using the shared LLM ---
+ # agent_executor = create_causal_agent(shared_llm)
+
+ # Construct input, including description if available
+ # IMPORTANT: Agent now expects 'input' and potentially 'chat_history'
+ # The input needs to contain all initial info the first tool might need.
+ input_text = f"My question is: {query}\n"
+ input_text += f"The dataset is located at: {dataset_path}\n"
+ if dataset_description:
+ input_text += f"Dataset Description: {dataset_description}\n"
+ input_text += "Please perform the causal analysis following the workflow."
+
+ # Log the constructed input text
+ logger.info(f"Constructed input for agent: \n{input_text}")
+
+ input_parsing_result = input_parser_tool(input_text)
+ dataset_analysis_result = dataset_analyzer_tool.func(dataset_path=input_parsing_result["dataset_path"], dataset_description=input_parsing_result["dataset_description"], original_query=input_parsing_result["original_query"]).analysis_results
+ query_info = QueryInfo(
+ query_text=input_parsing_result["original_query"],
+ potential_treatments=input_parsing_result["extracted_variables"].get("treatment"),
+ potential_outcomes=input_parsing_result["extracted_variables"].get("outcome"),
+ covariates_hints=input_parsing_result["extracted_variables"].get("covariates_mentioned"),
+ instrument_hints=input_parsing_result["extracted_variables"].get("instruments_mentioned")
+ )
+
+ query_interpreter_output = query_interpreter_tool.func(query_info=query_info, dataset_analysis=dataset_analysis_result, dataset_description=input_parsing_result["dataset_description"], original_query = input_parsing_result["original_query"]).variables
+ method_selector_output = method_selector_tool.func(variables=query_interpreter_output,
+ dataset_analysis=dataset_analysis_result,
+ dataset_description=input_parsing_result["dataset_description"],
+ original_query = input_parsing_result["original_query"],
+ excluded_methods=None)
+ method_info = MethodInfo(
+ **method_selector_output['method_info']
+ )
+ method_validator_input = MethodValidatorInput(
+ method_info=method_info,
+ variables=query_interpreter_output,
+ dataset_analysis=dataset_analysis_result,
+ dataset_description=input_parsing_result["dataset_description"],
+ original_query = input_parsing_result["original_query"]
+ )
+ method_validator_output = method_validator_tool.func(method_validator_input)
+ method_executor_input = MethodExecutorInput(
+ **method_validator_output
+ )
+ method_executor_output = method_executor_tool.func(method_executor_input, original_query = input_parsing_result["original_query"])
+
+ explainer_output = explanation_generator_tool.func( method_info=method_info,
+ validation_info=method_validator_output,
+ variables=query_interpreter_output,
+ results=method_executor_output,
+ dataset_analysis=dataset_analysis_result,
+ dataset_description=input_parsing_result["dataset_description"],
+ original_query = input_parsing_result["original_query"])
+ result = explainer_output
+ result['results']['results']["method_used"] = method_validator_output['method']
+ logger.info(result)
+ logger.info("Causal analysis run finished.")
+
+ # Ensure result is a dict and extract the 'output' part
+ if isinstance(result, dict):
+ final_output = result
+ if isinstance(final_output, dict):
+ return final_output # Return only the dictionary from the final tool
+ else:
+ logger.error(f"Agent result['output'] was not a dictionary: {type(final_output)}. Returning error dict.")
+ return {"error": "Agent did not produce the expected dictionary output in the 'output' key.", "raw_agent_result": result}
+ else:
+ logger.error(f"Agent returned non-dict type: {type(result)}. Returning error dict.")
+ return {"error": "Agent did not return expected dictionary output.", "raw_output": str(result)}
+
+ except ValueError as e:
+ logger.error(f"Configuration Error: {e}")
+ # Return an error dictionary in case of exception too
+ return {"error": f"Error: Configuration issue - {e}"} # Ensure consistent error return type
+ except Exception as e:
+ logger.error(f"An unexpected error occurred during causal analysis: {e}", exc_info=True)
+ # Return an error dictionary in case of exception too
+ return {"error": f"An unexpected error occurred: {e}"}
\ No newline at end of file
diff --git a/auto_causal/components/__init__.py b/auto_causal/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f802001c6f0831cca6a2254b30588691669eecf4
--- /dev/null
+++ b/auto_causal/components/__init__.py
@@ -0,0 +1,28 @@
+"""
+Auto Causal components package.
+
+This package contains the core components for the auto_causal module,
+each handling a specific part of the causal inference workflow.
+"""
+
+from auto_causal.components.input_parser import parse_input
+from auto_causal.components.dataset_analyzer import analyze_dataset
+from auto_causal.components.query_interpreter import interpret_query
+from auto_causal.components.decision_tree import select_method
+from auto_causal.components.method_validator import validate_method
+from auto_causal.components.explanation_generator import generate_explanation
+from auto_causal.components.output_formatter import format_output
+from auto_causal.components.state_manager import create_workflow_state_update
+
+__all__ = [
+ "parse_input",
+ "analyze_dataset",
+ "interpret_query",
+ "select_method",
+ "validate_method",
+ "generate_explanation",
+ "format_output",
+ "create_workflow_state_update"
+]
+
+# This file makes Python treat the directory as a package.
diff --git a/auto_causal/components/dataset_analyzer.py b/auto_causal/components/dataset_analyzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..875177855dd8379e308f7aa0f5232948590e22fa
--- /dev/null
+++ b/auto_causal/components/dataset_analyzer.py
@@ -0,0 +1,853 @@
+"""
+Dataset analyzer component for causal inference.
+
+This module provides functionality to analyze datasets to detect characteristics
+relevant for causal inference methods, including temporal structure, potential
+instrumental variables, discontinuities, and variable relationships.
+"""
+
+import os
+import pandas as pd
+import numpy as np
+from typing import Dict, List, Any, Optional, Tuple
+from scipy import stats
+import logging
+import json
+from langchain_core.language_models import BaseChatModel
+from auto_causal.utils.llm_helpers import llm_identify_temporal_and_unit_vars
+
+logger = logging.getLogger(__name__)
+
+def _calculate_per_group_stats(df: pd.DataFrame, potential_treatments: List[str]) -> Dict[str, Dict]:
+ """Calculates summary stats for numeric covariates grouped by potential binary treatments."""
+ stats_dict = {}
+ numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
+
+ for treat_var in potential_treatments:
+ if treat_var not in df.columns:
+ logger.warning(f"Potential treatment '{treat_var}' not found in DataFrame columns.")
+ continue
+
+ # Ensure treatment is binary (0/1 or similar)
+ unique_vals = df[treat_var].dropna().unique()
+ if len(unique_vals) != 2:
+ logger.info(f"Skipping stats for potential treatment '{treat_var}' as it is not binary ({len(unique_vals)} unique values).")
+ continue
+
+ # Attempt to map values to 0 and 1 if possible
+ try:
+ # Ensure boolean is converted to int
+ if df[treat_var].dtype == 'bool':
+ df[treat_var] = df[treat_var].astype(int)
+ unique_vals = df[treat_var].dropna().unique()
+
+ # Basic check if values are interpretable as 0/1
+ if not set(unique_vals).issubset({0, 1}):
+ # Attempt conversion if possible (e.g., True/False strings?)
+ logger.warning(f"Potential treatment '{treat_var}' has values {unique_vals}, not {0, 1}. Cannot calculate group stats reliably.")
+ continue
+ except Exception as e:
+ logger.warning(f"Could not process potential treatment '{treat_var}' values ({unique_vals}): {e}")
+ continue
+
+ logger.info(f"Calculating group stats for treatment: '{treat_var}'")
+ treat_stats = {'group_sizes': {}, 'covariate_stats': {}}
+
+ try:
+ grouped = df.groupby(treat_var)
+ sizes = grouped.size()
+ treat_stats['group_sizes']['treated'] = int(sizes.get(1, 0))
+ treat_stats['group_sizes']['control'] = int(sizes.get(0, 0))
+
+ if treat_stats['group_sizes']['treated'] == 0 or treat_stats['group_sizes']['control'] == 0:
+ logger.warning(f"Treatment '{treat_var}' has zero samples in one group. Skipping covariate stats.")
+ stats_dict[treat_var] = treat_stats
+ continue
+
+ # Calculate mean and std for numeric covariates
+ cov_stats = grouped[numeric_cols].agg(['mean', 'std']).unstack()
+
+ for cov in numeric_cols:
+ if cov == treat_var: continue # Skip treatment variable itself
+
+ mean_control = cov_stats.get(('mean', 0, cov), np.nan)
+ std_control = cov_stats.get(('std', 0, cov), np.nan)
+ mean_treated = cov_stats.get(('mean', 1, cov), np.nan)
+ std_treated = cov_stats.get(('std', 1, cov), np.nan)
+
+ treat_stats['covariate_stats'][cov] = {
+ 'mean_control': float(mean_control) if pd.notna(mean_control) else None,
+ 'std_control': float(std_control) if pd.notna(std_control) else None,
+ 'mean_treat': float(mean_treated) if pd.notna(mean_treated) else None,
+ 'std_treat': float(std_treated) if pd.notna(std_treated) else None,
+ }
+ stats_dict[treat_var] = treat_stats
+ except Exception as e:
+ logger.error(f"Error calculating stats for treatment '{treat_var}': {e}", exc_info=True)
+ # Store partial info if possible
+ if treat_var not in stats_dict:
+ stats_dict[treat_var] = {'error': str(e)}
+ elif 'error' not in stats_dict[treat_var]:
+ stats_dict[treat_var]['error'] = str(e)
+
+ return stats_dict
+
+def analyze_dataset(
+ dataset_path: str,
+ llm_client: Optional[BaseChatModel] = None,
+ dataset_description: Optional[str] = None,
+ original_query: Optional[str] = None
+) -> Dict[str, Any]:
+ """
+ Analyze a dataset to identify important characteristics for causal inference.
+
+ Args:
+ dataset_path: Path to the dataset file
+ llm_client: Optional LLM client for enhanced analysis
+ dataset_description: Optional description of the dataset for context
+
+ Returns:
+ Dict containing dataset analysis results:
+ - dataset_info: Basic information about the dataset
+ - columns: List of column names
+ - potential_treatments: List of potential treatment variables (possibly LLM augmented)
+ - potential_outcomes: List of potential outcome variables (possibly LLM augmented)
+ - temporal_structure_detected: Whether temporal structure was detected
+ - panel_data_detected: Whether panel data structure was detected
+ - potential_instruments_detected: Whether potential instruments were detected
+ - discontinuities_detected: Whether discontinuities were detected
+ - llm_augmentation: Status of LLM augmentation if used
+ """
+ llm_augmentation = "Not used" if not llm_client else "Initialized"
+
+ # Check if file exists
+ if not os.path.exists(dataset_path):
+ logger.error(f"Dataset file not found at {dataset_path}")
+ return {"error": f"Dataset file not found at {dataset_path}"}
+
+ try:
+ # Load the dataset
+ df = pd.read_csv(dataset_path)
+
+ # Basic dataset information
+ sample_size = len(df)
+ columns_list = df.columns.tolist()
+ num_covariates = len(columns_list) - 2 # Rough estimate (total - T - Y)
+ dataset_info = {
+ "num_rows": sample_size,
+ "num_columns": len(columns_list),
+ "file_path": dataset_path,
+ "file_name": os.path.basename(dataset_path)
+ }
+
+ # --- Detailed Analysis (Keep internal) ---
+ column_types_detailed = {col: str(df[col].dtype) for col in df.columns}
+ missing_values_detailed = df.isnull().sum().to_dict()
+ column_categories_detailed = _categorize_columns(df)
+ column_nunique_counts_detailed = {col: df[col].nunique() for col in df.columns} # Calculate nunique
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
+ correlations_detailed = df[numeric_cols].corr() if numeric_cols else pd.DataFrame()
+ temporal_structure_detailed = detect_temporal_structure(df, llm_client, dataset_description, original_query)
+
+ # First, identify potential treatment and outcome variables
+ potential_variables = _identify_potential_variables(
+ df,
+ column_categories_detailed,
+ llm_client=llm_client,
+ dataset_description=dataset_description
+ )
+
+ if llm_client:
+ llm_augmentation = "Used for variable identification"
+
+ # Then use that info to help find potential instrumental variables
+ potential_instruments_detailed = find_potential_instruments(
+ df,
+ llm_client=llm_client,
+ potential_treatments=potential_variables.get("potential_treatments", []),
+ potential_outcomes=potential_variables.get("potential_outcomes", []),
+ dataset_description=dataset_description
+ )
+
+ # Other analyses
+ discontinuities_detailed = detect_discontinuities(df)
+ variable_relationships_detailed = assess_variable_relationships(df, correlations_detailed)
+
+ # Calculate per-group stats for potential binary treatments
+ potential_binary_treatments = [
+ t for t in potential_variables["potential_treatments"]
+ if column_categories_detailed.get(t) == 'binary'
+ or column_categories_detailed.get(t) == 'binary_categorical'
+ ]
+ per_group_stats = _calculate_per_group_stats(df.copy(), potential_binary_treatments)
+
+ # --- Summarized Analysis (For Output) ---
+
+ # Get boolean flags and essential lists
+ has_temporal = temporal_structure_detailed.get("has_temporal_structure", False)
+ is_panel = temporal_structure_detailed.get("is_panel_data", False)
+ logger.info(f"iv is {potential_instruments_detailed}")
+ has_instruments = len(potential_instruments_detailed) > 0
+ has_discontinuities = discontinuities_detailed.get("has_discontinuities", False)
+
+ # --- Extract only instrument names for the final output ---
+ potential_instrument_names = [
+ inst_dict.get('variable')
+ for inst_dict in potential_instruments_detailed
+ if isinstance(inst_dict, dict) and 'variable' in inst_dict
+ ]
+ logger.info(f"iv is {potential_instrument_names}")
+ # --- Final Output Dictionary (Highly Summarized) ---
+ return {
+ "dataset_info": dataset_info, # Keep basic info
+ "columns": columns_list,
+ "potential_treatments": potential_variables["potential_treatments"],
+ "potential_outcomes": potential_variables["potential_outcomes"],
+ # Return concise flags instead of detailed dicts/lists
+ "temporal_structure_detected": has_temporal,
+ "panel_data_detected": is_panel,
+ "potential_instruments_detected": has_instruments,
+ "discontinuities_detected": has_discontinuities,
+ # Use the extracted list of names here
+ "potential_instruments": potential_instrument_names,
+ "discontinuities": discontinuities_detailed,
+ "temporal_structure": temporal_structure_detailed,
+ "column_categories": column_categories_detailed,
+ "column_nunique_counts": column_nunique_counts_detailed, # Add nunique counts to output
+ "sample_size": sample_size,
+ "num_covariates_estimate": num_covariates,
+ "llm_augmentation": llm_augmentation
+ }
+
+ except Exception as e:
+ logger.error(f"Error analyzing dataset '{dataset_path}': {e}", exc_info=True)
+ return {
+ "error": f"Error analyzing dataset: {str(e)}",
+ "llm_augmentation": llm_augmentation
+ }
+
+
+def _categorize_columns(df: pd.DataFrame) -> Dict[str, str]:
+ """
+ Categorize columns into types relevant for causal inference.
+
+ Args:
+ df: DataFrame to analyze
+
+ Returns:
+ Dict mapping column names to their types
+ """
+ result = {}
+
+ for col in df.columns:
+ # Check if column is numeric
+ if pd.api.types.is_numeric_dtype(df[col]):
+ # Count number of unique values
+ n_unique = df[col].nunique()
+
+ # Binary numeric variable
+ if n_unique == 2:
+ result[col] = "binary"
+ # Likely categorical represented as numeric
+ elif n_unique < 10:
+ result[col] = "categorical_numeric"
+ # Discrete numeric (integers)
+ elif pd.api.types.is_integer_dtype(df[col]):
+ result[col] = "discrete_numeric"
+ # Continuous numeric
+ else:
+ result[col] = "continuous_numeric"
+
+ # Check for datetime
+ elif pd.api.types.is_datetime64_any_dtype(df[col]) or _is_date_string(df, col):
+ result[col] = "datetime"
+
+ # Check for categorical
+ elif pd.api.types.is_categorical_dtype(df[col]) or df[col].nunique() < 20:
+ if df[col].nunique() == 2:
+ result[col] = "binary_categorical"
+ else:
+ result[col] = "categorical"
+
+ # Must be text or other
+ else:
+ result[col] = "text_or_other"
+
+ return result
+
+
+def _is_date_string(df: pd.DataFrame, col: str) -> bool:
+ """
+ Check if a column contains date strings.
+
+ Args:
+ df: DataFrame to check
+ col: Column name to check
+
+ Returns:
+ True if the column appears to contain date strings
+ """
+ # Try to convert to datetime
+ if not pd.api.types.is_string_dtype(df[col]):
+ return False
+
+ # Check sample of values
+ sample = df[col].dropna().sample(min(10, len(df[col].dropna()))).tolist()
+
+ try:
+ for val in sample:
+ pd.to_datetime(val)
+ return True
+ except:
+ return False
+
+
+def _identify_potential_variables(
+ df: pd.DataFrame,
+ column_categories: Dict[str, str],
+ llm_client: Optional[BaseChatModel] = None,
+ dataset_description: Optional[str] = None
+) -> Dict[str, List[str]]:
+ """
+ Identify potential treatment and outcome variables in the dataset, using LLM if available.
+ Falls back to heuristic method if LLM fails or is not available.
+
+ Args:
+ df: DataFrame to analyze
+ column_categories: Dictionary mapping column names to their types
+ llm_client: Optional LLM client for enhanced identification
+ dataset_description: Optional description of the dataset for context
+
+ Returns:
+ Dict with potential treatment and outcome variables
+ """
+ # Try LLM approach if client is provided
+ if llm_client:
+ try:
+ logger.info("Using LLM to identify potential treatment and outcome variables")
+
+ # Create a concise prompt with just column information
+ columns_list = df.columns.tolist()
+ column_types = {col: str(df[col].dtype) for col in columns_list}
+
+ # Get binary columns for extra context
+ binary_cols = [col for col in columns_list
+ if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() == 2]
+
+ # Add dataset description if available
+ description_text = f"\nDataset Description: {dataset_description}" if dataset_description else ""
+
+ prompt = f"""
+You are an expert causal inference data scientist. Identify potential treatment and outcome variables from this dataset.{description_text}
+
+Dataset columns:
+{columns_list}
+
+Column types:
+{column_types}
+
+Binary columns (good treatment candidates):
+{binary_cols}
+
+Instructions:
+1. Identify TREATMENT variables: interventions, treatments, programs, policies, or binary state changes.
+ Look for binary variables or names with 'treatment', 'intervention', 'program', 'policy', etc.
+
+2. Identify OUTCOME variables: results, effects, or responses to treatments.
+ Look for numeric variables (especially non-binary) or names with 'outcome', 'result', 'effect', 'score', etc.
+
+Return ONLY a valid JSON object with two lists: "potential_treatments" and "potential_outcomes".
+Example: {{"potential_treatments": ["treatment_a", "program_b"], "potential_outcomes": ["result_score", "outcome_measure"]}}
+"""
+
+ # Call the LLM and parse the response
+ response = llm_client.invoke(prompt)
+ response_text = response.content if hasattr(response, 'content') else str(response)
+
+ # Extract JSON from the response text
+ import re
+ json_match = re.search(r'{.*}', response_text, re.DOTALL)
+
+ if json_match:
+ result = json.loads(json_match.group(0))
+
+ # Validate the response
+ if (isinstance(result, dict) and
+ "potential_treatments" in result and
+ "potential_outcomes" in result and
+ isinstance(result["potential_treatments"], list) and
+ isinstance(result["potential_outcomes"], list)):
+
+ # Ensure all suggestions are valid columns
+ valid_treatments = [col for col in result["potential_treatments"] if col in df.columns]
+ valid_outcomes = [col for col in result["potential_outcomes"] if col in df.columns]
+
+ if valid_treatments and valid_outcomes:
+ logger.info(f"LLM identified {len(valid_treatments)} treatments and {len(valid_outcomes)} outcomes")
+ return {
+ "potential_treatments": valid_treatments,
+ "potential_outcomes": valid_outcomes
+ }
+ else:
+ logger.warning("LLM suggested invalid columns, falling back to heuristic method")
+ else:
+ logger.warning("Invalid LLM response format, falling back to heuristic method")
+ else:
+ logger.warning("Could not extract JSON from LLM response, falling back to heuristic method")
+
+ except Exception as e:
+ logger.error(f"Error in LLM identification: {e}", exc_info=True)
+ logger.info("Falling back to heuristic method")
+
+ # Fallback to heuristic method
+ logger.info("Using heuristic method to identify potential treatment and outcome variables")
+
+ # Identify potential treatment variables
+ potential_treatments = []
+
+ # Look for binary variables (good treatment candidates)
+ binary_cols = [col for col in df.columns
+ if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() == 2]
+
+ # Look for variables with names suggesting treatment
+ treatment_keywords = ['treatment', 'treat', 'intervention', 'program', 'policy',
+ 'exposed', 'assigned', 'received', 'participated']
+
+ for col in df.columns:
+ col_lower = col.lower()
+ if any(keyword in col_lower for keyword in treatment_keywords):
+ potential_treatments.append(col)
+
+ # Add binary variables if we don't have enough candidates
+ if len(potential_treatments) < 3:
+ for col in binary_cols:
+ if col not in potential_treatments:
+ potential_treatments.append(col)
+ if len(potential_treatments) >= 3:
+ break
+
+ # Identify potential outcome variables
+ potential_outcomes = []
+
+ # Look for numeric variables that aren't binary
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
+ non_binary_numeric = [col for col in numeric_cols if col not in binary_cols]
+
+ # Look for variables with names suggesting outcomes
+ outcome_keywords = ['outcome', 'result', 'effect', 'response', 'score', 'performance',
+ 'achievement', 'success', 'failure', 'improvement']
+
+ for col in df.columns:
+ col_lower = col.lower()
+ if any(keyword in col_lower for keyword in outcome_keywords):
+ potential_outcomes.append(col)
+
+ # Add numeric non-binary variables if we don't have enough candidates
+ if len(potential_outcomes) < 3:
+ for col in non_binary_numeric:
+ if col not in potential_outcomes and col not in potential_treatments:
+ potential_outcomes.append(col)
+ if len(potential_outcomes) >= 3:
+ break
+
+ return {
+ "potential_treatments": potential_treatments,
+ "potential_outcomes": potential_outcomes
+ }
+
+
+def detect_temporal_structure(
+ df: pd.DataFrame,
+ llm_client: Optional[BaseChatModel] = None,
+ dataset_description: Optional[str] = None,
+ original_query: Optional[str] = None
+) -> Dict[str, Any]:
+ """
+ Detect temporal structure in the dataset, using LLM for enhanced identification.
+
+ Args:
+ df: DataFrame to analyze
+ llm_client: Optional LLM client for enhanced identification
+ dataset_description: Optional description of the dataset for context
+
+ Returns:
+ Dict with information about temporal structure:
+ - has_temporal_structure: Whether temporal structure exists
+ - temporal_columns: Primary time column identified (or list if multiple from heuristic)
+ - is_panel_data: Whether data is in panel format
+ - time_column: Primary time column identified for panel data
+ - id_column: Primary unit ID column identified for panel data
+ - time_periods: Number of time periods (if panel data)
+ - units: Number of unique units (if panel data)
+ - identification_method: How time/unit vars were identified ('LLM', 'Heuristic', 'None')
+ """
+ result = {
+ "has_temporal_structure": False,
+ "temporal_columns": [], # Will store primary time column or heuristic list
+ "is_panel_data": False,
+ "time_column": None,
+ "id_column": None,
+ "time_periods": None,
+ "units": None,
+ "identification_method": "None"
+ }
+
+ # --- Step 1: Heuristic identification (as before) ---
+ #heuristic_datetime_cols = []
+ #for col in df.columns:
+ # if pd.api.types.is_datetime64_any_dtype(df[col]):
+ # heuristic_datetime_cols.append(col)
+ # elif pd.api.types.is_string_dtype(df[col]):
+ # try:
+ # if pd.to_datetime(df[col], errors='coerce').notna().any():
+ # heuristic_datetime_cols.append(col)
+ # except:
+ # pass # Ignore conversion errors
+
+ #time_keywords = ['year', 'month', 'day', 'date', 'time', 'period', 'quarter', 'week']
+ #for col in df.columns:
+ # col_lower = col.lower()
+ # if any(keyword in col_lower for keyword in time_keywords) and col not in heuristic_datetime_cols:
+ # heuristic_datetime_cols.append(col)
+
+ #id_keywords = ['id', 'individual', 'person', 'unit', 'entity', 'firm', 'company', 'state', 'country']
+ #heuristic_potential_id_cols = []
+ #for col in df.columns:
+ # col_lower = col.lower()
+ # # Exclude columns already identified as time-related by heuristics
+ # if any(keyword in col_lower for keyword in id_keywords) and col not in heuristic_datetime_cols:
+ # heuristic_potential_id_cols.append(col)
+
+ # --- Step 2: LLM-assisted identification ---
+ llm_identified_time_var = None
+ llm_identified_unit_var = None
+ heuristic_datetime_cols = []
+ heuristic_potential_id_cols = []
+ dataset_summary = df.describe(include='all')
+
+ if llm_client:
+ logger.info("Attempting LLM-assisted identification of temporal/unit variables.")
+ column_names = df.columns.tolist()
+ column_dtypes_dict = {col: str(df[col].dtype) for col in column_names}
+
+ try:
+ llm_suggestions = llm_identify_temporal_and_unit_vars(
+ column_names=column_names,
+ column_dtypes=column_dtypes_dict,
+ dataset_description=dataset_description if dataset_description else "No dataset description provided.",
+ dataset_summary=dataset_summary,
+ heuristic_time_candidates=heuristic_datetime_cols,
+ heuristic_id_candidates=heuristic_potential_id_cols,
+ query=original_query if original_query else "No query provided.",
+ llm=llm_client
+ )
+ llm_identified_time_var = llm_suggestions.get("time_variable")
+ llm_identified_unit_var = llm_suggestions.get("unit_variable")
+ result["identification_method"] = "LLM"
+
+ if not llm_identified_time_var and not llm_identified_unit_var:
+ result["identification_method"] = "LLM_NoIdentification"
+ except Exception as e:
+ logger.warning(f"LLM call for temporal/unit vars failed: {e}. Falling back to heuristics.")
+ result["identification_method"] = "Heuristic_LLM_Error"
+ else:
+ result["identification_method"] = "Heuristic_NoLLM"
+
+ # --- Step 3: Combine LLM and Heuristic Results ---
+ final_time_var = None
+ final_unit_var = None
+
+ if llm_identified_time_var:
+ final_time_var = llm_identified_time_var
+ logger.info(f"Prioritizing LLM identified time variable: {final_time_var}")
+ elif heuristic_datetime_cols:
+ final_time_var = heuristic_datetime_cols[0] # Fallback to first heuristic time col
+ logger.info(f"Using heuristic time variable: {final_time_var}")
+
+ if llm_identified_unit_var:
+ final_unit_var = llm_identified_unit_var
+ logger.info(f"Prioritizing LLM identified unit variable: {final_unit_var}")
+ elif heuristic_potential_id_cols:
+ final_unit_var = heuristic_potential_id_cols[0] # Fallback to first heuristic ID col
+ logger.info(f"Using heuristic unit variable: {final_unit_var}")
+
+ # Update results based on final selections
+ if final_time_var:
+ result["has_temporal_structure"] = True
+ result["temporal_columns"] = [final_time_var] # Store as a list with the primary time var
+ result["time_column"] = final_time_var
+ else: # If no time var found by LLM or heuristic, use original heuristic list for temporal_columns
+ if heuristic_datetime_cols:
+ result["has_temporal_structure"] = True
+ result["temporal_columns"] = heuristic_datetime_cols
+ # time_column remains None
+
+ if final_unit_var:
+ result["id_column"] = final_unit_var
+
+ # --- Step 4: Update Panel Data Logic (based on final_time_var and final_unit_var) ---
+ if final_time_var and final_unit_var:
+ # Check if there are multiple time periods per unit using the identified variables
+ try:
+ # Ensure columns exist before groupby
+ if final_time_var in df.columns and final_unit_var in df.columns:
+ if df.groupby(final_unit_var)[final_time_var].nunique().mean() > 1.0:
+ result["is_panel_data"] = True
+ result["time_periods"] = df[final_time_var].nunique()
+ result["units"] = df[final_unit_var].nunique()
+ logger.info(f"Panel data detected: Time='{final_time_var}', Unit='{final_unit_var}', Periods={result['time_periods']}, Units={result['units']}")
+ else:
+ logger.info("Not panel data: Each unit does not have multiple time periods.")
+ else:
+ logger.warning(f"Final time ('{final_time_var}') or unit ('{final_unit_var}') var not in DataFrame. Cannot confirm panel structure.")
+ except Exception as e:
+ logger.error(f"Error checking panel data structure with time='{final_time_var}', unit='{final_unit_var}': {e}")
+ result["is_panel_data"] = False # Default to false on error
+ else:
+ logger.info("Not panel data: Missing either time or unit variable for panel structure.")
+
+ logger.debug(f"Final temporal structure detection result: {result}")
+ return result
+
+
+def find_potential_instruments(
+ df: pd.DataFrame,
+ llm_client: Optional[BaseChatModel] = None,
+ potential_treatments: List[str] = None,
+ potential_outcomes: List[str] = None,
+ dataset_description: Optional[str] = None
+) -> List[Dict[str, Any]]:
+ """
+ Find potential instrumental variables in the dataset, using LLM if available.
+ Falls back to heuristic method if LLM fails or is not available.
+
+ Args:
+ df: DataFrame to analyze
+ llm_client: Optional LLM client for enhanced identification
+ potential_treatments: Optional list of potential treatment variables
+ potential_outcomes: Optional list of potential outcome variables
+ dataset_description: Optional description of the dataset for context
+
+ Returns:
+ List of potential instrumental variables with their properties
+ """
+ # Try LLM approach if client is provided
+ if llm_client:
+ try:
+ logger.info("Using LLM to identify potential instrumental variables")
+
+ # Create a concise prompt with just column information
+ columns_list = df.columns.tolist()
+
+ # Exclude known treatment and outcome variables from consideration
+ excluded_columns = []
+ if potential_treatments:
+ excluded_columns.extend(potential_treatments)
+ if potential_outcomes:
+ excluded_columns.extend(potential_outcomes)
+
+ # Filter columns to exclude treatments and outcomes
+ candidate_columns = [col for col in columns_list if col not in excluded_columns]
+
+ if not candidate_columns:
+ logger.warning("No eligible columns for instrumental variables after filtering treatments and outcomes")
+ return []
+
+ # Get column types for context
+ column_types = {col: str(df[col].dtype) for col in candidate_columns}
+
+ # Add dataset description if available
+ description_text = f"\nDataset Description: {dataset_description}" if dataset_description else ""
+
+ prompt = f"""
+You are an expert causal inference data scientist. Identify potential instrumental variables from this dataset.{description_text}
+
+DEFINITION: Instrumental variables must:
+1. Be correlated with the treatment variable (relevance)
+2. Only affect the outcome through the treatment (exclusion restriction)
+3. Not be correlated with unmeasured confounders (exogeneity)
+
+Treatment variables: {potential_treatments if potential_treatments else "Unknown"}
+Outcome variables: {potential_outcomes if potential_outcomes else "Unknown"}
+
+Available columns (excluding treatments and outcomes):
+{candidate_columns}
+
+Column types:
+{column_types}
+
+Look for variables likely to be:
+- Random assignments
+- Policy changes
+- Geographic or temporal variations
+- Variables with names containing: 'instrument', 'iv', 'assigned', 'random', 'lottery', 'exogenous'
+
+Return ONLY a JSON array of objects, each with "variable", "reason", and "data_type" fields.
+Example:
+[
+ {{"variable": "random_assignment", "reason": "Random assignment variable", "data_type": "int64"}},
+ {{"variable": "distance_to_facility", "reason": "Geographic variation", "data_type": "float64"}}
+]
+"""
+
+ # Call the LLM and parse the response
+ response = llm_client.invoke(prompt)
+ response_text = response.content if hasattr(response, 'content') else str(response)
+
+ # Extract JSON from the response text
+ import re
+ json_match = re.search(r'\[\s*{.*}\s*\]', response_text, re.DOTALL)
+
+ if json_match:
+ result = json.loads(json_match.group(0))
+
+ # Validate the response
+ if isinstance(result, list) and len(result) > 0:
+ # Filter for valid entries
+ valid_instruments = []
+ for item in result:
+ if not isinstance(item, dict) or "variable" not in item:
+ continue
+
+ if item["variable"] not in df.columns:
+ continue
+
+ # Ensure all required fields are present
+ if "reason" not in item:
+ item["reason"] = "Identified by LLM"
+ if "data_type" not in item:
+ item["data_type"] = str(df[item["variable"]].dtype)
+
+ valid_instruments.append(item)
+
+ if valid_instruments:
+ logger.info(f"LLM identified {len(valid_instruments)} potential instrumental variables {valid_instruments}")
+ return valid_instruments
+ else:
+ logger.warning("No valid instruments found by LLM, falling back to heuristic method")
+ else:
+ logger.warning("Invalid LLM response format, falling back to heuristic method")
+ else:
+ logger.warning("Could not extract JSON from LLM response, falling back to heuristic method")
+
+ except Exception as e:
+ logger.error(f"Error in LLM identification of instruments: {e}", exc_info=True)
+ logger.info("Falling back to heuristic method")
+
+ # Fallback to heuristic method
+ logger.info("Using heuristic method to identify potential instrumental variables")
+ potential_instruments = []
+
+ # Look for variables with instrumental-related names
+ instrument_keywords = ['instrument', 'iv', 'assigned', 'random', 'lottery', 'exogenous']
+
+ for col in df.columns:
+ # Skip treatment and outcome variables
+ if potential_treatments and col in potential_treatments:
+ continue
+ if potential_outcomes and col in potential_outcomes:
+ continue
+
+ col_lower = col.lower()
+ if any(keyword in col_lower for keyword in instrument_keywords):
+ instrument_info = {
+ "variable": col,
+ "reason": f"Name contains instrument-related keyword",
+ "data_type": str(df[col].dtype)
+ }
+ potential_instruments.append(instrument_info)
+
+ return potential_instruments
+
+
+def detect_discontinuities(df: pd.DataFrame) -> Dict[str, Any]:
+ """
+ Identify discontinuities in continuous variables (for RDD).
+
+ Args:
+ df: DataFrame to analyze
+
+ Returns:
+ Dict with information about detected discontinuities
+ """
+ discontinuities = []
+
+ # For each numeric column, check for potential discontinuities
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
+
+ for col in numeric_cols:
+ # Skip columns with too many unique values
+ if df[col].nunique() > 100:
+ continue
+
+ values = df[col].dropna().sort_values().values
+
+ # Calculate gaps between consecutive values
+ if len(values) > 10:
+ gaps = np.diff(values)
+ mean_gap = np.mean(gaps)
+ std_gap = np.std(gaps)
+
+ # Look for unusually large gaps (potential discontinuities)
+ large_gaps = np.where(gaps > mean_gap + 2*std_gap)[0]
+
+ if len(large_gaps) > 0:
+ for idx in large_gaps:
+ cutpoint = (values[idx] + values[idx+1]) / 2
+ discontinuities.append({
+ "variable": col,
+ "cutpoint": float(cutpoint),
+ "gap_size": float(gaps[idx]),
+ "mean_gap": float(mean_gap)
+ })
+
+ return {
+ "has_discontinuities": len(discontinuities) > 0,
+ "discontinuities": discontinuities
+ }
+
+
+def assess_variable_relationships(df: pd.DataFrame, corr_matrix: pd.DataFrame) -> Dict[str, Any]:
+ """
+ Assess relationships between variables in the dataset.
+
+ Args:
+ df: DataFrame to analyze
+ corr_matrix: Precomputed correlation matrix for numeric columns
+
+ Returns:
+ Dict with information about variable relationships:
+ - strongly_correlated_pairs: Pairs of strongly correlated variables
+ - potential_confounders: Variables that might be confounders
+ """
+ result = {"strongly_correlated_pairs": [], "potential_confounders": []}
+
+ numeric_cols = corr_matrix.columns.tolist()
+ if len(numeric_cols) < 2:
+ return result
+
+ # Use the precomputed correlation matrix
+ corr_matrix_abs = corr_matrix.abs()
+
+ # Find strongly correlated variable pairs
+ for i in range(len(numeric_cols)):
+ for j in range(i+1, len(numeric_cols)):
+ if abs(corr_matrix_abs.iloc[i, j]) > 0.7: # Correlation threshold
+ result["strongly_correlated_pairs"].append({
+ "variables": [numeric_cols[i], numeric_cols[j]],
+ "correlation": float(corr_matrix.iloc[i, j])
+ })
+
+ # Identify potential confounders (variables correlated with multiple others)
+ confounder_counts = {col: 0 for col in numeric_cols}
+
+ for pair in result["strongly_correlated_pairs"]:
+ confounder_counts[pair["variables"][0]] += 1
+ confounder_counts[pair["variables"][1]] += 1
+
+ # Variables correlated with multiple others are potential confounders
+ for col, count in confounder_counts.items():
+ if count >= 2:
+ result["potential_confounders"].append({"variable": col, "num_correlations": count})
+
+ return result
\ No newline at end of file
diff --git a/auto_causal/components/decision_tree.py b/auto_causal/components/decision_tree.py
new file mode 100644
index 0000000000000000000000000000000000000000..39ef7031cabc787b438d4449c276a47e86adc446
--- /dev/null
+++ b/auto_causal/components/decision_tree.py
@@ -0,0 +1,366 @@
+"""
+decision tree component for selecting causal inference methods
+
+this module implements the decision tree logic to select the most appropriate
+causal inference method based on dataset characteristics and available variables
+"""
+
+import logging
+from typing import Dict, List, Any, Optional
+import pandas as pd
+
+# define method names
+BACKDOOR_ADJUSTMENT = "backdoor_adjustment"
+LINEAR_REGRESSION = "linear_regression"
+DIFF_IN_MEANS = "diff_in_means"
+DIFF_IN_DIFF = "difference_in_differences"
+REGRESSION_DISCONTINUITY = "regression_discontinuity_design"
+PROPENSITY_SCORE_MATCHING = "propensity_score_matching"
+INSTRUMENTAL_VARIABLE = "instrumental_variable"
+CORRELATION_ANALYSIS = "correlation_analysis"
+PROPENSITY_SCORE_WEIGHTING = "propensity_score_weighting"
+GENERALIZED_PROPENSITY_SCORE = "generalized_propensity_score"
+FRONTDOOR_ADJUSTMENT = "frontdoor_adjustment"
+
+
+logger = logging.getLogger(__name__)
+
+# method assumptions mapping
+METHOD_ASSUMPTIONS = {
+ BACKDOOR_ADJUSTMENT: [
+ "no unmeasured confounders (conditional ignorability given covariates)",
+ "correct model specification for outcome conditional on treatment and covariates",
+ "positivity/overlap (for all covariate values, units could potentially receive either treatment level)"
+ ],
+ LINEAR_REGRESSION: [
+ "linear relationship between treatment, covariates, and outcome",
+ "no unmeasured confounders (if observational)",
+ "correct model specification",
+ "homoscedasticity of errors",
+ "normally distributed errors (for inference)"
+ ],
+ DIFF_IN_MEANS: [
+ "treatment is randomly assigned (or as-if random)",
+ "no spillover effects",
+ "stable unit treatment value assumption (SUTVA)"
+ ],
+ DIFF_IN_DIFF: [
+ "parallel trends between treatment and control groups before treatment",
+ "no spillover effects between groups",
+ "no anticipation effects before treatment",
+ "stable composition of treatment and control groups",
+ "treatment timing is exogenous"
+ ],
+ REGRESSION_DISCONTINUITY: [
+ "units cannot precisely manipulate the running variable around the cutoff",
+ "continuity of conditional expectation functions of potential outcomes at the cutoff",
+ "no other changes occurring precisely at the cutoff"
+ ],
+ PROPENSITY_SCORE_MATCHING: [
+ "no unmeasured confounders (conditional ignorability)",
+ "sufficient overlap (common support) between treatment and control groups",
+ "correct propensity score model specification"
+ ],
+ INSTRUMENTAL_VARIABLE: [
+ "instrument is correlated with treatment (relevance)",
+ "instrument affects outcome only through treatment (exclusion restriction)",
+ "instrument is independent of unmeasured confounders (exogeneity/independence)"
+ ],
+ CORRELATION_ANALYSIS: [
+ "data represents a sample from the population of interest",
+ "variables are measured appropriately"
+ ],
+ PROPENSITY_SCORE_WEIGHTING: [
+ "no unmeasured confounders (conditional ignorability)",
+ "sufficient overlap (common support) between treatment and control groups",
+ "correct propensity score model specification",
+ "weights correctly specified (e.g., ATE, ATT)"
+ ],
+ GENERALIZED_PROPENSITY_SCORE: [
+ "conditional mean independence",
+ "positivity/common support for GPS",
+ "correct specification of the GPS model",
+ "correct specification of the outcome model",
+ "no unmeasured confounders affecting both treatment and outcome, given X",
+ "treatment variable is continuous"
+ ],
+ FRONTDOOR_ADJUSTMENT: [
+ "mediator is affected by treatment and affects outcome",
+ "mediator is not affected by any confounders of the treatment-outcome relationship"
+ ]
+}
+
+
+def select_method(dataset_properties: Dict[str, Any], excluded_methods: Optional[List[str]] = None) -> Dict[str, Any]:
+ excluded_methods = set(excluded_methods or [])
+ logger.info(f"Excluded methods: {sorted(excluded_methods)}")
+
+ treatment = dataset_properties.get("treatment_variable")
+ outcome = dataset_properties.get("outcome_variable")
+ if not treatment or not outcome:
+ raise ValueError("Both treatment and outcome variables must be specified")
+
+ instrument_var = dataset_properties.get("instrument_variable")
+ running_var = dataset_properties.get("running_variable")
+ cutoff_val = dataset_properties.get("cutoff_value")
+ time_var = dataset_properties.get("time_variable")
+ is_rct = dataset_properties.get("is_rct", False)
+ has_temporal = dataset_properties.get("has_temporal_structure", False)
+ frontdoor = dataset_properties.get("frontdoor_criterion", False)
+ covariate_overlap_result = dataset_properties.get("covariate_overlap_score")
+ covariates = dataset_properties.get("covariates", [])
+ treatment_variable_type = dataset_properties.get("treatment_variable_type", "binary")
+
+ # Helpers to collect candidates
+ candidates = [] # list of (method, priority_index)
+ justifications: Dict[str, str] = {}
+ assumptions: Dict[str, List[str]] = {}
+
+ def add(method: str, justification: str, prio_order: List[str]):
+ if method in justifications: # already added
+ return
+ justifications[method] = justification
+ assumptions[method] = METHOD_ASSUMPTIONS[method]
+ # priority index from provided order (fallback large if not present)
+ try:
+ idx = prio_order.index(method)
+ except ValueError:
+ idx = 10**6
+ candidates.append((method, idx))
+
+ # ----- Build candidate set (no returns here) -----
+
+ # RCT branch
+ if is_rct:
+ logger.info("Dataset is from a randomized controlled trial (RCT)")
+ rct_priority = [INSTRUMENTAL_VARIABLE, LINEAR_REGRESSION, DIFF_IN_MEANS]
+
+ if instrument_var and instrument_var != treatment:
+ add(INSTRUMENTAL_VARIABLE,
+ f"RCT encouragement: instrument '{instrument_var}' differs from treatment '{treatment}'.",
+ rct_priority)
+
+ if covariates:
+ add(LINEAR_REGRESSION,
+ "RCT with covariates—use OLS for precision.",
+ rct_priority)
+ else:
+ add(DIFF_IN_MEANS,
+ "Pure RCT without covariates—difference-in-means.",
+ rct_priority)
+
+ # Observational branch
+ obs_priority_binary = [
+ INSTRUMENTAL_VARIABLE,
+ PROPENSITY_SCORE_MATCHING,
+ PROPENSITY_SCORE_WEIGHTING,
+ FRONTDOOR_ADJUSTMENT,
+ LINEAR_REGRESSION,
+ ]
+ obs_priority_nonbinary = [
+ INSTRUMENTAL_VARIABLE,
+ FRONTDOOR_ADJUSTMENT,
+ LINEAR_REGRESSION,
+ ]
+
+ # Common early structural signals first (still only add as candidates)
+ if has_temporal and time_var:
+ add(DIFF_IN_DIFF,
+ f"Temporal structure via '{time_var}'—consider Difference-in-Differences (assumes parallel trends).",
+ [DIFF_IN_DIFF]) # highest among itself
+
+ if running_var and cutoff_val is not None:
+ add(REGRESSION_DISCONTINUITY,
+ f"Running variable '{running_var}' with cutoff {cutoff_val}—consider RDD.",
+ [REGRESSION_DISCONTINUITY])
+
+ # Binary vs non-binary pathways
+ if treatment_variable_type == "binary":
+ if instrument_var:
+ add(INSTRUMENTAL_VARIABLE,
+ f"Instrumental variable '{instrument_var}' available.",
+ obs_priority_binary)
+
+ # Propensity score methods only if covariates exist
+ if covariates:
+ if covariate_overlap_result is not None:
+ ps_method = (PROPENSITY_SCORE_WEIGHTING
+ if covariate_overlap_result < 0.1
+ else PROPENSITY_SCORE_MATCHING)
+ else:
+ ps_method = PROPENSITY_SCORE_MATCHING
+ add(ps_method,
+ "Covariates observed; PS method chosen based on overlap.",
+ obs_priority_binary)
+
+ if frontdoor:
+ add(FRONTDOOR_ADJUSTMENT,
+ "Front-door criterion satisfied.",
+ obs_priority_binary)
+
+ add(LINEAR_REGRESSION,
+ "OLS as a fallback specification.",
+ obs_priority_binary)
+
+ else:
+ logger.info(f"Non-binary treatment variable detected: {treatment_variable_type}")
+ if instrument_var:
+ add(INSTRUMENTAL_VARIABLE,
+ f"Instrument '{instrument_var}' candidate for non-binary treatment.",
+ obs_priority_nonbinary)
+ if frontdoor:
+ add(FRONTDOOR_ADJUSTMENT,
+ "Front-door criterion satisfied.",
+ obs_priority_nonbinary)
+ add(LINEAR_REGRESSION,
+ "Fallback for non-binary treatment without stronger identification.",
+ obs_priority_nonbinary)
+
+ # ----- Centralized exclusion handling -----
+ # Remove excluded
+ filtered = [(m, p) for (m, p) in candidates if m not in excluded_methods]
+
+ # If nothing survives, attempt a safe fallback not excluded
+ if not filtered:
+ logger.warning(f"All candidates excluded. Candidates were: {[m for m,_ in candidates]}. Excluded: {sorted(excluded_methods)}")
+ fallback_order = [
+ LINEAR_REGRESSION,
+ DIFF_IN_MEANS,
+ PROPENSITY_SCORE_MATCHING,
+ PROPENSITY_SCORE_WEIGHTING,
+ DIFF_IN_DIFF,
+ REGRESSION_DISCONTINUITY,
+ INSTRUMENTAL_VARIABLE,
+ FRONTDOOR_ADJUSTMENT,
+ ]
+ fallback = next((m for m in fallback_order if m in justifications and m not in excluded_methods), None)
+ if not fallback:
+ # truly nothing left; raise with context
+ raise RuntimeError("No viable method remains after exclusions.")
+ selected_method = fallback
+ alternatives = []
+ justifications[selected_method] = justifications.get(selected_method, "Fallback after exclusions.")
+ else:
+ # Pick by smallest priority index, then stable by insertion
+ filtered.sort(key=lambda x: x[1])
+ selected_method = filtered[0][0]
+ alternatives = [m for (m, _) in filtered[1:] if m != selected_method]
+
+ logger.info(f"Selected method: {selected_method}; alternatives: {alternatives}")
+
+ return {
+ "selected_method": selected_method,
+ "method_justification": justifications[selected_method],
+ "method_assumptions": assumptions[selected_method],
+ "alternatives": alternatives,
+ "excluded_methods": sorted(excluded_methods),
+ }
+
+
+
+def rule_based_select_method(dataset_analysis, variables, is_rct, llm, dataset_description, original_query, excluded_methods=None):
+ """
+ Wrapped function to select causal method based on dataset properties and query
+
+ Args:
+ dataset_analysis (Dict): results of dataset analysis
+ variables (Dict): dictionary of variable names and types
+ is_rct (bool): whether the dataset is from a randomized controlled trial
+ llm (BaseChatModel): language model instance for generating prompts
+ dataset_description (str): description of the dataset
+ original_query (str): the original user query
+ excluded_methods (List[str], optional): list of methods to exclude from selection
+ """
+
+ logger.info("Running rule-based method selection")
+
+
+ properties = {"treatment_variable": variables.get("treatment_variable"), "instrument_variable":variables.get("instrument_variable"),
+ "covariates": variables.get("covariates", []), "outcome_variable": variables.get("outcome_variable"),
+ "time_variable": variables.get("time_variable"), "running_variable": variables.get("running_variable"),
+ "treatment_variable_type": variables.get("treatment_variable_type", "binary"),
+ "has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
+ "frontdoor_criterion": variables.get("frontdoor_criterion", False),
+ "cutoff_value": variables.get("cutoff_value"),
+ "covariate_overlap_score": variables.get("covariate_overlap_result", 0)}
+
+ properties["is_rct"] = is_rct
+ logger.info(f"Dataset properties for method selection: {properties}")
+
+ return select_method(properties, excluded_methods)
+
+
+
+class DecisionTreeEngine:
+ """
+ Engine for applying decision trees to select appropriate causal methods.
+
+ This class wraps the functional decision tree implementation to provide
+ an object-oriented interface for method selection.
+ """
+
+ def __init__(self, verbose=False):
+ self.verbose = verbose
+
+ def select_method(self, df: pd.DataFrame, treatment: str, outcome: str, covariates: List[str],
+ dataset_analysis: Dict[str, Any], query_details: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Apply decision tree to select appropriate causal method.
+ """
+
+ if self.verbose:
+ print(f"Applying decision tree for treatment: {treatment}, outcome: {outcome}")
+ print(f"Available covariates: {covariates}")
+
+ treatment_variable_type = query_details.get("treatment_variable_type")
+ covariate_overlap_result = query_details.get("covariate_overlap_result")
+ info = {"treatment_variable": treatment, "outcome_variable": outcome,
+ "covariates": covariates, "time_variable": query_details.get("time_variable"),
+ "group_variable": query_details.get("group_variable"),
+ "instrument_variable": query_details.get("instrument_variable"),
+ "running_variable": query_details.get("running_variable"),
+ "cutoff_value": query_details.get("cutoff_value"),
+ "is_rct": query_details.get("is_rct", False),
+ "has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
+ "frontdoor_criterion": query_details.get("frontdoor_criterion", False),
+ "covariate_overlap_score": covariate_overlap_result,
+ "treatment_variable_type": treatment_variable_type}
+
+ result = select_method(info)
+
+ if self.verbose:
+ print(f"Selected method: {result['selected_method']}")
+ print(f"Justification: {result['method_justification']}")
+
+ result["decision_path"] = self._get_decision_path(result["selected_method"])
+ return result
+
+
+ def _get_decision_path(self, method):
+ if method == "linear_regression":
+ return ["Check if randomized experiment", "Data appears to be from a randomized experiment with covariates"]
+ elif method == "propensity_score_matching":
+ return ["Check if randomized experiment", "Data is observational",
+ "Check for sufficient covariate overlap", "Sufficient overlap exists"]
+ elif method == "propensity_score_weighting":
+ return ["Check if randomized experiment", "Data is observational",
+ "Check for sufficient covariate overlap", "Low overlap—weighting preferred"]
+ elif method == "backdoor_adjustment":
+ return ["Check if randomized experiment", "Data is observational",
+ "Check for sufficient covariate overlap", "Adjusting for covariates"]
+ elif method == "instrumental_variable":
+ return ["Check if randomized experiment", "Data is observational",
+ "Check for instrumental variables", "Instrument is available"]
+ elif method == "regression_discontinuity_design":
+ return ["Check if randomized experiment", "Data is observational",
+ "Check for discontinuity", "Discontinuity exists"]
+ elif method == "difference_in_differences":
+ return ["Check if randomized experiment", "Data is observational",
+ "Check for temporal structure", "Panel data structure exists"]
+ elif method == "frontdoor_adjustment":
+ return ["Check if randomized experiment", "Data is observational",
+ "Check front-door criterion", "Front-door path identified"]
+ elif method == "diff_in_means":
+ return ["Check if randomized experiment", "Pure RCT without covariates"]
+ else:
+ return ["Default method selection"]
\ No newline at end of file
diff --git a/auto_causal/components/decision_tree_llm.py b/auto_causal/components/decision_tree_llm.py
new file mode 100644
index 0000000000000000000000000000000000000000..41d92d6772b217ead47e6fee296841526d52574e
--- /dev/null
+++ b/auto_causal/components/decision_tree_llm.py
@@ -0,0 +1,218 @@
+"""
+LLM-based Decision tree component for selecting causal inference methods.
+
+This module implements the decision tree logic via an LLM prompt
+to select the most appropriate causal inference method based on
+dataset characteristics and available variables.
+"""
+
+import logging
+import json
+from typing import Dict, Any, Optional, List
+
+from langchain_core.messages import HumanMessage
+from langchain_core.language_models import BaseChatModel
+
+# Import constants and assumptions from the original decision_tree module
+from .decision_tree import (
+ METHOD_ASSUMPTIONS,
+ BACKDOOR_ADJUSTMENT,
+ LINEAR_REGRESSION,
+ DIFF_IN_MEANS,
+ DIFF_IN_DIFF,
+ REGRESSION_DISCONTINUITY,
+ PROPENSITY_SCORE_MATCHING,
+ INSTRUMENTAL_VARIABLE,
+ CORRELATION_ANALYSIS,
+ PROPENSITY_SCORE_WEIGHTING,
+ GENERALIZED_PROPENSITY_SCORE
+)
+
+# Configure logging
+logger = logging.getLogger(__name__)
+
+# Define a list of all known methods for the LLM prompt
+ALL_METHODS = [
+ DIFF_IN_MEANS,
+ LINEAR_REGRESSION,
+ DIFF_IN_DIFF,
+ REGRESSION_DISCONTINUITY,
+ INSTRUMENTAL_VARIABLE,
+ PROPENSITY_SCORE_MATCHING,
+ PROPENSITY_SCORE_WEIGHTING,
+ GENERALIZED_PROPENSITY_SCORE,
+ BACKDOOR_ADJUSTMENT, # Often a general approach rather than a specific model.
+ CORRELATION_ANALYSIS,
+]
+
+METHOD_DESCRIPTIONS_FOR_LLM = {
+ DIFF_IN_MEANS: "Appropriate for Randomized Controlled Trials (RCTs) with no covariates. Compares the average outcome between treated and control groups.",
+ LINEAR_REGRESSION: "Can be used for RCTs with covariates to increase precision, or for observational data assuming linear relationships and no unmeasured confounders. Models the outcome as a linear function of treatment and covariates.",
+ DIFF_IN_DIFF: "Suitable for observational data with a temporal structure (e.g., panel data with pre/post treatment periods). Requires the 'parallel trends' assumption: treatment and control groups would have followed similar trends in the outcome in the absence of treatment.",
+ REGRESSION_DISCONTINUITY: "Applicable when treatment assignment is determined by whether an observed 'running variable' crosses a specific cutoff point. Assumes individuals cannot precisely manipulate the running variable.",
+ INSTRUMENTAL_VARIABLE: "Used when there's an 'instrument' variable that is correlated with the treatment, affects the outcome only through the treatment, and is not confounded with the outcome. Useful for handling unobserved confounding.",
+ PROPENSITY_SCORE_MATCHING: "For observational data with covariates. Estimates the probability of receiving treatment (propensity score) for each unit and then matches treated and control units with similar scores. Aims to create balanced groups.",
+ PROPENSITY_SCORE_WEIGHTING: "Similar to PSM, for observational data with covariates. Uses propensity scores to weight units to create a pseudo-population where confounders are balanced. Can estimate ATE, ATT, or ATC.",
+ GENERALIZED_PROPENSITY_SCORE: "An extension of propensity scores for continuous treatment variables. Aims to estimate the dose-response function, assuming unconfoundedness given covariates.",
+ BACKDOOR_ADJUSTMENT: "A general strategy for causal inference in observational studies that involves statistically controlling for all common causes (confounders) of the treatment and outcome. Specific methods like regression or matching implement this.",
+ CORRELATION_ANALYSIS: "A fallback method when causal inference is not feasible due to data limitations (e.g., no clear design, no covariates for adjustment). Measures the statistical association between variables, but does not imply causation."
+}
+
+
+class DecisionTreeLLMEngine:
+ """
+ Engine for applying an LLM-based decision tree to select appropriate causal methods.
+ """
+
+ def __init__(self, verbose: bool = False):
+ """
+ Initialize the LLM decision tree engine.
+
+ Args:
+ verbose: Whether to print verbose information.
+ """
+ self.verbose = verbose
+
+ def _construct_prompt(self, dataset_analysis: Dict[str, Any], variables: Dict[str, Any], is_rct: bool, excluded_methods: Optional[List[str]] = None) -> str:
+ """
+ Constructs the detailed prompt for the LLM.
+ """
+ # Filter out excluded methods
+ excluded_methods = excluded_methods or []
+ available_methods = [method for method in ALL_METHODS if method not in excluded_methods]
+ methods_list_str = "\n".join([f"- {method}: {METHOD_DESCRIPTIONS_FOR_LLM[method]}" for method in available_methods if method in METHOD_DESCRIPTIONS_FOR_LLM])
+
+ excluded_info = ""
+ if excluded_methods:
+ excluded_info = f"\nEXCLUDED METHODS (do not select these): {', '.join(excluded_methods)}\nReason: These methods failed validation in previous attempts.\n"
+
+ prompt = f"""You are an expert in causal inference. Your task is to select the most appropriate causal inference method based on the provided dataset analysis and variable information.
+
+Dataset Analysis:
+{json.dumps(dataset_analysis, indent=2)}
+
+Identified Variables:
+{json.dumps(variables, indent=2)}
+
+Is the data from a Randomized Controlled Trial (RCT)? {'Yes' if is_rct else 'No'}{excluded_info}
+
+Available Causal Inference Methods and their descriptions:
+{methods_list_str}
+
+Instructions:
+1. Carefully review all the provided information: dataset analysis, variables, and RCT status.
+2. Reason step-by-step to determine the most suitable method. Consider the hierarchy of methods (e.g., specific designs like DiD, RDD, IV before general adjustment methods).
+3. Explain your reasoning for selecting a particular method.
+4. Identify any potential alternative methods if applicable.
+5. State the key assumptions for your *selected* method by referring to the general list of assumptions for all methods that will be provided to you separately (you don't need to list them here, just be aware that you need to select a method for which assumptions are known).
+
+Output your final decision as a JSON object with the following exact keys:
+- "selected_method": string (must be one of {', '.join(available_methods)})
+- "method_justification": string (your detailed reasoning)
+- "alternative_methods": list of strings (alternative method names, can be empty)
+
+Example JSON output format:
+{{
+ "selected_method": "difference_in_differences",
+ "method_justification": "The dataset has a clear time variable and group variable, indicating a panel structure suitable for DiD. The parallel trends assumption will need to be checked.",
+ "alternative_methods": ["instrumental_variable"]
+}}
+
+Please provide only the JSON object in your response.
+"""
+ return prompt
+
+ def select_method_llm(self, dataset_analysis: Dict[str, Any], variables: Dict[str, Any], is_rct: bool = False, llm: Optional[BaseChatModel] = None, excluded_methods: Optional[List[str]] = None) -> Dict[str, Any]:
+ """
+ Apply LLM-based decision tree to select appropriate causal method.
+
+ Args:
+ dataset_analysis: Dataset analysis results.
+ variables: Identified variables from query_interpreter.
+ is_rct: Boolean indicating if the data comes from an RCT.
+ llm: Langchain BaseChatModel instance for making the call.
+ excluded_methods: Optional list of method names to exclude from selection.
+
+ Returns:
+ Dict with selected method, justification, and assumptions.
+ Example:
+ {{
+ "selected_method": "difference_in_differences",
+ "method_justification": "Reasoning...",
+ "method_assumptions": ["Assumption 1", ...],
+ "alternative_methods": ["instrumental_variable"]
+ }}
+ """
+ if not llm:
+ logger.error("LLM client not provided to DecisionTreeLLMEngine. Cannot select method.")
+ return {
+ "selected_method": CORRELATION_ANALYSIS,
+ "method_justification": "LLM client not provided. Defaulting to Correlation Analysis as causal inference method selection is not possible. This indicates association, not causation.",
+ "method_assumptions": METHOD_ASSUMPTIONS.get(CORRELATION_ANALYSIS, []),
+ "alternative_methods": []
+ }
+
+ prompt = self._construct_prompt(dataset_analysis, variables, is_rct, excluded_methods)
+ if self.verbose:
+ logger.info("LLM Prompt for method selection:")
+ logger.info(prompt)
+
+ messages = [HumanMessage(content=prompt)]
+
+ llm_output_str = "" # Initialize llm_output_str here
+ try:
+ response = llm.invoke(messages)
+ llm_output_str = response.content.strip()
+
+ if self.verbose:
+ logger.info(f"LLM Raw Output: {llm_output_str}")
+
+ # Attempt to parse the JSON output
+ # The LLM might sometimes include explanations outside the JSON block.
+ # Try to extract JSON from within ```json ... ``` if present.
+ if "```json" in llm_output_str:
+ json_str = llm_output_str.split("```json")[1].split("```")[0].strip()
+ elif "```" in llm_output_str and llm_output_str.startswith("{") == False : # if it doesn't start with { then likely ```{}```
+ json_str = llm_output_str.split("```")[1].strip()
+ else: # Assume the entire string is the JSON if no triple backticks
+ json_str = llm_output_str
+
+ parsed_response = json.loads(json_str)
+
+ selected_method = parsed_response.get("selected_method")
+ justification = parsed_response.get("method_justification", "No justification provided by LLM.")
+ alternatives = parsed_response.get("alternative_methods", [])
+
+ if selected_method and selected_method in METHOD_ASSUMPTIONS:
+ logger.info(f"LLM selected method: {selected_method}")
+ return {
+ "selected_method": selected_method,
+ "method_justification": justification,
+ "method_assumptions": METHOD_ASSUMPTIONS[selected_method],
+ "alternative_methods": alternatives
+ }
+ else:
+ logger.warning(f"LLM selected an invalid or unknown method: '{selected_method}'. Or method not in METHOD_ASSUMPTIONS. Raw response: {llm_output_str}")
+ fallback_justification = f"LLM output was problematic (selected: {selected_method}). Defaulting to Correlation Analysis. LLM Raw Response: {llm_output_str}"
+ selected_method = CORRELATION_ANALYSIS
+ justification = fallback_justification
+
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse JSON response from LLM: {e}. Raw response: {llm_output_str}", exc_info=True)
+ fallback_justification = f"LLM response was not valid JSON. Defaulting to Correlation Analysis. Error: {e}. LLM Raw Response: {llm_output_str}"
+ selected_method = CORRELATION_ANALYSIS
+ justification = fallback_justification
+ alternatives = []
+ except Exception as e:
+ logger.error(f"Error during LLM call for method selection: {e}. Raw response: {llm_output_str}", exc_info=True)
+ fallback_justification = f"An unexpected error occurred during LLM method selection. Defaulting to Correlation Analysis. Error: {e}. LLM Raw Response: {llm_output_str}"
+ selected_method = CORRELATION_ANALYSIS
+ justification = fallback_justification
+ alternatives = []
+
+ return {
+ "selected_method": selected_method,
+ "method_justification": justification,
+ "method_assumptions": METHOD_ASSUMPTIONS.get(selected_method, []),
+ "alternative_methods": alternatives
+ }
\ No newline at end of file
diff --git a/auto_causal/components/explanation_generator.py b/auto_causal/components/explanation_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..31bfc643dfe6f4c502b2c078ffde36a6d0eb5ec1
--- /dev/null
+++ b/auto_causal/components/explanation_generator.py
@@ -0,0 +1,404 @@
+"""
+Explanation generator component for causal inference methods.
+
+This module generates explanations for causal inference methods, including
+what the method does, its assumptions, and how it will be applied to the dataset.
+"""
+
+from typing import Dict, Any, List, Optional
+from langchain_core.language_models import BaseChatModel # For LLM type hint
+
+
+def generate_explanation(
+ method_info: Dict[str, Any],
+ validation_result: Dict[str, Any],
+ variables: Dict[str, Any],
+ results: Dict[str, Any],
+ dataset_analysis: Optional[Dict[str, Any]] = None,
+ dataset_description: Optional[str] = None,
+ llm: Optional[BaseChatModel] = None
+) -> Dict[str, str]:
+ """
+ Generates a comprehensive explanation text for the causal analysis.
+
+ Args:
+ method_info: Dictionary containing selected method details.
+ validation_result: Dictionary containing method validation results.
+ variables: Dictionary containing identified variables.
+ results: Dictionary containing numerical results from the method execution.
+ dataset_analysis: Optional dictionary with dataset analysis details.
+ dataset_description: Optional string describing the dataset.
+ llm: Optional language model instance (for potential future use in generation).
+
+ Returns:
+ Dictionary containing the final explanation text.
+ """
+ method = method_info.get("method_name")
+
+ # Handle potential None for validation_result
+ if validation_result and validation_result.get("valid") is False:
+ method = validation_result.get("recommended_method", method)
+
+ # Get components
+ method_explanation = get_method_explanation(method)
+ assumption_explanations = explain_assumptions(method_info.get("assumptions", []))
+ application_explanation = explain_application(method, variables.get("treatment_variable"),
+ variables.get("outcome_variable"),
+ variables.get("covariates", []), variables)
+ limitations_explanation = explain_limitations(method, validation_result.get("concerns", []) if validation_result else [])
+ interpretation_guide = generate_interpretation_guide(method, variables.get("treatment_variable"),
+ variables.get("outcome_variable"))
+
+ # --- Extract Numerical Results ---
+ effect_estimate = results.get("effect_estimate")
+ effect_se = results.get("effect_se")
+ ci = results.get("confidence_interval")
+ p_value = results.get("p_value") # Assuming method executor returns p_value
+
+ # --- Assemble Final Text ---
+ final_text = f"**Method Used:** {method_info.get('method_name', method)}\n\n"
+ final_text += f"**Method Explanation:**\n{method_explanation}\n\n"
+
+ # Add Results Section
+ final_text += "**Results:**\n"
+ if effect_estimate is not None:
+ final_text += f"- Estimated Causal Effect: {effect_estimate:.4f}\n"
+ if effect_se is not None:
+ final_text += f"- Standard Error: {effect_se:.4f}\n"
+ if ci and ci[0] is not None and ci[1] is not None:
+ final_text += f"- 95% Confidence Interval: [{ci[0]:.4f}, {ci[1]:.4f}]\n"
+ if p_value is not None:
+ final_text += f"- P-value: {p_value:.4f}\n"
+ final_text += "\n"
+
+ final_text += f"**Interpretation Guide:**\n{interpretation_guide}\n\n"
+ final_text += f"**Assumptions:**\n"
+ for item in assumption_explanations:
+ final_text += f"- {item['assumption']}: {item['explanation']}\n"
+ final_text += "\n"
+ final_text += f"**Limitations:**\n{limitations_explanation}\n\n"
+
+ return {
+ "final_explanation_text": final_text
+ # Return only the final text, the tool wrapper adds workflow state
+ }
+
+
+def get_method_explanation(method: str) -> str:
+ """
+ Get explanation for what the method does.
+
+ Args:
+ method: Causal inference method name
+
+ Returns:
+ String explaining what the method does
+ """
+ explanations = {
+ "propensity_score_matching": (
+ "Propensity Score Matching is a statistical technique that attempts to estimate the effect "
+ "of a treatment by accounting for covariates that predict receiving the treatment. "
+ "It creates matched sets of treated and untreated subjects who share similar characteristics, "
+ "allowing for a more fair comparison between groups."
+ ),
+ "regression_adjustment": (
+ "Regression Adjustment is a method that uses regression models to estimate causal effects "
+ "by controlling for covariates. It models the outcome as a function of the treatment and "
+ "other potential confounding variables, allowing the isolation of the treatment effect."
+ ),
+ "instrumental_variable": (
+ "The Instrumental Variable method addresses issues of endogeneity or unmeasured confounding "
+ "by using an 'instrument' - a variable that affects the treatment but not the outcome directly. "
+ "It effectively finds the natural experiment hidden in your data to estimate causal effects."
+ ),
+ "difference_in_differences": (
+ "Difference-in-Differences compares the changes in outcomes over time between a group that "
+ "receives a treatment and a group that does not. It controls for time-invariant unobserved "
+ "confounders by looking at differences in trends rather than absolute values."
+ ),
+ "regression_discontinuity": (
+ "Regression Discontinuity Design exploits a threshold or cutoff rule that determines treatment "
+ "assignment. By comparing observations just above and below this threshold, where treatment "
+ "status changes but other characteristics remain similar, it estimates the local causal effect."
+ ),
+ "backdoor_adjustment": (
+ "Backdoor Adjustment controls for confounding variables that create 'backdoor paths' between "
+ "treatment and outcome variables in a causal graph. By conditioning on these variables, "
+ "it blocks the non-causal associations, allowing for identification of the causal effect."
+ ),
+ }
+
+ return explanations.get(method,
+ f"The {method} method is a causal inference technique used to estimate "
+ f"causal effects from observational data.")
+
+
+def explain_assumptions(assumptions: List[str]) -> List[Dict[str, str]]:
+ """
+ Explain each assumption of the method.
+
+ Args:
+ assumptions: List of assumption names
+
+ Returns:
+ List of dictionaries with assumption name and explanation
+ """
+ assumption_details = {
+ "Treatment is randomly assigned": (
+ "This assumes that treatment assignment is not influenced by any factors "
+ "related to the outcome, similar to a randomized controlled trial. "
+ "In observational data, this assumption rarely holds without conditioning on confounders."
+ ),
+ "No systematic differences between treatment and control groups": (
+ "Treatment and control groups should be balanced on all relevant characteristics "
+ "except for the treatment itself. Any systematic differences could bias the estimate."
+ ),
+ "No unmeasured confounders (conditional ignorability)": (
+ "All variables that simultaneously affect the treatment and outcome are measured and "
+ "included in the analysis. If important confounders are missing, the estimated causal "
+ "effect will be biased."
+ ),
+ "Sufficient overlap between treatment and control groups": (
+ "For each combination of covariate values, there should be both treated and untreated "
+ "units. Without overlap, the model must extrapolate, which can lead to biased estimates."
+ ),
+ "Treatment assignment is not deterministic given covariates": (
+ "No combination of covariates should perfectly predict treatment assignment. "
+ "If treatment is deterministic for some units, causal comparisons become impossible."
+ ),
+ "Instrument is correlated with treatment (relevance)": (
+ "The instrumental variable must have a clear and preferably strong effect on the "
+ "treatment variable. Weak instruments lead to imprecise and potentially biased estimates."
+ ),
+ "Instrument affects outcome only through treatment (exclusion restriction)": (
+ "The instrumental variable must not directly affect the outcome except through its "
+ "effect on the treatment. If this assumption fails, the causal estimate will be biased."
+ ),
+ "Instrument is as good as randomly assigned (exogeneity)": (
+ "The instrumental variable must not be correlated with any confounders of the "
+ "treatment-outcome relationship. It should be as good as randomly assigned."
+ ),
+ "Parallel trends between treatment and control groups": (
+ "In the absence of treatment, the difference between treatment and control groups "
+ "would have remained constant over time. This is the key identifying assumption for "
+ "difference-in-differences and cannot be directly tested for the post-treatment period."
+ ),
+ "No spillover effects between groups": (
+ "The treatment of one unit should not affect the outcomes of other units. "
+ "If spillovers exist, they can bias the estimated treatment effect."
+ ),
+ "No anticipation effects before treatment": (
+ "Units should not change their behavior in anticipation of future treatment. "
+ "If anticipation effects exist, the pre-treatment trends may already reflect treatment effects."
+ ),
+ "Stable composition of treatment and control groups": (
+ "The composition of treatment and control groups should remain stable over time. "
+ "If units move between groups based on outcomes, this can bias the estimates."
+ ),
+ "Units cannot precisely manipulate their position around the cutoff": (
+ "In regression discontinuity, units must not be able to precisely control their position "
+ "relative to the cutoff. If they can, the randomization-like property of the design fails."
+ ),
+ "No other variables change discontinuously at the cutoff": (
+ "Any discontinuity in outcomes at the cutoff should be attributable only to the change "
+ "in treatment status. If other relevant variables also change at the cutoff, the causal "
+ "interpretation is compromised."
+ ),
+ "The relationship between running variable and outcome is continuous at the cutoff": (
+ "In the absence of treatment, the relationship between the running variable and the "
+ "outcome would be continuous at the cutoff. This allows attributing any observed "
+ "discontinuity to the treatment effect."
+ ),
+ "The model correctly specifies the relationship between variables": (
+ "The functional form of the relationship between variables in the model should correctly "
+ "capture the true relationship in the data. Misspecification can lead to biased estimates."
+ ),
+ "No reverse causality": (
+ "The treatment must cause the outcome, not the other way around. If the outcome affects "
+ "the treatment, the estimated relationship will not have a causal interpretation."
+ ),
+ }
+
+ return [
+ {"assumption": assumption, "explanation": assumption_details.get(assumption,
+ "This is a key assumption for the selected causal inference method.")}
+ for assumption in assumptions
+ ]
+
+
+def explain_application(method: str, treatment: str, outcome: str,
+ covariates: List[str], variables: Dict[str, Any]) -> str:
+ """
+ Explain how the method will be applied to the dataset.
+
+ Args:
+ method: Causal inference method name
+ treatment: Treatment variable name
+ outcome: Outcome variable name
+ covariates: List of covariate names
+ variables: Dictionary of identified variables
+
+ Returns:
+ String explaining the application
+ """
+ covariate_str = ", ".join(covariates[:3])
+ if len(covariates) > 3:
+ covariate_str += f", and {len(covariates) - 3} other variables"
+
+ applications = {
+ "propensity_score_matching": (
+ f"I will estimate the propensity scores (probability of receiving treatment) for each "
+ f"observation based on the covariates ({covariate_str}). Then, I'll match treated and "
+ f"untreated units with similar propensity scores to create balanced comparison groups. "
+ f"Finally, I'll calculate the difference in {outcome} between these matched groups to "
+ f"estimate the causal effect of {treatment}."
+ ),
+ "regression_adjustment": (
+ f"I will build a regression model with {outcome} as the dependent variable and "
+ f"{treatment} as the independent variable of interest, while controlling for "
+ f"potential confounders ({covariate_str}). The coefficient of {treatment} will "
+ f"represent the estimated causal effect after adjusting for these covariates."
+ ),
+ "instrumental_variable": (
+ f"I will use {variables.get('instrument_variable')} as an instrumental variable for "
+ f"{treatment}. First, I'll estimate how the instrument affects {treatment} (first stage). "
+ f"Then, I'll use these predictions to estimate how changes in {treatment} that are induced "
+ f"by the instrument affect {outcome} (second stage). This two-stage approach helps "
+ f"address potential unmeasured confounding."
+ ),
+ "difference_in_differences": (
+ f"I will compare the change in {outcome} before and after the intervention for the "
+ f"group receiving {treatment}, relative to the change in a control group that didn't "
+ f"receive the treatment. This approach controls for time-invariant confounders and "
+ f"common time trends that affect both groups."
+ ),
+ "regression_discontinuity": (
+ f"I will focus on observations close to the cutoff value "
+ f"({variables.get('cutoff_value')}) of the running variable "
+ f"({variables.get('running_variable')}), where treatment assignment changes. "
+ f"By comparing outcomes just above and below this threshold, I can estimate "
+ f"the local causal effect of {treatment} on {outcome}."
+ ),
+ "backdoor_adjustment": (
+ f"I will control for the identified confounding variables ({covariate_str}) to "
+ f"block all backdoor paths between {treatment} and {outcome}. This may involve "
+ f"stratification, regression adjustment, or inverse probability weighting, depending "
+ f"on the data characteristics."
+ ),
+ }
+
+ return applications.get(method,
+ f"I will apply the {method} method to estimate the causal effect of "
+ f"{treatment} on {outcome}, controlling for relevant confounding factors "
+ f"where appropriate.")
+
+
+def explain_limitations(method: str, concerns: List[str]) -> str:
+ """
+ Explain the limitations of the method based on validation concerns.
+
+ Args:
+ method: Causal inference method name
+ concerns: List of concerns from validation
+
+ Returns:
+ String explaining the limitations
+ """
+ method_limitations = {
+ "propensity_score_matching": (
+ "Propensity Score Matching can only account for observed confounders, and its "
+ "effectiveness depends on having good overlap between treatment and control groups. "
+ "It may also be sensitive to model specification for the propensity score estimation."
+ ),
+ "regression_adjustment": (
+ "Regression Adjustment relies heavily on correct model specification and can only "
+ "control for observed confounders. Extrapolation to regions with limited data can lead "
+ "to unreliable estimates, and the method may be sensitive to outliers."
+ ),
+ "instrumental_variable": (
+ "Instrumental Variable estimation can be imprecise with weak instruments and is "
+ "sensitive to violations of the exclusion restriction. The estimated effect is a local "
+ "average treatment effect for 'compliers', which may not generalize to the entire population."
+ ),
+ "difference_in_differences": (
+ "Difference-in-Differences relies on the parallel trends assumption, which cannot be fully "
+ "tested for the post-treatment period. It may be sensitive to the choice of comparison group "
+ "and can be biased if there are time-varying confounders or anticipation effects."
+ ),
+ "regression_discontinuity": (
+ "Regression Discontinuity provides estimates that are local to the cutoff point and may not "
+ "generalize to units far from this threshold. It also requires sufficient data around the "
+ "cutoff and is sensitive to the choice of bandwidth and functional form."
+ ),
+ "backdoor_adjustment": (
+ "Backdoor Adjustment requires correctly identifying all confounding variables and their "
+ "relationships. It depends on the assumption of no unmeasured confounders and may be "
+ "sensitive to model misspecification in complex settings."
+ ),
+ }
+
+ base_limitation = method_limitations.get(method,
+ f"The {method} method has general limitations in terms of its assumptions and applicability.")
+
+ # Add specific concerns if any
+ if concerns:
+ concern_text = " Additionally, specific concerns for this analysis include: " + \
+ "; ".join(concerns) + "."
+ return base_limitation + concern_text
+
+ return base_limitation
+
+
+def generate_interpretation_guide(method: str, treatment: str, outcome: str) -> str:
+ """
+ Generate guide for interpreting the results.
+
+ Args:
+ method: Causal inference method name
+ treatment: Treatment variable name
+ outcome: Outcome variable name
+
+ Returns:
+ String with interpretation guide
+ """
+ interpretation_guides = {
+ "propensity_score_matching": (
+ f"The estimated effect represents the Average Treatment Effect (ATE) or the Average "
+ f"Treatment Effect on the Treated (ATT), depending on the specific matching approach. "
+ f"It can be interpreted as the expected change in {outcome} if a unit were to receive "
+ f"{treatment}, compared to not receiving it, for units with similar covariate values."
+ ),
+ "regression_adjustment": (
+ f"The coefficient of {treatment} in the regression model represents the estimated "
+ f"average causal effect on {outcome}, holding all included covariates constant. "
+ f"For binary treatments, it's the expected difference in outcomes between treated "
+ f"and untreated units with the same covariate values."
+ ),
+ "instrumental_variable": (
+ f"The estimated effect represents the Local Average Treatment Effect (LATE) for 'compliers' "
+ f"- units whose treatment status is influenced by the instrument. It can be interpreted as "
+ f"the average effect of {treatment} on {outcome} for this specific subpopulation."
+ ),
+ "difference_in_differences": (
+ f"The estimated effect represents the average causal impact of {treatment} on {outcome}, "
+ f"under the assumption that treatment and control groups would have followed parallel "
+ f"trends in the absence of treatment. It accounts for both time-invariant differences "
+ f"between groups and common time trends."
+ ),
+ "regression_discontinuity": (
+ f"The estimated effect represents the local causal impact of {treatment} on {outcome} "
+ f"at the cutoff point. It can be interpreted as the expected difference in outcomes "
+ f"for units just above versus just below the threshold, where treatment status changes."
+ ),
+ "backdoor_adjustment": (
+ f"The estimated effect represents the average causal effect of {treatment} on {outcome} "
+ f"after controlling for all identified confounding variables. It can be interpreted as "
+ f"the expected difference in outcomes if a unit were to receive versus not receive the "
+ f"treatment, holding all confounding factors constant."
+ ),
+ }
+
+ return interpretation_guides.get(method,
+ f"The estimated effect represents the causal impact of {treatment} on {outcome}, "
+ f"given the assumptions of the method are met. Careful consideration of these "
+ f"assumptions is needed for valid causal interpretation.")
\ No newline at end of file
diff --git a/auto_causal/components/input_parser.py b/auto_causal/components/input_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..394efa8a8a5071a490d865a06308db6821e2b193
--- /dev/null
+++ b/auto_causal/components/input_parser.py
@@ -0,0 +1,456 @@
+"""
+Input parser component for extracting information from causal queries.
+
+This module provides functionality to parse user queries and extract key
+elements such as the causal question, relevant variables, and constraints.
+"""
+
+import re
+import os
+import json
+import logging # Added for better logging
+from typing import Dict, List, Any, Optional, Union
+import pandas as pd
+from pydantic import BaseModel, Field, ValidationError
+from functools import partial # Import partial
+
+# Add dotenv import
+from dotenv import load_dotenv
+
+# LangChain Imports
+from langchain_openai import ChatOpenAI # Example, replace if using another provider
+from langchain_core.messages import HumanMessage, SystemMessage
+from langchain_core.exceptions import OutputParserException # Correct path
+from langchain_core.language_models import BaseChatModel # Import BaseChatModel
+
+# --- Load .env file ---
+load_dotenv() # Load environment variables from .env file
+
+# --- Configure Logging ---
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+# --- Instantiate LLM Client ---
+# Ensure OPENAI_API_KEY environment variable is set
+# Consider making model name configurable
+try:
+ # Using with_structured_output later, so instantiate base model here
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
+ # Add a check or allow configuration for different providers if needed
+except ImportError:
+ logger.error("langchain_openai not installed. Please install it to use OpenAI models.")
+ llm = None
+except Exception as e:
+ logger.error(f"Error initializing LLM: {e}. Input parsing will rely on fallbacks.")
+ llm = None
+
+# --- Pydantic Models for Structured Output ---
+class ParsedVariables(BaseModel):
+ treatment: List[str] = Field(default_factory=list, description="Variable(s) representing the treatment/intervention.")
+ outcome: List[str] = Field(default_factory=list, description="Variable(s) representing the outcome/result.")
+ covariates_mentioned: Optional[List[str]] = Field(default_factory=list, description="Covariate/control variable(s) explicitly mentioned in the query.")
+ grouping_vars: Optional[List[str]] = Field(default_factory=list, description="Variable(s) identifying groups or units for analysis.")
+ instruments_mentioned: Optional[List[str]] = Field(default_factory=list, description="Potential instrumental variable(s) mentioned.")
+
+class ParsedQueryInfo(BaseModel):
+ query_type: str = Field(..., description="Type of query (e.g., EFFECT_ESTIMATION, COUNTERFACTUAL, CORRELATION, DESCRIPTIVE, OTHER). Required.")
+ variables: ParsedVariables = Field(..., description="Variables identified in the query.")
+ constraints: Optional[List[str]] = Field(default_factory=list, description="Constraints or conditions mentioned (e.g., 'X > 10', 'country = USA').")
+ dataset_path_mentioned: Optional[str] = Field(None, description="Dataset path explicitly mentioned in the query, if any.")
+
+# Add Pydantic model for path extraction
+class ExtractedPath(BaseModel):
+ dataset_path: Optional[str] = Field(None, description="File path or URL for the dataset mentioned in the query.")
+
+# --- End Pydantic Models ---
+
+def _build_llm_prompt(query: str, dataset_info: Optional[Dict] = None) -> str:
+ """Builds the prompt for the LLM to extract query information."""
+ dataset_context = "No dataset context provided."
+ if dataset_info:
+ columns = dataset_info.get('columns', [])
+ column_details = "\n".join([f"- {col} (Type: {dataset_info.get('column_types', {}).get(col, 'Unknown')})" for col in columns])
+ sample_rows = dataset_info.get('sample_rows', 'Not available')
+ # Ensure sample rows are formatted reasonably
+ if isinstance(sample_rows, list):
+ sample_rows_str = json.dumps(sample_rows[:3], indent=2) # Show first 3 sample rows
+ elif isinstance(sample_rows, str):
+ sample_rows_str = sample_rows
+ else:
+ sample_rows_str = 'Not available'
+
+ dataset_context = f"""
+Dataset Context:
+Columns:
+{column_details}
+Sample Rows (first few):
+{sample_rows_str}
+"""
+
+ prompt = f"""
+Analyze the following causal query **strictly in the context of the provided dataset information (if available)**. Identify the query type, key variables (mapping query terms to actual column names when possible), constraints, and any explicitly mentioned dataset path.
+
+User Query: "{query}"
+
+{dataset_context}
+
+# Add specific guidance for query types
+Guidance for Identifying Query Type:
+- EFFECT_ESTIMATION: Look for keywords like 'effect', 'impact', 'influence', 'cause', 'affect', 'consequence'. Also consider questions asking "how does X affect Y?" or comparing outcomes between groups based on an intervention.
+- COUNTERFACTUAL: Look for hypothetical scenarios, often using phrases like 'what if', 'if X had been', 'would Y have changed', 'imagine if', 'counterfactual'.
+- CORRELATION: Look for keywords like 'correlation', 'association', 'relationship', 'linked to', 'related to'. These queries ask about statistical relationships without necessarily implying causality.
+- DESCRIPTIVE: These queries ask for summaries, descriptions, trends, or statistics about the data without investigating causal links or relationships (e.g., "Show sales over time", "What is the average age?").
+- OTHER: Use this if the query does not fit any of the above categories.
+
+Choose the most appropriate type from: EFFECT_ESTIMATION, COUNTERFACTUAL, CORRELATION, DESCRIPTIVE, OTHER.
+
+Variable Roles to Identify:
+- treatment: The intervention or variable whose effect is being studied.
+- outcome: The result or variable being measured.
+- covariates_mentioned: Variables explicitly mentioned to control for or adjust for.
+- grouping_vars: Variables identifying specific subgroups for analysis (e.g., 'for men', 'in the sales department').
+- instruments_mentioned: Variables explicitly mentioned as potential instruments.
+
+Constraints: Conditions applied to the analysis (e.g., filters on columns, specific time periods).
+
+Dataset Path Mentioned: Extract the file path or URL if explicitly stated in the query.
+
+**Output ONLY a valid JSON object** matching this exact schema (no explanations, notes, or surrounding text):
+```json
+{{
+ "query_type": "",
+ "variables": {{
+ "treatment": [""],
+ "outcome": [""],
+ "covariates_mentioned": [""],
+ "grouping_vars": [""],
+ "instruments_mentioned": [""]
+ }},
+ "constraints": ["", ""],
+ "dataset_path_mentioned": ""
+}}
+```
+If Dataset Context is provided, ensure variable names in the output JSON correspond to actual column names where possible. If no context is provided, or if a mentioned variable doesn't map directly, use the phrasing from the query.
+Respond with only the JSON object.
+"""
+ return prompt
+
+def _validate_llm_output(parsed_info: ParsedQueryInfo, dataset_info: Optional[Dict] = None) -> bool:
+ """Perform basic assertions on the parsed LLM output."""
+ # 1. Check required fields exist (Pydantic handles this on parsing)
+ # 2. Check query type is one of the allowed types (can add enum to Pydantic later)
+ allowed_types = {"EFFECT_ESTIMATION", "COUNTERFACTUAL", "CORRELATION", "DESCRIPTIVE", "OTHER"}
+ print(parsed_info)
+ assert parsed_info.query_type in allowed_types, f"Invalid query_type: {parsed_info.query_type}"
+
+ # 3. Check that if it's an effect query, treatment and outcome are likely present
+ if parsed_info.query_type == "EFFECT_ESTIMATION":
+ # Check that the lists are not empty
+ assert parsed_info.variables.treatment, "Treatment variable list is empty for effect query."
+ assert parsed_info.variables.outcome, "Outcome variable list is empty for effect query."
+
+ # 4. If dataset_info provided, check if extracted variables exist in columns
+ if dataset_info and (columns := dataset_info.get('columns')):
+ all_extracted_vars = set()
+ for var_list in parsed_info.variables.model_dump().values(): # Iterate through variable lists
+ if var_list: # Ensure var_list is not None or empty
+ all_extracted_vars.update(var_list)
+
+ unknown_vars = all_extracted_vars - set(columns)
+ # Allow for non-column variables if context is missing? Maybe relax this.
+ # For now, strict check if columns are provided.
+ if unknown_vars:
+ logger.warning(f"LLM mentioned variables potentially not in dataset columns: {unknown_vars}")
+ # Decide if this should be a hard failure (AssertionError) or just a warning.
+ # Let's make it a hard failure for now to enforce mapping.
+ raise AssertionError(f"LLM hallucinated variables not in dataset columns: {unknown_vars}")
+
+ logger.info("LLM output validation passed.")
+ return True
+
+def _extract_query_information_with_llm(query: str, dataset_info: Optional[Dict] = None, llm: Optional[BaseChatModel] = None, max_retries: int = 3) -> Optional[ParsedQueryInfo]:
+ """Extracts query type, variables, and constraints using LLM with retries and validation."""
+ if not llm:
+ logger.error("LLM client not provided. Cannot perform LLM extraction.")
+ return None
+
+ last_error = None
+ # Bind the Pydantic model to the LLM for structured output
+ structured_llm = llm.with_structured_output(ParsedQueryInfo)
+
+ # Initial prompt construction
+ system_prompt_content = _build_llm_prompt(query, dataset_info)
+ messages = [HumanMessage(content=system_prompt_content)] # Start with just the detailed prompt as Human message
+
+ for attempt in range(max_retries):
+ logger.info(f"LLM Extraction Attempt {attempt + 1}/{max_retries}...")
+ try:
+ # --- Invoke LangChain LLM with structured output (using passed llm) ---
+ parsed_info = structured_llm.invoke(messages)
+ # ---------------------------------------------------
+ print(messages)
+ print('---------------------------------------------------')
+ print(parsed_info)
+ # Perform custom assertions/validation
+ if _validate_llm_output(parsed_info, dataset_info):
+ return parsed_info # Success!
+
+ # Catch errors specific to structured output parsing or Pydantic validation
+ except (OutputParserException, ValidationError, AssertionError) as e:
+ logger.warning(f"Validation/Parsing Error (Attempt {attempt + 1}): {e}")
+ last_error = e
+ # Add feedback message for retry
+ messages.append(SystemMessage(content=f"Your previous response failed validation: {str(e)}. Please revise your response to be valid JSON conforming strictly to the schema and ensure variable names exist in the dataset context."))
+ continue # Go to next retry
+ except Exception as e: # Catch other potential LLM API errors
+ logger.error(f"Unexpected LLM Error (Attempt {attempt + 1}): {e}", exc_info=True)
+ last_error = e
+ break # Stop retrying on unexpected API errors
+
+ logger.error(f"LLM extraction failed after {max_retries} attempts.")
+ if last_error:
+ logger.error(f"Last error: {last_error}")
+ return None # Indicate failure
+
+# Add helper function to call LLM for path - needs llm argument
+def _call_llm_for_path(query: str, llm: Optional[BaseChatModel] = None, max_retries: int = 2) -> Optional[str]:
+ """Uses LLM as a fallback to extract just the dataset path."""
+ if not llm:
+ logger.warning("LLM client not provided. Cannot perform LLM path fallback.")
+ return None
+
+ logger.info("Attempting LLM fallback for dataset path extraction...")
+ path_extractor_llm = llm.with_structured_output(ExtractedPath)
+ prompt = f"Extract the dataset file path (e.g., /path/to/file.csv or https://...) mentioned in the following query. Respond ONLY with the JSON object.\nQuery: \"{query}\""
+ messages = [HumanMessage(content=prompt)]
+ last_error = None
+
+ for attempt in range(max_retries):
+ try:
+ parsed_info = path_extractor_llm.invoke(messages)
+ if parsed_info.dataset_path:
+ logger.info(f"LLM fallback extracted path: {parsed_info.dataset_path}")
+ return parsed_info.dataset_path
+ else:
+ logger.info("LLM fallback did not find a path.")
+ return None # LLM explicitly found no path
+ except (OutputParserException, ValidationError) as e:
+ logger.warning(f"LLM path extraction parsing/validation error (Attempt {attempt+1}): {e}")
+ last_error = e
+ messages.append(SystemMessage(content=f"Parsing Error: {e}. Please ensure you provide valid JSON with only the 'dataset_path' key."))
+ continue
+ except Exception as e:
+ logger.error(f"Unexpected LLM Error during path fallback (Attempt {attempt+1}): {e}", exc_info=True)
+ last_error = e
+ break # Don't retry on unexpected errors
+
+ logger.error(f"LLM path fallback failed after {max_retries} attempts. Last error: {last_error}")
+ return None
+
+# Renamed and modified function for regex path extraction + LLM fallback - needs llm argument
+def extract_dataset_path(query: str, llm: Optional[BaseChatModel] = None) -> Optional[str]:
+ """
+ Extract dataset path from the query using regex patterns, with LLM fallback.
+
+ Args:
+ query: The user's causal question text
+ llm: The shared LLM client instance for fallback.
+
+ Returns:
+ String with dataset path or None if not found
+ """
+ # --- Regex Part (existing logic) ---
+ # Check for common patterns indicating dataset paths
+ path_patterns = [
+ # More specific patterns first
+ r"(?:dataset|data|file) (?:at|in|from|located at) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?", # Handles subdirs in path
+ r"(?:use|using|analyze|analyse) (?:the |)(?:dataset|data|file) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?",
+ # Simpler patterns
+ r"[\"']([^\"']+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"']", # Path in quotes
+ r"([a-zA-Z0-9_/.:-]+[\\/][a-zA-Z0-9_.:-]+\.csv)", # More generic path-like structure ending in .csv
+ r"([^\"\'.,\s]+\.csv)" # Just a .csv file name (least specific)
+ ]
+
+ for pattern in path_patterns:
+ matches = re.search(pattern, query, re.IGNORECASE)
+ if matches:
+ path = matches.group(1).strip()
+
+ # Basic check if it looks like a path
+ if '/' in path or '\\' in path or os.path.exists(path):
+ # Check if this is a valid file path immediately
+ if os.path.exists(path):
+ logger.info(f"Regex found existing path: {path}")
+ return path
+
+ # Check if it's in common data directories
+ data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
+ for data_dir in data_dir_paths:
+ potential_path = os.path.join(data_dir, os.path.basename(path))
+ if os.path.exists(potential_path):
+ logger.info(f"Regex found path in {data_dir}: {potential_path}")
+ return potential_path
+
+ # If not found but looks like a path, return it anyway - let downstream handle non-existence
+ logger.info(f"Regex found potential path (existence not verified): {path}")
+ return path
+ # Else: it might just be a word ending in .csv, ignore unless it exists
+ elif os.path.exists(path):
+ logger.info(f"Regex found existing path (simple pattern): {path}")
+ return path
+
+ # --- LLM Fallback ---
+ logger.info("Regex did not find dataset path. Trying LLM fallback...")
+ llm_fallback_path = _call_llm_for_path(query, llm=llm)
+ if llm_fallback_path:
+ # Optional: Add existence check here too? Or let downstream handle it.
+ # For now, return what LLM found.
+ return llm_fallback_path
+
+ logger.info("No dataset path found via regex or LLM fallback.")
+ return None
+
+def parse_input(query: str, dataset_path_arg: Optional[str] = None, dataset_info: Optional[Dict] = None, llm: Optional[BaseChatModel] = None) -> Dict[str, Any]:
+ """
+ Parse the user's causal query using LLM and regex.
+
+ Args:
+ query: The user's causal question text.
+ dataset_path_arg: Path to dataset if provided directly as an argument.
+ dataset_info: Dictionary with dataset context (columns, types, etc.).
+ llm: The shared LLM client instance.
+
+ Returns:
+ Dict containing parsed query information.
+ """
+ result = {
+ "original_query": query,
+ "dataset_path": dataset_path_arg, # Start with argument path
+ "query_type": "OTHER", # Default values
+ "extracted_variables": {},
+ "constraints": []
+ }
+
+ # --- 1. Use LLM for core NLP tasks ---
+ parsed_llm_info = _extract_query_information_with_llm(query, dataset_info, llm=llm)
+
+ if parsed_llm_info:
+ result["query_type"] = parsed_llm_info.query_type
+ result["extracted_variables"] = {k: v if v is not None else [] for k, v in parsed_llm_info.variables.model_dump().items()}
+ result["constraints"] = parsed_llm_info.constraints if parsed_llm_info.constraints is not None else []
+ llm_mentioned_path = parsed_llm_info.dataset_path_mentioned
+ else:
+ logger.warning("LLM-based query information extraction failed.")
+ llm_mentioned_path = None
+ # Consider falling back to old regex methods here if critical
+ # logger.info("Falling back to regex-based parsing (if implemented).")
+
+ # --- 2. Determine Dataset Path (Hybrid Approach) ---
+ final_dataset_path = dataset_path_arg # Priority 1: Explicit argument
+
+ # Pass llm instance to the path extractor for its fallback mechanism
+ path_extractor = partial(extract_dataset_path, llm=llm)
+
+ if not final_dataset_path:
+ # Priority 2: Path mentioned in query (extracted by main LLM call)
+ if llm_mentioned_path and os.path.exists(llm_mentioned_path):
+ logger.info(f"Using dataset path mentioned by LLM: {llm_mentioned_path}")
+ final_dataset_path = llm_mentioned_path
+ elif llm_mentioned_path: # Check data dirs if path not absolute
+ data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
+ base_name = os.path.basename(llm_mentioned_path)
+ for data_dir in data_dir_paths:
+ potential_path = os.path.join(data_dir, base_name)
+ if os.path.exists(potential_path):
+ logger.info(f"Using dataset path mentioned by LLM (found in {data_dir}): {potential_path}")
+ final_dataset_path = potential_path
+ break
+ if not final_dataset_path:
+ logger.warning(f"LLM mentioned path '{llm_mentioned_path}' but it was not found.")
+
+ if not final_dataset_path:
+ # Priority 3: Path extracted by dedicated Regex + LLM fallback function
+ logger.info("Attempting dedicated dataset path extraction (Regex + LLM Fallback)...")
+ extracted_path = path_extractor(query) # Call the partial function with llm bound
+ if extracted_path:
+ final_dataset_path = extracted_path
+
+ result["dataset_path"] = final_dataset_path
+
+ # Check if a path was found ultimately
+ if not result["dataset_path"]:
+ logger.warning("Could not determine dataset path from query or arguments.")
+ else:
+ logger.info(f"Final dataset path determined: {result['dataset_path']}")
+
+ return result
+
+# --- Old Regex-based functions (Commented out or removed) ---
+# def determine_query_type(query: str) -> str:
+# ... (implementation removed)
+
+# def extract_variables(query: str) -> Dict[str, Any]:
+# ... (implementation removed)
+
+# def detect_constraints(query: str) -> List[str]:
+# ... (implementation removed)
+# --- End Old Functions ---
+
+# Renamed function for regex path extraction
+def extract_dataset_path_regex(query: str) -> Optional[str]:
+ """
+ Extract dataset path from the query using regex patterns.
+
+ Args:
+ query: The user's causal question text
+
+ Returns:
+ String with dataset path or None if not found
+ """
+ # Check for common patterns indicating dataset paths
+ path_patterns = [
+ # More specific patterns first
+ r"(?:dataset|data|file) (?:at|in|from|located at) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?", # Handles subdirs in path
+ r"(?:use|using|analyze|analyse) (?:the |)(?:dataset|data|file) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?",
+ # Simpler patterns
+ r"[\"']([^\"']+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"']", # Path in quotes
+ r"([a-zA-Z0-9_/.:-]+[\\/][a-zA-Z0-9_.:-]+\.csv)", # More generic path-like structure ending in .csv
+ r"([^\"\'.,\s]+\.csv)" # Just a .csv file name (least specific)
+ ]
+
+ for pattern in path_patterns:
+ matches = re.search(pattern, query, re.IGNORECASE)
+ if matches:
+ path = matches.group(1).strip()
+
+ # Basic check if it looks like a path
+ if '/' in path or '\\' in path or os.path.exists(path):
+ # Check if this is a valid file path immediately
+ if os.path.exists(path):
+ logger.info(f"Regex found existing path: {path}")
+ return path
+
+ # Check if it's in common data directories
+ data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
+ # Also check relative to current dir (often useful)
+ # base_name = os.path.basename(path)
+ for data_dir in data_dir_paths:
+ potential_path = os.path.join(data_dir, os.path.basename(path))
+ if os.path.exists(potential_path):
+ logger.info(f"Regex found path in {data_dir}: {potential_path}")
+ return potential_path
+
+ # If not found but looks like a path, return it anyway - let downstream handle non-existence
+ logger.info(f"Regex found potential path (existence not verified): {path}")
+ return path
+ # Else: it might just be a word ending in .csv, ignore unless it exists
+ elif os.path.exists(path):
+ logger.info(f"Regex found existing path (simple pattern): {path}")
+ return path
+
+ # TODO: Optional: Add LLM fallback call here if regex fails
+ # if no path found:
+ # llm_fallback_path = call_llm_for_path(query)
+ # return llm_fallback_path
+
+ return None
\ No newline at end of file
diff --git a/auto_causal/components/method_validator.py b/auto_causal/components/method_validator.py
new file mode 100644
index 0000000000000000000000000000000000000000..57b6512cc4a513d4f0e593f846e706388a6b5820
--- /dev/null
+++ b/auto_causal/components/method_validator.py
@@ -0,0 +1,327 @@
+"""
+Method validator component for causal inference methods.
+
+This module validates the selected causal inference method against
+dataset characteristics and available variables.
+"""
+
+from typing import Dict, List, Any, Optional
+
+
+def validate_method(method_info: Dict[str, Any], dataset_analysis: Dict[str, Any],
+ variables: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Validate the selected causal method against dataset characteristics.
+
+ Args:
+ method_info: Information about the selected method from decision_tree
+ dataset_analysis: Dataset analysis results from dataset_analyzer
+ variables: Identified variables from query_interpreter
+
+ Returns:
+ Dict with validation results:
+ - valid: Boolean indicating if method is valid
+ - concerns: List of concerns/issues with the selected method
+ - alternative_suggestions: Alternative methods if the selected method is problematic
+ - recommended_method: Updated method recommendation if issues are found
+ """
+ method = method_info.get("selected_method")
+ assumptions = method_info.get("method_assumptions", [])
+
+ # Get required variables
+ treatment = variables.get("treatment_variable")
+ outcome = variables.get("outcome_variable")
+ covariates = variables.get("covariates", [])
+ time_variable = variables.get("time_variable")
+ group_variable = variables.get("group_variable")
+ instrument_variable = variables.get("instrument_variable")
+ running_variable = variables.get("running_variable")
+ cutoff_value = variables.get("cutoff_value")
+
+ # Initialize validation result
+ validation_result = {
+ "valid": True,
+ "concerns": [],
+ "alternative_suggestions": [],
+ "recommended_method": method,
+ }
+
+ # Common validations for all methods
+ if treatment is None:
+ validation_result["valid"] = False
+ validation_result["concerns"].append("Treatment variable is not identified")
+
+ if outcome is None:
+ validation_result["valid"] = False
+ validation_result["concerns"].append("Outcome variable is not identified")
+
+ # Method-specific validations
+ if method == "propensity_score_matching":
+ validate_propensity_score_matching(validation_result, dataset_analysis, variables)
+
+ elif method == "regression_adjustment":
+ validate_regression_adjustment(validation_result, dataset_analysis, variables)
+
+ elif method == "instrumental_variable":
+ validate_instrumental_variable(validation_result, dataset_analysis, variables)
+
+ elif method == "difference_in_differences":
+ validate_difference_in_differences(validation_result, dataset_analysis, variables)
+
+ elif method == "regression_discontinuity_design":
+ validate_regression_discontinuity(validation_result, dataset_analysis, variables)
+
+ elif method == "backdoor_adjustment":
+ validate_backdoor_adjustment(validation_result, dataset_analysis, variables)
+
+ # If there are serious concerns, recommend alternatives
+ if not validation_result["valid"]:
+ validation_result["recommended_method"] = recommend_alternative(
+ method, validation_result["concerns"], method_info.get("alternatives", [])
+ )
+
+ # Make sure assumptions are listed in the validation result
+ validation_result["assumptions"] = assumptions
+ print("--------------------------")
+ print("Validation result:", validation_result)
+ print("--------------------------")
+ return validation_result
+
+
+def validate_propensity_score_matching(validation_result: Dict[str, Any],
+ dataset_analysis: Dict[str, Any],
+ variables: Dict[str, Any]) -> None:
+ """
+ Validate propensity score matching method requirements.
+
+ Args:
+ validation_result: Current validation result to update
+ dataset_analysis: Dataset analysis results
+ variables: Identified variables
+ """
+ treatment = variables.get("treatment_variable")
+ covariates = variables.get("covariates", [])
+
+ # Check if treatment is binary using column_categories
+ is_binary = dataset_analysis.get("column_categories", {}).get(treatment) == "binary"
+
+ # Fallback to check if the column has only two unique values (0 and 1)
+ if not is_binary:
+ column_types = dataset_analysis.get("column_types", {})
+ if column_types.get(treatment) == "int64" or column_types.get(treatment) == "int32":
+ # Assuming int type with only 0s and 1s is binary
+ is_binary = True
+
+ if not is_binary:
+ validation_result["valid"] = False
+ validation_result["concerns"].append(
+ "Treatment variable is not binary, which is required for propensity score matching"
+ )
+
+ # Check if there are sufficient covariates
+ if len(covariates) < 2:
+ validation_result["concerns"].append(
+ "Few covariates identified, which may limit the effectiveness of propensity score matching"
+ )
+
+ # Check for sufficient overlap
+ variable_relationships = dataset_analysis.get("variable_relationships", {})
+ treatment_imbalance = variable_relationships.get("treatment_imbalance", 0.5)
+
+ if treatment_imbalance < 0.1 or treatment_imbalance > 0.9:
+ validation_result["concerns"].append(
+ "Treatment groups are highly imbalanced, which may lead to poor matching quality"
+ )
+ validation_result["alternative_suggestions"].append("regression_adjustment")
+
+
+def validate_regression_adjustment(validation_result: Dict[str, Any],
+ dataset_analysis: Dict[str, Any],
+ variables: Dict[str, Any]) -> None:
+ """
+ Validate regression adjustment method requirements.
+
+ Args:
+ validation_result: Current validation result to update
+ dataset_analysis: Dataset analysis results
+ variables: Identified variables
+ """
+ outcome = variables.get("outcome_variable")
+
+ # Check outcome type for appropriate regression model
+ outcome_data = dataset_analysis.get("variable_types", {}).get(outcome, {})
+ outcome_type = outcome_data.get("type")
+
+ if outcome_type == "categorical" and outcome_data.get("n_categories", 0) > 2:
+ validation_result["concerns"].append(
+ "Outcome is categorical with multiple categories, which may require multinomial regression"
+ )
+
+ # Check for potential nonlinear relationships
+ nonlinear_relationships = dataset_analysis.get("nonlinear_relationships", False)
+
+ if nonlinear_relationships:
+ validation_result["concerns"].append(
+ "Potential nonlinear relationships detected, which may require more flexible models"
+ )
+
+
+def validate_instrumental_variable(validation_result: Dict[str, Any],
+ dataset_analysis: Dict[str, Any],
+ variables: Dict[str, Any]) -> None:
+ """
+ Validate instrumental variable method requirements.
+
+ Args:
+ validation_result: Current validation result to update
+ dataset_analysis: Dataset analysis results
+ variables: Identified variables
+ """
+ instrument_variable = variables.get("instrument_variable")
+ treatment = variables.get("treatment_variable")
+
+ if instrument_variable is None:
+ validation_result["valid"] = False
+ validation_result["concerns"].append(
+ "No instrumental variable identified, which is required for this method"
+ )
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
+ return
+
+ # Check for instrument strength (correlation with treatment)
+ variable_relationships = dataset_analysis.get("variable_relationships", {})
+ instrument_correlation = next(
+ (corr.get("correlation", 0) for corr in variable_relationships.get("correlations", [])
+ if corr.get("var1") == instrument_variable and corr.get("var2") == treatment
+ or corr.get("var1") == treatment and corr.get("var2") == instrument_variable),
+ 0
+ )
+
+ if abs(instrument_correlation) < 0.2:
+ validation_result["concerns"].append(
+ "Instrument appears weak (low correlation with treatment), which may lead to bias"
+ )
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
+
+
+def validate_difference_in_differences(validation_result: Dict[str, Any],
+ dataset_analysis: Dict[str, Any],
+ variables: Dict[str, Any]) -> None:
+ """
+ Validate difference-in-differences method requirements.
+
+ Args:
+ validation_result: Current validation result to update
+ dataset_analysis: Dataset analysis results
+ variables: Identified variables
+ """
+ time_variable = variables.get("time_variable")
+ group_variable = variables.get("group_variable")
+
+ if time_variable is None:
+ validation_result["valid"] = False
+ validation_result["concerns"].append(
+ "No time variable identified, which is required for difference-in-differences"
+ )
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
+
+ if group_variable is None:
+ validation_result["valid"] = False
+ validation_result["concerns"].append(
+ "No group variable identified, which is required for difference-in-differences"
+ )
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
+
+ # Check for parallel trends
+ temporal_structure = dataset_analysis.get("temporal_structure", {})
+ parallel_trends = temporal_structure.get("parallel_trends", False)
+
+ if not parallel_trends:
+ validation_result["concerns"].append(
+ "No evidence of parallel trends, which is a key assumption for difference-in-differences"
+ )
+ validation_result["alternative_suggestions"].append("synthetic_control")
+
+
+def validate_regression_discontinuity(validation_result: Dict[str, Any],
+ dataset_analysis: Dict[str, Any],
+ variables: Dict[str, Any]) -> None:
+ """
+ Validate regression discontinuity method requirements.
+
+ Args:
+ validation_result: Current validation result to update
+ dataset_analysis: Dataset analysis results
+ variables: Identified variables
+ """
+ running_variable = variables.get("running_variable")
+ cutoff_value = variables.get("cutoff_value")
+
+ if running_variable is None:
+ validation_result["valid"] = False
+ validation_result["concerns"].append(
+ "No running variable identified, which is required for regression discontinuity"
+ )
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
+
+ if cutoff_value is None:
+ validation_result["valid"] = False
+ validation_result["concerns"].append(
+ "No cutoff value identified, which is required for regression discontinuity"
+ )
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
+
+ # Check for discontinuity at threshold
+ discontinuities = dataset_analysis.get("discontinuities", {})
+ has_discontinuity = discontinuities.get("has_discontinuities", False)
+
+ if not has_discontinuity:
+ validation_result["valid"] = False
+ validation_result["concerns"].append(
+ "No clear discontinuity detected at the threshold, which is necessary for this method"
+ )
+ validation_result["alternative_suggestions"].append("regression_adjustment")
+
+def validate_backdoor_adjustment(validation_result: Dict[str, Any],
+ dataset_analysis: Dict[str, Any],
+ variables: Dict[str, Any]) -> None:
+ """
+ Validate backdoor adjustment method requirements.
+
+ Args:
+ validation_result: Current validation result to update
+ dataset_analysis: Dataset analysis results
+ variables: Identified variables
+ """
+ covariates = variables.get("covariates", [])
+
+ if len(covariates) == 0:
+ validation_result["valid"] = False
+ validation_result["concerns"].append(
+ "No covariates identified for backdoor adjustment"
+ )
+ validation_result["alternative_suggestions"].append("regression_adjustment")
+
+
+def recommend_alternative(method: str, concerns: List[str], alternatives: List[str]) -> str:
+ """
+ Recommend an alternative method if the current one has issues.
+
+ Args:
+ method: Current method
+ concerns: List of concerns with the current method
+ alternatives: List of alternative methods suggested by the decision tree
+
+ Returns:
+ String with the recommended method
+ """
+ # If there are alternatives, recommend the first one
+ if alternatives:
+ return alternatives[0]
+
+ # If no alternatives, use regression adjustment as a fallback
+ if method != "regression_adjustment":
+ return "regression_adjustment"
+
+ # If regression adjustment is also problematic, use propensity score matching
+ return "propensity_score_matching"
\ No newline at end of file
diff --git a/auto_causal/components/output_formatter.py b/auto_causal/components/output_formatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d90267fbf3065caef99c42e3348abdf8997156f1
--- /dev/null
+++ b/auto_causal/components/output_formatter.py
@@ -0,0 +1,138 @@
+"""
+Output formatter component for causal inference results.
+
+This module formats the results of causal analysis into a clear,
+structured output for presentation to the user.
+"""
+
+from typing import Dict, List, Any, Optional
+import json # Add this import at the top of the file
+
+# Import the new model
+from auto_causal.models import FormattedOutput
+
+# Add this module-level variable, typically near imports or at the top
+CURRENT_OUTPUT_LOG_FILE = None
+
+# Revert signature and logic to handle results and structured explanation
+def format_output(
+ query: str,
+ method: str,
+ results: Dict[str, Any],
+ explanation: Dict[str, Any],
+ dataset_analysis: Optional[Dict[str, Any]] = None,
+ dataset_description: Optional[str] = None
+) -> FormattedOutput:
+ """
+ Format final results including numerical estimates and explanations.
+
+ Args:
+ query: Original user query
+ method: Causal inference method used (string name)
+ results: Numerical results from method_executor_tool
+ explanation: Structured explanation object from explainer_tool
+ dataset_analysis: Optional dictionary of dataset analysis results
+ dataset_description: Optional string description of the dataset
+
+ Returns:
+ Dict with formatted output fields ready for presentation.
+ """
+ # Extract numerical results
+ effect_estimate = results.get("effect_estimate")
+ confidence_interval = results.get("confidence_interval")
+ p_value = results.get("p_value")
+ effect_se = results.get("standard_error") # Get SE if available
+
+ # Format method name for readability
+ method_name_formatted = _format_method_name(method)
+
+ # Extract explanation components (assuming explainer returns structured dict again)
+ # If explainer returns single string, adjust this
+ method_explanation_text = explanation.get("method_explanation", "")
+ interpretation_guide = explanation.get("interpretation_guide", "")
+ limitations = explanation.get("limitations", [])
+ assumptions_discussion = explanation.get("assumptions", "") # Assuming key is 'assumptions'
+ practical_implications = explanation.get("practical_implications", "")
+ # Add back final_explanation_text if explainer provides it
+ # final_explanation_text = explanation.get("final_explanation_text")
+
+ # Create summary using numerical results
+ ci_text = ""
+ if confidence_interval and confidence_interval[0] is not None and confidence_interval[1] is not None:
+ ci_text = f" (95% CI: [{confidence_interval[0]:.4f}, {confidence_interval[1]:.4f}])"
+
+ p_value_text = f", p={p_value:.4f}" if p_value is not None else ""
+ effect_text = f"{effect_estimate:.4f}" if effect_estimate is not None else "N/A"
+
+ summary = (
+ f"Based on {method_name_formatted}, the estimated causal effect is {effect_text}"
+ f"{ci_text}{p_value_text}. {_create_effect_interpretation(effect_estimate, p_value)}"
+ f" See details below regarding assumptions and limitations."
+ )
+
+ # Assemble formatted output dictionary
+ results_dict = {
+ "query": query,
+ "method_used": method_name_formatted,
+ "causal_effect": effect_estimate,
+ "standard_error": effect_se,
+ "confidence_interval": confidence_interval,
+ "p_value": p_value,
+ "summary": summary,
+ "method_explanation": method_explanation_text,
+ "interpretation_guide": interpretation_guide,
+ "limitations": limitations,
+ "assumptions": assumptions_discussion,
+ "practical_implications": practical_implications,
+ # "full_explanation_text": final_explanation_text # Optionally include combined text
+ }
+ final_results_dict = {key : results_dict[key] for key in {"query", "method_used", "causal_effect", "standard_error", "confidence_interval"}}
+ # print(final_results_dict)
+
+ # Validate and instantiate the Pydantic model
+ try:
+ formatted_output_model = FormattedOutput(**results_dict)
+ except Exception as e: # Catch validation errors specifically if needed
+ # Handle validation error - perhaps log and return a default or raise
+ print(f"Error creating FormattedOutput model: {e}") # Or use logger
+ # Decide on error handling: raise, return None, return default?
+ # For now, re-raising might be simplest if the structure is expected
+ raise ValueError(f"Failed to create FormattedOutput from results: {e}")
+
+ return formatted_output_model # Return the Pydantic model instance
+
+
+def _format_method_name(method: str) -> str:
+ """Format method name for readability."""
+ method_names = {
+ "propensity_score_matching": "Propensity Score Matching",
+ "regression_adjustment": "Regression Adjustment",
+ "instrumental_variable": "Instrumental Variable Analysis",
+ "difference_in_differences": "Difference-in-Differences",
+ "regression_discontinuity": "Regression Discontinuity Design",
+ "backdoor_adjustment": "Backdoor Adjustment",
+ "propensity_score_weighting": "Propensity Score Weighting"
+ }
+ return method_names.get(method, method.replace("_", " ").title())
+
+# Reinstate helper function for interpretation
+def _create_effect_interpretation(effect: Optional[float], p_value: Optional[float] = None) -> str:
+ """Create a basic interpretation of the effect."""
+ if effect is None:
+ return "Effect estimate not available."
+
+ significance = ""
+ if p_value is not None:
+ significance = "statistically significant" if p_value < 0.05 else "not statistically significant"
+
+ magnitude = ""
+ if abs(effect) < 0.01:
+ magnitude = "no practical effect"
+ elif abs(effect) < 0.1:
+ magnitude = "a small effect"
+ elif abs(effect) < 0.5:
+ magnitude = "a moderate effect"
+ else:
+ magnitude = "a substantial effect"
+
+ return f"This suggests {magnitude}{f' and is {significance}' if significance else ''}."
\ No newline at end of file
diff --git a/auto_causal/components/query_interpreter.py b/auto_causal/components/query_interpreter.py
new file mode 100644
index 0000000000000000000000000000000000000000..33b8f53ddd63a5f07b18df319b7e2095a9699b89
--- /dev/null
+++ b/auto_causal/components/query_interpreter.py
@@ -0,0 +1,580 @@
+"""
+Query interpreter component for causal inference.
+
+This module provides functionality to match query concepts to actual dataset variables,
+identifying treatment, outcome, and covariate variables for causal inference analysis.
+"""
+
+import re
+from typing import Dict, List, Any, Optional, Union, Tuple
+import pandas as pd
+import logging
+import numpy as np
+from auto_causal.config import get_llm_client
+# Import LLM and message types
+from langchain_core.language_models import BaseChatModel
+from langchain_core.messages import HumanMessage
+from langchain_core.exceptions import OutputParserException
+# Import base Pydantic models needed directly
+from pydantic import BaseModel, ValidationError
+from dowhy import CausalModel
+import json
+
+# Import shared Pydantic models from the central location
+from auto_causal.models import (
+ LLMSelectedVariable,
+ LLMSelectedCovariates,
+ LLMIVars,
+ LLMRDDVars,
+ LLMRCTCheck,
+ LLMTreatmentReferenceLevel,
+ LLMInteractionSuggestion,
+ LLMEstimand,
+ # LLMDIDCheck,
+ # LLMDiDTemporalVars,
+ # LLMDiDGroupVars,
+ # LLMRDDCheck,
+ # LLMRDDVarsExtended
+)
+
+# Import the new prompt templates
+from auto_causal.prompts.method_identification_prompts import (
+ IV_IDENTIFICATION_PROMPT_TEMPLATE,
+ RDD_IDENTIFICATION_PROMPT_TEMPLATE,
+ RCT_IDENTIFICATION_PROMPT_TEMPLATE,
+ TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE,
+ INTERACTION_TERM_IDENTIFICATION_PROMPT_TEMPLATE,
+ TREATMENT_VAR_IDENTIFICATION_PROMPT_TEMPLATE,
+ OUTCOME_VAR_IDENTIFICATION_PROMPT_TEMPLATE,
+ COVARIATES_IDENTIFICATION_PROMPT_TEMPLATE,
+ ESTIMAND_PROMPT_TEMPLATE,
+ CONFOUNDER_IDENTIFICATION_PROMPT_TEMPLATE,
+ DID_TERM_IDENTIFICATION_PROMPT_TEMPLATE)
+
+
+# Assume central models are defined elsewhere or keep local definitions for now
+# from ..models import ...
+
+# --- Pydantic models for LLM structured output ---
+# REMOVED - Now defined in causalscientist/auto_causal/models.py
+# class LLMSelectedVariable(BaseModel): ...
+# class LLMSelectedCovariates(BaseModel): ...
+# class LLMIVars(BaseModel): ...
+# class LLMRDDVars(BaseModel): ...
+# class LLMRCTCheck(BaseModel): ...
+
+
+logger = logging.getLogger(__name__)
+
+def infer_treatment_variable_type(treatment_variable: str, column_categories: Dict[str, str],
+ dataset_analysis: Dict[str, Any]) -> str:
+ """
+ Determine treatment variable type from column category and unique value count
+ Args:
+ treatment_variable: name of the treatment variable
+ column_categories: mapping of column names to their categories
+ dataset_analysis: exploratory analysis results
+
+ Returns:
+ str: type of the treatment variable (e.g., "binary", "continuous", etc
+ """
+
+ treatment_variable_type = "unknown"
+ if treatment_variable and treatment_variable in column_categories:
+ category = column_categories[treatment_variable]
+ logger.info(f"Category for treatment '{treatment_variable}' is '{category}'.")
+
+ if category == "continuous_numeric":
+ treatment_variable_type = "continuous"
+
+ elif category == "discrete_numeric":
+ num_unique = dataset_analysis.get("column_nunique_counts", {}).get(treatment_variable, -1)
+ if num_unique > 10:
+ logger.info(f"'{treatment_variable}' has {num_unique} unique values, treating as continuous.")
+ treatment_variable_type = "continuous"
+ elif num_unique == 2:
+ logger.info(f"'{treatment_variable}' has 2 unique values, treating as binary.")
+ treatment_variable_type = "binary"
+ elif num_unique > 0:
+ logger.info(f"'{treatment_variable}' has {num_unique} unique values, treating as discrete_multi_value.")
+ treatment_variable_type = "discrete_multi_value"
+ else:
+ logger.info(f"'{treatment_variable}' unique value count unknown or too few.")
+ treatment_variable_type = "discrete_numeric_unknown_cardinality"
+
+ elif category in ["binary", "binary_categorical"]:
+ treatment_variable_type = "binary"
+
+ elif category in ["categorical", "categorical_numeric"]:
+ num_unique = dataset_analysis.get("column_nunique_counts", {}).get(treatment_variable, -1)
+ if num_unique == 2:
+ treatment_variable_type = "binary"
+ elif num_unique > 0:
+ treatment_variable_type = "categorical_multi_value"
+ else:
+ treatment_variable_type = "categorical_unknown_cardinality"
+
+ else:
+ logger.warning(f"Unmapped category '{category}' for '{treatment_variable}', setting as 'other'.")
+ treatment_variable_type = "other"
+
+ elif treatment_variable:
+ logger.warning(f"'{treatment_variable}' not found in column_categories.")
+ else:
+ logger.info("No treatment variable identified.")
+
+ logger.info(f"Final Determined Treatment Variable Type: {treatment_variable_type}")
+ return treatment_variable_type
+
+def determine_treatment_reference_level(is_rct: Optional[bool], llm: Optional[BaseChatModel], treatment_variable: Optional[str],
+ query_text: str, dataset_description: Optional[str], file_path: Optional[str],
+ columns: List[str]) -> Optional[str]:
+ """
+ Determines the treatment reference level
+ """
+
+ # If LLM didn't explicitly say RCT, default to False or keep None?
+ # Let's default to False if LLM didn't provide a boolean value.
+ if is_rct is None: is_rct = False
+ treatment_reference_level = None
+
+ if llm and treatment_variable and treatment_variable in columns:
+ treatment_values_sample = []
+ if file_path:
+ try:
+ df = pd.read_csv(file_path)
+ if treatment_variable in df.columns:
+ unique_vals = df[treatment_variable].unique()
+ treatment_values_sample = [item.item() if hasattr(item, 'item') else item for item in unique_vals][:10]
+ if treatment_values_sample:
+ logger.info(f"Successfully read treatment values sample from dataset at '{file_path}' for variable '{treatment_variable}'.")
+ else:
+ logger.info(f"'{treatment_variable}' in '{file_path}' has no unique values or is empty.")
+ else:
+ logger.warning(f"'{treatment_variable}' not found in dataset columns at '{file_path}'.")
+ except FileNotFoundError:
+ logger.warning(f"File not found at: {file_path}")
+ except pd.errors.EmptyDataError:
+ logger.warning(f"Empty file at: {file_path}")
+ except Exception as e:
+ logger.warning(f"Error reading dataset at '{file_path}' for '{treatment_variable}': {e}")
+
+ if not treatment_values_sample:
+ logger.warning(f"No unique values found for treatment '{treatment_variable}'. LLM prompt will receive empty list.")
+ else:
+ logger.info(f"Final treatment values sample: {treatment_values_sample}")
+
+ try:
+ prompt = TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description or 'N/A', treatment_variable=treatment_variable, treatment_variable_values=treatment_values_sample)
+ ref_result = _call_llm_for_var(llm, prompt, LLMTreatmentReferenceLevel)
+ if ref_result and ref_result.reference_level:
+ if treatment_values_sample and ref_result.reference_level not in treatment_values_sample:
+ logger.warning(f"LLM reference level '{ref_result.reference_level}' not in sampled values for '{treatment_variable}'.")
+ treatment_reference_level = ref_result.reference_level
+ logger.info(f"LLM identified reference level: {treatment_reference_level} (Reason: {ref_result.reasoning})")
+ elif ref_result:
+ logger.info(f"LLM returned no reference level (Reason: {ref_result.reasoning})")
+ except Exception as e:
+ logger.error(f"LLM error for treatment reference level: {e}")
+
+ return treatment_reference_level
+
+def identify_interaction_term(llm: Optional[BaseChatModel], treatment_variable: Optional[str], covariates: List[str],
+ column_categories: Dict[str, str], query_text: str,
+ dataset_description: Optional[str]) -> Tuple[bool, Optional[str]]:
+ """
+ Identifies the interaction term based on the query and the dataset information
+ """
+
+ interaction_term_suggested, interaction_variable_candidate = False, None
+
+ if llm and treatment_variable and covariates:
+ try:
+ covariates_list_str = "\n".join([f"- {cov}: {column_categories.get(cov, 'Unknown')}" for cov in covariates]) or "No covariates identified or available."
+ prompt = INTERACTION_TERM_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description or 'N/A', treatment_variable=treatment_variable, covariates_list_with_types=covariates_list_str)
+ result = _call_llm_for_var(llm, prompt, LLMInteractionSuggestion)
+ if result:
+ interaction_term_suggested = result.interaction_needed if result.interaction_needed is not None else False
+ if interaction_term_suggested and result.interaction_variable:
+ if result.interaction_variable in covariates:
+ interaction_variable_candidate = result.interaction_variable
+ logger.info(f"LLM suggested interaction: needed={interaction_term_suggested}, variable='{interaction_variable_candidate}' (Reason: {result.reasoning})")
+ else:
+ logger.warning(f"LLM suggested variable '{result.interaction_variable}' not in covariates {covariates}. Ignoring.")
+ interaction_term_suggested = False
+ elif interaction_term_suggested:
+ logger.info(f"LLM suggested interaction is needed but no variable provided (Reason: {result.reasoning})")
+ else:
+ logger.info(f"LLM suggested no interaction is needed (Reason: {result.reasoning})")
+ else:
+ logger.warning("LLM returned no result for interaction term suggestion.")
+ except Exception as e:
+ logger.error(f"LLM error during interaction term check: {e}")
+
+ return interaction_term_suggested, interaction_variable_candidate
+
+
+def interpret_query(query_info: Dict[str, Any], dataset_analysis: Dict[str, Any],
+ dataset_description: Optional[str] = None) -> Dict[str, Any]:
+ """
+ Interpret query using hybrid heuristic/LLM approach to identify variables.
+
+ Args:
+ query_info: Information extracted from the user's query (text, hints).
+ dataset_analysis: Information about the dataset structure (columns, types, etc.).
+ dataset_description: Optional textual description of the dataset.
+ llm: Optional language model instance.
+
+ Returns:
+ Dict containing identified variables (treatment, outcome, covariates, etc., and is_rct).
+ """
+
+ logger.info("Interpreting query with hybrid approach...")
+ llm = get_llm_client()
+
+ query_text = query_info.get("query_text", "")
+ columns = dataset_analysis.get("columns", [])
+ column_categories = dataset_analysis.get("column_categories", {})
+ file_path = dataset_analysis["dataset_info"]["file_path"]
+
+
+ # --- Identify Treatment ---
+ treatment_hints = query_info.get("potential_treatments", [])
+ dataset_treatments = dataset_analysis.get("potential_treatments", [])
+ treatment_variable = _identify_variable_hybrid(role="treatment", query_hints=treatment_hints,
+ dataset_suggestions=dataset_treatments, columns=columns,
+ column_categories=column_categories,
+ prioritize_types=["binary", "binary_categorical", "discrete_numeric","continuous_numeric"], # Prioritize binary/discrete
+ query_text=query_text, dataset_description=dataset_description,llm=llm)
+ logger.info(f"Identified Treatment: {treatment_variable}")
+ treatment_variable_type = infer_treatment_variable_type(treatment_variable, column_categories, dataset_analysis)
+
+
+ # --- Identify Outcome ---
+ outcome_hints = query_info.get("outcome_hints", [])
+ dataset_outcomes = dataset_analysis.get("potential_outcomes", [])
+ outcome_variable = _identify_variable_hybrid(role="outcome", query_hints=outcome_hints, dataset_suggestions=dataset_outcomes,
+ columns=columns, column_categories=column_categories,
+ prioritize_types=["continuous_numeric", "discrete_numeric"], # Prioritize numeric
+ exclude_vars=[treatment_variable], # Exclude treatment
+ query_text=query_text, dataset_description=dataset_description, llm=llm)
+ logger.info(f"Identified Outcome: {outcome_variable}")
+
+ # --- Identify Covariates ---
+ covariate_hints = query_info.get("covariates_hints", [])
+ covariates = _identify_covariates_hybrid("covars", treatment_variable=treatment_variable, outcome_variable=outcome_variable,
+ columns=columns, column_categories=column_categories, query_hints=covariate_hints,
+ query_text=query_text, dataset_description=dataset_description, llm=llm)
+ logger.info(f"Identified Covariates: {covariates}")
+
+ # --- Identify Confounders ---
+ confounder_hints = query_info.get("covariates_hints", [])
+ confounders = _identify_covariates_hybrid("confounders", treatment_variable=treatment_variable, outcome_variable=outcome_variable,
+ columns=columns, column_categories=column_categories, query_hints=confounder_hints,
+ query_text=query_text, dataset_description=dataset_description, llm=llm)
+ logger.info(f"Identified Confounders: {confounders}")
+
+ # --- Identify Time/Group (from dataset analysis) ---
+ time_variable = None
+ group_variable = None
+ has_temporal = dataset_analysis.get("temporal_structure", {}).get("has_temporal_structure", False)
+ temporal_structure = dataset_analysis.get("temporal_structure", {})
+ if temporal_structure.get("has_temporal_structure", False):
+ time_variable = temporal_structure.get("time_column") or temporal_structure.get("temporal_columns", [None])[0]
+ if temporal_structure.get("is_panel_data", False):
+ group_variable = temporal_structure.get("id_column")
+ logger.info(f"Identified Time Var: {time_variable}, Group Var: {group_variable}, temporal structure: {temporal_structure}")
+
+ # --- Identify IV/RDD/RCT using LLM ---
+ instrument_variable = None
+ running_variable = None
+ cutoff_value = None
+ is_rct = None
+ smd_score = None
+
+ if llm:
+ try:
+ # Check for RCT
+ prompt_rct = _create_identify_prompt("whether data is from RCT", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
+ rct_result = _call_llm_for_var(llm, prompt_rct, LLMRCTCheck)
+ is_rct = rct_result.is_rct if rct_result else None
+ logger.info(f"LLM identified RCT: {is_rct}")
+
+ # Check for IV
+ prompt_iv = _create_identify_prompt("instrumental variable", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
+ iv_result = _call_llm_for_var(llm, prompt_iv, LLMIVars)
+ instrument_variable = iv_result.instrument_variable if iv_result else None
+ if instrument_variable not in columns:
+ instrument_variable = None
+ logger.info(f"LLM identified IV: {instrument_variable}")
+
+ # Check for RDD
+ prompt_rdd = _create_identify_prompt("regression discontinuity (running variable and cutoff)", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
+ rdd_result = _call_llm_for_var(llm, prompt_rdd, LLMRDDVars)
+ if rdd_result:
+ running_variable = rdd_result.running_variable
+ cutoff_value = rdd_result.cutoff_value
+ if running_variable not in columns or cutoff_value is None:
+ running_variable = None
+ cutoff_value = None
+ logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}")
+
+ ## For graph based methods
+ exclude_cols = [treatment_variable, outcome_variable]
+ potential_covariates = [col for col in columns if col not in exclude_cols and col is not None]
+ usable_covariates = [col for col in potential_covariates if column_categories.get(col) not in ["text_or_other"]]
+ logger.info(f"Usable covariates for graph: {usable_covariates}")
+
+ estimand_prompt = ESTIMAND_PROMPT_TEMPLATE.format(query=query_text,dataset_description=dataset_description,
+ dataset_columns=usable_covariates,
+ treatment=treatment_variable, outcome=outcome_variable)
+
+ estimand_result = _call_llm_for_var(llm, estimand_prompt, LLMEstimand)
+ estimand = "ate" if "ate" in estimand_result.estimand.strip().lower() else "att"
+ logger.info(f"LLM identified estimand: {estimand}")
+
+ ## Did Term
+ did_term_prompt = DID_TERM_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description,
+ column_info=columns, time_variable=time_variable,
+ group_variable=group_variable, column_types=column_categories)
+ did_term_result = _call_llm_for_var(llm, did_term_prompt, LLMRDDVars)
+ did_term_result = did_term_result.did_term if did_term_result in columns else None
+ logger.info(f"LLM identified DiD term: {did_term_result}")
+
+
+
+ #smd_score_all = compute_smd(dataset_analysis.get("data", pd.DataFrame()), treatment_variable, usable_covariates)
+ #smd_score = smd_score_all.get("ate", 0.0) if smd_score_all else 0.0
+ #logger.info(f"Computed SMD score: {smd_score}")
+
+ #logger.debug(f"Computed SMD score for {estimand}: {smd_score}")
+
+
+ except Exception as e:
+ logger.error(f"Error during LLM checks for IV/RDD/RCT: {e}")
+
+
+
+ # --- Identify Treatment Reference Level ---
+ treatment_reference_level = determine_treatment_reference_level(is_rct=is_rct, llm=llm, treatment_variable=treatment_variable,
+ query_text=query_text, dataset_description=dataset_description,
+ file_path=file_path, columns=columns)
+
+ # --- Identify Interaction Term Suggestion ---
+ interaction_term_suggested, interaction_variable_candidate = identify_interaction_term(llm=llm, treatment_variable=treatment_variable,
+ covariates=covariates,
+ column_categories=column_categories, query_text=query_text,
+ dataset_description=dataset_description)
+
+
+ # --- Consolidate ---
+ return {
+ "treatment_variable": treatment_variable,
+ "treatment_variable_type": treatment_variable_type,
+ "outcome_variable": outcome_variable,
+ "covariates": covariates,
+ "time_variable": time_variable,
+ "group_variable": group_variable,
+ "instrument_variable": instrument_variable,
+ "running_variable": running_variable,
+ "cutoff_value": cutoff_value,
+ "is_rct": is_rct,
+ "treatment_reference_level": treatment_reference_level,
+ "interaction_term_suggested": interaction_term_suggested,
+ "interaction_variable_candidate": interaction_variable_candidate,
+ "confounders": confounders,
+ "did_term": did_term_result
+ }
+
+def compute_smd(df: pd.DataFrame, treat, covars_list) -> Dict[str, float]:
+ """
+ Computed the standardized mean differences (SMD) for the treatment variable
+ Args:
+ df (pd.DataFrame): The dataset.
+ treat (str): Name of the binary treatment column (0/1).
+ covars_list (List[str]): List of covariate names to consider for SMD calculation
+
+ Returns:
+ Dict{str ->float}: the standardized mean difference (SMD)
+ """
+ logger.info(f"Computing SMD for treatment variable '{treat}' with covariates: {covars_list}")
+ df_t = df[df[treat] == 1]
+ df_c = df[df[treat] == 0]
+
+ covariates = covars_list if covars_list else df.columns.tolist()
+ smd_ate = np.zeros(len(covariates))
+ smd_att = np.zeros(len(covariates))
+
+ for i, col in enumerate(covariates):
+ try:
+ m_t, m_c = df_t[col].mean(), df_c[col].mean()
+ s_t, s_c = df_t[col].std(ddof=0), df_c[col].std(ddof=0)
+ pooled = np.sqrt((s_t**2 + s_c**2) / 2)
+
+ ate_val = 0.0 if pooled == 0 else (m_t - m_c) / pooled
+ att_val = 0.0 if s_t == 0 else (m_t - m_c) / s_t
+
+ smd_ate.append(ate_val)
+ smd_att.append(att_val)
+ except Exception as e:
+ logger.warning(f"SMD computation failed for column '{col}': {e}")
+ continue
+
+ avg_ate = np.nanmean(np.abs(smd_ate))
+ avg_att = np.nanmean(np.abs(smd_att))
+
+ return {"ate":avg_ate, "att":avg_att}
+
+
+
+# --- Helper Functions for Hybrid Identification ---
+def _identify_variable_hybrid(role: str, query_hints: List[str], dataset_suggestions: List[str],
+ columns: List[str], column_categories: Dict[str, str],
+ prioritize_types: List[str], query_text: str,
+ dataset_description: Optional[str],llm: Optional[BaseChatModel],
+ exclude_vars: Optional[List[str]] = None) -> Optional[str]:
+ """
+ Used to identify a variable from the avaiable information by prompting the LLM. In case of failure,
+ it will fallback to a programmatic selection (heuristics)
+
+ Args:
+ role: variable type (treatment or outcome)
+ query_hints: hints from the query for this variable
+ dataset_suggestions: dataset-specific suggestions for this variable
+ columns: list of available columns in the dataset
+ column_categories: mapping of column names to their categories
+ prioritize_types: types to prioritize for this variable
+ query_text: the original query text
+ dataset_description: description of the dataset
+ llm: language model
+ exclude_vars: list of variables to exclude from selection (e.g., treatment for outcome)
+ Returns:
+ str: name of the identified variable, or None if not found
+ """
+
+ candidates = set()
+ available_columns = [c for c in columns if c not in (exclude_vars or [])]
+ if not available_columns: return None
+
+ # 1. Exact matches from hints
+ for hint in query_hints:
+ if hint in available_columns:
+ candidates.add(hint)
+ # 2. Add dataset suggestions
+ for sugg in dataset_suggestions:
+ if sugg in available_columns:
+ candidates.add(sugg)
+
+ # 3. Programmatic Filtering based on type
+ plausible_candidates = [c for c in candidates if column_categories.get(c) in prioritize_types]
+
+ if llm:
+ if role == "treatment":
+ prompt_template = TREATMENT_VAR_IDENTIFICATION_PROMPT_TEMPLATE
+ elif role == "outcome":
+ prompt_template = OUTCOME_VAR_IDENTIFICATION_PROMPT_TEMPLATE
+ else:
+ raise ValueError(f"Unsupported role for LLM variable identification: {role}")
+
+ prompt = prompt_template.format(query=query_text, description=dataset_description,
+ column_info=available_columns)
+ llm_choice = _call_llm_for_var(llm, prompt, LLMSelectedVariable)
+
+ if llm_choice and llm_choice.variable_name in available_columns:
+ logger.info(f"LLM selected {role}: {llm_choice.variable_name}")
+ return llm_choice.variable_name
+ else:
+ fallback = plausible_candidates[0] if plausible_candidates else None
+ logger.warning(f"LLM failed to select valid {role}. Falling back to: {fallback}")
+ return fallback
+
+ if plausible_candidates:
+ logger.info(f"No LLM provided. Using first plausible {role}: {plausible_candidates[0]}")
+ return plausible_candidates[0]
+
+ logger.warning(f"No plausible candidates for {role}. Cannot identify variable.")
+ return None
+
+
+def _identify_covariates_hybrid(role, treatment_variable: Optional[str], outcome_variable: Optional[str],
+ columns: List[str], column_categories: Dict[str, str], query_hints: List[str],
+ query_text: str, dataset_description: Optional[str], llm: Optional[BaseChatModel]) -> List[str]:
+ """
+ Prompts an LLM to identify the covariates
+ """
+
+ # 1. Initial Programmatic Filtering
+ exclude_cols = [treatment_variable, outcome_variable]
+ potential_covariates = [col for col in columns if col not in exclude_cols and col is not None]
+
+ # Filter out unusable types
+ usable_covariates = [col for col in potential_covariates if column_categories.get(col) not in ["text_or_other"]]
+ logger.debug(f"Initial usable covariates: {usable_covariates}")
+
+ # 2. LLM Refinement (if LLM available)
+ if llm:
+ logger.info("Using LLM to refine covariate list...")
+ prompt = ""
+ if role == "covars":
+ prompt = COVARIATES_IDENTIFICATION_PROMPT_TEMPLATE.format("covars", query=query_text, description=dataset_description,
+ column_info=", ".join(usable_covariates),
+ treatment=treatment_variable, outcome=outcome_variable)
+ elif role == "confounders":
+ prompt = CONFOUNDER_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description,
+ column_info=", ".join(usable_covariates),
+ treatment=treatment_variable, outcome=outcome_variable)
+ llm_selection = _call_llm_for_var(llm, prompt, LLMSelectedCovariates)
+
+ if llm_selection and llm_selection.covariates:
+ # Validate LLM output against available columns
+ valid_llm_covs = [c for c in llm_selection.covariates if c in usable_covariates]
+ if len(valid_llm_covs) < len(llm_selection.covariates):
+ logger.warning("LLM suggested covariates not found in initial usable list.")
+ if valid_llm_covs: # Use LLM selection if it's valid and non-empty
+ logger.info(f"LLM refined covariates to: {valid_llm_covs}")
+ return valid_llm_covs[:10] # Cap at 10
+ else:
+ logger.warning("LLM refinement failed or returned empty/invalid list. Falling back.")
+ else:
+ logger.warning("LLM refinement call failed or returned no covariates. Falling back.")
+
+ # 3. Fallback to Programmatic List (Capped)
+ logger.info(f"Using programmatically determined covariates (capped at 10): {usable_covariates[:10]}")
+ return usable_covariates[:10]
+
+def _create_identify_prompt(target: str, query: str, description: Optional[str], columns: List[str],
+ categories: Dict[str,str], treatment: Optional[str], outcome: Optional[str]) -> str:
+ """
+ Creates a prompt to ask LLM to identify specific roles like IV, RDD, or RCT by selecting and formatting a specific template
+ """
+ column_info = "\n".join([f"- '{c}' (Type: {categories.get(c, 'Unknown')})" for c in columns])
+
+ # Select the appropriate detailed prompt template based on the target
+ if "instrumental variable" in target.lower():
+ template = IV_IDENTIFICATION_PROMPT_TEMPLATE
+ elif "regression discontinuity" in target.lower():
+ template = RDD_IDENTIFICATION_PROMPT_TEMPLATE
+ elif "rct" in target.lower():
+ template = RCT_IDENTIFICATION_PROMPT_TEMPLATE
+ else:
+ # Fallback or error? For now, let's raise an error if target is unexpected.
+ logger.error(f"Unsupported target for _create_identify_prompt: {target}")
+ raise ValueError(f"Unsupported target for specific identification prompt: {target}")
+
+ # Format the selected template with the provided context
+ prompt = template.format(query=query, description=description or 'N/A', column_info=column_info,
+ treatment=treatment or 'N/A', outcome=outcome or 'N/A')
+ return prompt
+
+def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel) -> Optional[BaseModel]:
+ """Helper to call LLM with structured output and handle errors."""
+ try:
+ messages = [HumanMessage(content=prompt)]
+ structured_llm = llm.with_structured_output(pydantic_model)
+ parsed_result = structured_llm.invoke(messages)
+ return parsed_result
+ except (OutputParserException, ValidationError) as e:
+ logger.error(f"LLM call failed parsing/validation for {pydantic_model.__name__}: {e}")
+ except Exception as e:
+ logger.error(f"LLM call failed unexpectedly for {pydantic_model.__name__}: {e}", exc_info=True)
+ return None
diff --git a/auto_causal/components/state_manager.py b/auto_causal/components/state_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb3c23594b7a8ea47a58de310b458c969fefa3c3
--- /dev/null
+++ b/auto_causal/components/state_manager.py
@@ -0,0 +1,40 @@
+"""
+State management utilities for the auto_causal workflow.
+
+This module provides utility functions to create standardized state updates
+for passing between tools in the auto_causal agent workflow.
+"""
+
+from typing import Dict, Any, Optional
+
+def create_workflow_state_update(
+ current_step: str,
+ step_completed_flag: bool,
+ next_tool: str,
+ next_step_reason: str,
+ error: Optional[str] = None
+) -> Dict[str, Any]:
+ """
+ Create a standardized workflow state update dictionary.
+
+ Args:
+ current_step: Current step in the workflow (e.g., "input_processing")
+ step_completed_flag: Flag indicating which step was completed (e.g., "query_parsed")
+ next_tool: Name of the next tool to call
+ next_step_reason: Reason message for the next step
+ error: Optional error message if the step failed
+
+ Returns:
+ Dictionary containing the workflow_state sub-dictionary
+ """
+ state_update = {
+ "workflow_state": {
+ "current_step": current_step,
+ current_step + "_completed": step_completed_flag,
+ "next_tool": next_tool,
+ "next_step_reason": next_step_reason
+ }
+ }
+ if error:
+ state_update["workflow_state"]["error_message"] = error
+ return state_update
\ No newline at end of file
diff --git a/auto_causal/config.py b/auto_causal/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f0224fb8ed4ddf81a4948e497d2ed03825f6e28
--- /dev/null
+++ b/auto_causal/config.py
@@ -0,0 +1,97 @@
+# auto_causal/config.py
+"""Central configuration for AutoCausal, including LLM client setup."""
+
+import os
+import logging
+from typing import Optional
+
+# Langchain imports
+from langchain_core.language_models import BaseChatModel
+from langchain_openai import ChatOpenAI # Default
+from langchain_anthropic import ChatAnthropic # Example
+from langchain_google_genai import ChatGoogleGenerativeAI
+# Add other providers if needed, e.g.:
+# from langchain_community.chat_models import ChatOllama
+from dotenv import load_dotenv
+from langchain_deepseek import ChatDeepSeek
+# Create a disk-backed SQLite cache:
+# Import Together provider
+from langchain_together import ChatTogether
+
+logger = logging.getLogger(__name__)
+
+# Load .env file when this module is loaded
+load_dotenv()
+
+def get_llm_client(provider: Optional[str] = None, model_name: Optional[str] = None, **kwargs) -> BaseChatModel:
+ """Initializes and returns the chosen LLM client based on provider.
+
+ Reads provider, model, and API keys from environment variables if not passed directly.
+ Defaults to OpenAI GPT-4o-mini if no provider/model specified.
+ """
+ # Prioritize arguments, then environment variables, then defaults
+ provider = provider or os.getenv("LLM_PROVIDER", "openai")
+ provider = provider.lower()
+
+ # Default model depends on provider
+ default_models = {
+ "openai": "gpt-4o-mini",
+ "anthropic": "claude-3-5-sonnet-latest",
+ "together": "deepseek-ai/DeepSeek-V3", # Default Together model
+ "gemini" : "gemini-2.5-flash",
+ "deepseek" : "deepseek-chat"
+ }
+
+ model_name = model_name or os.getenv("LLM_MODEL", default_models.get(provider, default_models["openai"]))
+
+ api_key = None
+ if model_name not in ['o3-mini', 'o3', 'o4-mini']:
+ kwargs.setdefault("temperature", 0) # Default temperature if not provided
+
+ logger.info(f"Initializing LLM client: Provider='{provider}', Model='{model_name}'")
+
+ try:
+ if provider == "openai":
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ raise ValueError("OPENAI_API_KEY not found in environment.")
+ return ChatOpenAI(model=model_name, api_key=api_key, **kwargs)
+
+ elif provider == "anthropic":
+ api_key = os.getenv("ANTHROPIC_API_KEY")
+ if not api_key:
+ raise ValueError("ANTHROPIC_API_KEY not found in environment.")
+ return ChatAnthropic(model=model_name, api_key=api_key, **kwargs, streaming=False)
+
+ elif provider == "together":
+ api_key = os.getenv("TOGETHER_API_KEY")
+ if not api_key:
+ raise ValueError("TOGETHER_API_KEY not found in environment.")
+ return ChatTogether(model=model_name, api_key=api_key, **kwargs)
+
+ elif provider == "gemini":
+ api_key = os.getenv("GEMINI_API_KEY")
+ if not api_key:
+ raise ValueError("GEMINI_API_KEY not found in environment.")
+ return ChatGoogleGenerativeAI(model=model_name, **kwargs, function_calling="auto")
+
+ elif provider == "deepseek":
+ api_key = os.getenv("DEEPSEEK_API_KEY")
+ if not api_key:
+ raise ValueError("DEEPSEEK_API_KEY not found in environment.")
+ return ChatDeepSeek(model=model_name, **kwargs)
+
+ # Example for Ollama (ensure langchain_community is installed)
+ # elif provider == "ollama":
+ # try:
+ # from langchain_community.chat_models import ChatOllama
+ # return ChatOllama(model=model_name, **kwargs)
+ # except ImportError:
+ # raise ValueError("langchain_community needed for Ollama. Run `pip install langchain-community`")
+
+ else:
+ raise ValueError(f"Unsupported LLM provider: {provider}")
+
+ except Exception as e:
+ logger.error(f"Failed to initialize LLM (Provider: {provider}, Model: {model_name}): {e}")
+ raise # Re-raise the exception
\ No newline at end of file
diff --git a/auto_causal/methods/__init__.py b/auto_causal/methods/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..89d220f0d592e6caeb8f6169c48f027dac38485c
--- /dev/null
+++ b/auto_causal/methods/__init__.py
@@ -0,0 +1,44 @@
+"""
+Causal inference methods for the auto_causal module.
+
+This package contains implementations of various causal inference methods
+that can be selected and applied by the auto_causal pipeline.
+"""
+
+from .causal_method import CausalMethod
+from .propensity_score.matching import estimate_effect as psm_estimate_effect
+from .propensity_score.weighting import estimate_effect as psw_estimate_effect
+from .instrumental_variable.estimator import estimate_effect as iv_estimate_effect
+from .difference_in_differences.estimator import estimate_effect as did_estimate_effect
+from .diff_in_means.estimator import estimate_effect as dim_estimate_effect
+from .linear_regression.estimator import estimate_effect as lr_estimate_effect
+from .backdoor_adjustment.estimator import estimate_effect as ba_estimate_effect
+from .regression_discontinuity.estimator import estimate_effect as rdd_estimate_effect
+from .generalized_propensity_score.estimator import estimate_effect_gps
+
+# Mapping of method names to their implementation functions
+METHOD_MAPPING = {
+ "propensity_score_matching": psm_estimate_effect,
+ "propensity_score_weighting": psw_estimate_effect,
+ "instrumental_variable": iv_estimate_effect,
+ "difference_in_differences": did_estimate_effect,
+ "regression_discontinuity_design": rdd_estimate_effect,
+ "backdoor_adjustment": ba_estimate_effect,
+ "linear_regression": lr_estimate_effect,
+ "diff_in_means": dim_estimate_effect,
+ "generalized_propensity_score": estimate_effect_gps,
+}
+
+__all__ = [
+ "CausalMethod",
+ "psm_estimate_effect",
+ "psw_estimate_effect",
+ "iv_estimate_effect",
+ "did_estimate_effect",
+ "rdd_estimate_effect",
+ "dim_estimate_effect",
+ "lr_estimate_effect",
+ "ba_estimate_effect",
+ "METHOD_MAPPING",
+ "estimate_effect_gps",
+]
diff --git a/auto_causal/methods/backdoor_adjustment/__init__.py b/auto_causal/methods/backdoor_adjustment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/auto_causal/methods/backdoor_adjustment/diagnostics.py b/auto_causal/methods/backdoor_adjustment/diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..635844c78b142cd863010eb9958c72d02d886fc1
--- /dev/null
+++ b/auto_causal/methods/backdoor_adjustment/diagnostics.py
@@ -0,0 +1,92 @@
+"""
+Diagnostic checks for Backdoor Adjustment models (typically OLS).
+"""
+
+from typing import Dict, Any, List
+import statsmodels.api as sm
+from statsmodels.stats.diagnostic import het_breuschpagan
+from statsmodels.stats.stattools import jarque_bera, durbin_watson
+from statsmodels.regression.linear_model import RegressionResultsWrapper
+from statsmodels.stats.outliers_influence import variance_inflation_factor
+import pandas as pd
+import numpy as np
+import logging
+
+logger = logging.getLogger(__name__)
+
+def run_backdoor_diagnostics(results: RegressionResultsWrapper, X: pd.DataFrame) -> Dict[str, Any]:
+ """
+ Runs diagnostic checks on a fitted OLS model used for backdoor adjustment.
+
+ Args:
+ results: A fitted statsmodels OLS results object.
+ X: The design matrix (including constant and all predictors) used.
+
+ Returns:
+ Dictionary containing diagnostic metrics.
+ """
+ diagnostics = {}
+ details = {}
+
+ try:
+ details['r_squared'] = results.rsquared
+ details['adj_r_squared'] = results.rsquared_adj
+ details['f_statistic'] = results.fvalue
+ details['f_p_value'] = results.f_pvalue
+ details['n_observations'] = int(results.nobs)
+ details['degrees_of_freedom_resid'] = int(results.df_resid)
+ details['durbin_watson'] = durbin_watson(results.resid) if results.nobs > 5 else 'N/A (Too few obs)' # Autocorrelation
+
+ # --- Normality of Residuals (Jarque-Bera) ---
+ try:
+ if results.nobs >= 2:
+ jb_value, jb_p_value, skew, kurtosis = jarque_bera(results.resid)
+ details['residuals_normality_jb_stat'] = jb_value
+ details['residuals_normality_jb_p_value'] = jb_p_value
+ details['residuals_skewness'] = skew
+ details['residuals_kurtosis'] = kurtosis
+ details['residuals_normality_status'] = "Normal" if jb_p_value > 0.05 else "Non-Normal"
+ else:
+ details['residuals_normality_status'] = "N/A (Too few obs)"
+ except Exception as e:
+ logger.warning(f"Could not run Jarque-Bera test: {e}")
+ details['residuals_normality_status'] = "Test Failed"
+
+ # --- Homoscedasticity (Breusch-Pagan) ---
+ try:
+ if X.shape[0] > X.shape[1]: # Needs more observations than predictors
+ lm_stat, lm_p_value, f_stat, f_p_value = het_breuschpagan(results.resid, X)
+ details['homoscedasticity_bp_lm_stat'] = lm_stat
+ details['homoscedasticity_bp_lm_p_value'] = lm_p_value
+ details['homoscedasticity_status'] = "Homoscedastic" if lm_p_value > 0.05 else "Heteroscedastic"
+ else:
+ details['homoscedasticity_status'] = "N/A (Too few obs or too many predictors)"
+ except Exception as e:
+ logger.warning(f"Could not run Breusch-Pagan test: {e}")
+ details['homoscedasticity_status'] = "Test Failed"
+
+ # --- Multicollinearity (VIF - Placeholder/Basic) ---
+ # Full VIF requires calculating for each predictor vs others.
+ # Providing a basic status based on condition number as a proxy.
+ try:
+ cond_no = np.linalg.cond(results.model.exog)
+ details['model_condition_number'] = cond_no
+ if cond_no > 30:
+ details['multicollinearity_status'] = "High (Cond. No. > 30)"
+ elif cond_no > 10:
+ details['multicollinearity_status'] = "Moderate (Cond. No. > 10)"
+ else:
+ details['multicollinearity_status'] = "Low"
+ except Exception as e:
+ logger.warning(f"Could not calculate condition number: {e}")
+ details['multicollinearity_status'] = "Check Failed"
+ # details['VIF'] = "Not Fully Implemented"
+
+ # --- Linearity (Still requires visual inspection) ---
+ details['linearity_check'] = "Requires visual inspection (e.g., residual vs fitted plot)"
+
+ return {"status": "Success", "details": details}
+
+ except Exception as e:
+ logger.error(f"Error running Backdoor Adjustment diagnostics: {e}")
+ return {"status": "Failed", "error": str(e), "details": details}
diff --git a/auto_causal/methods/backdoor_adjustment/estimator.py b/auto_causal/methods/backdoor_adjustment/estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..37a43c661e77e202223e7bd362ef742ca3088705
--- /dev/null
+++ b/auto_causal/methods/backdoor_adjustment/estimator.py
@@ -0,0 +1,105 @@
+"""
+Backdoor Adjustment Estimator using Regression.
+
+Estimates the Average Treatment Effect (ATE) by regressing the outcome on the
+treatment and a set of covariates assumed to satisfy the backdoor criterion.
+"""
+import pandas as pd
+import numpy as np
+import statsmodels.api as sm
+from typing import Dict, Any, List, Optional
+import logging
+from langchain.chat_models.base import BaseChatModel # For type hinting llm
+
+# Import diagnostics and llm assist (placeholders for now)
+from .diagnostics import run_backdoor_diagnostics
+from .llm_assist import interpret_backdoor_results, identify_backdoor_set
+
+logger = logging.getLogger(__name__)
+
+def estimate_effect(
+ df: pd.DataFrame,
+ treatment: str,
+ outcome: str,
+ covariates: List[str], # Backdoor set - Required for this method
+ query: Optional[str] = None, # For potential LLM use
+ llm: Optional[BaseChatModel] = None, # For potential LLM use
+ **kwargs # To capture any other potential arguments
+) -> Dict[str, Any]:
+ """
+ Estimates the causal effect using Backdoor Adjustment (via OLS regression).
+
+ Assumes the provided `covariates` list satisfies the backdoor criterion.
+
+ Args:
+ df: Input DataFrame.
+ treatment: Name of the treatment variable column.
+ outcome: Name of the outcome variable column.
+ covariates: List of covariate names forming the backdoor adjustment set.
+ query: Optional user query for context (e.g., for LLM).
+ llm: Optional Language Model instance.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ Dictionary containing estimation results:
+ - 'effect_estimate': The estimated coefficient for the treatment variable.
+ - 'p_value': The p-value associated with the treatment coefficient.
+ - 'confidence_interval': The 95% confidence interval for the effect.
+ - 'standard_error': The standard error of the treatment coefficient.
+ - 'formula': The regression formula used.
+ - 'model_summary': Summary object from statsmodels.
+ - 'diagnostics': Placeholder for diagnostic results.
+ - 'interpretation': LLM interpretation.
+ """
+ if not covariates: # Check if the list is empty or None
+ raise ValueError("Backdoor Adjustment requires a non-empty list of covariates (adjustment set).")
+
+ required_cols = [treatment, outcome] + covariates
+ missing_cols = [col for col in required_cols if col not in df.columns]
+ if missing_cols:
+ raise ValueError(f"Missing required columns for Backdoor Adjustment: {missing_cols}")
+
+ # Prepare data for statsmodels (add constant, handle potential NaNs)
+ df_analysis = df[required_cols].dropna()
+ if df_analysis.empty:
+ raise ValueError("No data remaining after dropping NaNs for required columns.")
+
+ X = df_analysis[[treatment] + covariates]
+ X = sm.add_constant(X) # Add intercept
+ y = df_analysis[outcome]
+
+ # Build the formula string for reporting
+ formula = f"{outcome} ~ {treatment} + " + " + ".join(covariates) + " + const"
+ logger.info(f"Running Backdoor Adjustment regression: {formula}")
+
+ try:
+ model = sm.OLS(y, X)
+ results = model.fit()
+
+ effect_estimate = results.params[treatment]
+ p_value = results.pvalues[treatment]
+ conf_int = results.conf_int(alpha=0.05).loc[treatment].tolist()
+ std_err = results.bse[treatment]
+
+ # Run diagnostics (Placeholders)
+ # Pass the full design matrix X for potential VIF checks etc.
+ diag_results = run_backdoor_diagnostics(results, X)
+
+ # Get interpretation
+ interpretation = interpret_backdoor_results(results, diag_results, treatment, covariates, llm=llm)
+
+ return {
+ 'effect_estimate': effect_estimate,
+ 'p_value': p_value,
+ 'confidence_interval': conf_int,
+ 'standard_error': std_err,
+ 'formula': formula,
+ 'model_summary': results.summary(),
+ 'diagnostics': diag_results,
+ 'interpretation': interpretation,
+ 'method_used': 'Backdoor Adjustment (OLS)'
+ }
+
+ except Exception as e:
+ logger.error(f"Backdoor Adjustment failed: {e}")
+ raise
diff --git a/auto_causal/methods/backdoor_adjustment/llm_assist.py b/auto_causal/methods/backdoor_adjustment/llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..6385dbd2c5043fb874ac39bebde4770691a80f3e
--- /dev/null
+++ b/auto_causal/methods/backdoor_adjustment/llm_assist.py
@@ -0,0 +1,176 @@
+"""
+LLM assistance functions for Backdoor Adjustment analysis.
+"""
+
+from typing import List, Dict, Any, Optional
+import logging
+
+# Imported for type hinting
+from langchain.chat_models.base import BaseChatModel
+from statsmodels.regression.linear_model import RegressionResultsWrapper
+
+# Import shared LLM helpers
+from auto_causal.utils.llm_helpers import call_llm_with_json_output
+
+logger = logging.getLogger(__name__)
+
+def identify_backdoor_set(
+ df_cols: List[str],
+ treatment: str,
+ outcome: str,
+ query: Optional[str] = None,
+ existing_covariates: Optional[List[str]] = None, # Allow user to provide some
+ llm: Optional[BaseChatModel] = None
+) -> List[str]:
+ """
+ Use LLM to suggest a potential backdoor adjustment set (confounders).
+
+ Tries to identify variables that affect both treatment and outcome.
+
+ Args:
+ df_cols: List of available column names in the dataset.
+ treatment: Treatment variable name.
+ outcome: Outcome variable name.
+ query: User's causal query text (provides context).
+ existing_covariates: Covariates already considered/provided by user.
+ llm: Optional LLM model instance.
+
+ Returns:
+ List of suggested variable names for the backdoor adjustment set.
+ """
+ if llm is None:
+ logger.warning("No LLM provided for backdoor set identification.")
+ return existing_covariates or []
+
+ # Exclude treatment and outcome from potential confounders
+ potential_confounders = [c for c in df_cols if c not in [treatment, outcome]]
+ if not potential_confounders:
+ return existing_covariates or []
+
+ prompt = f"""
+ You are assisting with identifying a backdoor adjustment set for causal inference.
+ The goal is to find observed variables that confound the relationship between the treatment and outcome.
+ Assume the causal effect of '{treatment}' on '{outcome}' is of interest.
+
+ User query context (optional): {query}
+ Available variables in the dataset (excluding treatment and outcome): {potential_confounders}
+ Variables already specified as covariates by user (if any): {existing_covariates}
+
+ Based *only* on the variable names and the query context, identify which of the available variables are likely to be common causes (confounders) of both '{treatment}' and '{outcome}'.
+ These variables should be included in the backdoor adjustment set.
+ Consider variables that likely occurred *before* or *at the same time as* the treatment.
+
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
+ {{
+ "suggested_backdoor_set": ["confounder1", "confounder2", ...]
+ }}
+ Include variables from the user-provided list if they seem appropriate as confounders.
+ If no plausible confounders are identified among the available variables, return an empty list.
+ """
+
+ response = call_llm_with_json_output(llm, prompt)
+
+ suggested_set = []
+ if response and "suggested_backdoor_set" in response and isinstance(response["suggested_backdoor_set"], list):
+ # Basic validation
+ valid_vars = [item for item in response["suggested_backdoor_set"] if isinstance(item, str)]
+ if len(valid_vars) != len(response["suggested_backdoor_set"]):
+ logger.warning("LLM returned non-string items in suggested_backdoor_set list.")
+ suggested_set = valid_vars
+ else:
+ logger.warning(f"Failed to get valid backdoor set recommendations from LLM. Response: {response}")
+
+ # Combine with existing covariates, removing duplicates
+ final_set = list(dict.fromkeys((existing_covariates or []) + suggested_set))
+ return final_set
+
+def interpret_backdoor_results(
+ results: RegressionResultsWrapper,
+ diagnostics: Dict[str, Any],
+ treatment_var: str,
+ covariates: List[str],
+ llm: Optional[BaseChatModel] = None
+) -> str:
+ """
+ Use LLM to interpret Backdoor Adjustment results.
+
+ Args:
+ results: Fitted statsmodels OLS results object.
+ diagnostics: Dictionary of diagnostic results.
+ treatment_var: Name of the treatment variable.
+ covariates: List of covariates used in the adjustment set.
+ llm: Optional LLM model instance.
+
+ Returns:
+ String containing natural language interpretation.
+ """
+ default_interpretation = "LLM interpretation not available for Backdoor Adjustment."
+ if llm is None:
+ logger.info("LLM not provided for Backdoor Adjustment interpretation.")
+ return default_interpretation
+
+ try:
+ # --- Prepare summary for LLM ---
+ results_summary = {}
+ diag_details = diagnostics.get('details', {})
+
+ effect = results.params.get(treatment_var)
+ pval = results.pvalues.get(treatment_var)
+
+ results_summary['Treatment Effect Estimate'] = f"{effect:.3f}" if isinstance(effect, (int, float)) else str(effect)
+ results_summary['P-value'] = f"{pval:.3f}" if isinstance(pval, (int, float)) else str(pval)
+ try:
+ conf_int = results.conf_int().loc[treatment_var]
+ results_summary['95% Confidence Interval'] = f"[{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
+ except KeyError:
+ results_summary['95% Confidence Interval'] = "Not Found"
+ except Exception as ci_e:
+ results_summary['95% Confidence Interval'] = f"Error ({ci_e})"
+
+ results_summary['Adjustment Set (Covariates Used)'] = covariates
+ results_summary['Model R-squared'] = f"{diagnostics.get('details', {}).get('r_squared', 'N/A'):.3f}" if isinstance(diagnostics.get('details', {}).get('r_squared'), (int, float)) else "N/A"
+
+ diag_summary = {}
+ if diagnostics.get("status") == "Success":
+ diag_summary['Residuals Normality Status'] = diag_details.get('residuals_normality_status', 'N/A')
+ diag_summary['Homoscedasticity Status'] = diag_details.get('homoscedasticity_status', 'N/A')
+ diag_summary['Multicollinearity Status'] = diag_details.get('multicollinearity_status', 'N/A')
+ else:
+ diag_summary['Status'] = diagnostics.get("status", "Unknown")
+
+ # --- Construct Prompt ---
+ prompt = f"""
+ You are assisting with interpreting Backdoor Adjustment (Regression) results.
+ The key assumption is that the specified adjustment set (covariates) blocks all confounding paths between the treatment ('{treatment_var}') and outcome.
+
+ Results Summary:
+ {results_summary}
+
+ Diagnostics Summary (OLS model checks):
+ {diag_summary}
+
+ Explain these results in 2-4 concise sentences. Focus on:
+ 1. The estimated average treatment effect after adjusting for the specified covariates (magnitude, direction, statistical significance based on p-value < 0.05).
+ 2. **Crucially, mention that this estimate relies heavily on the assumption that the included covariates ('{str(covariates)[:100]}...') are sufficient to control for confounding (i.e., satisfy the backdoor criterion).**
+ 3. Briefly mention any major OLS diagnostic issues noted (e.g., non-normal residuals, heteroscedasticity, high multicollinearity).
+
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
+ {{
+ "interpretation": ""
+ }}
+ """
+
+ # --- Call LLM ---
+ response = call_llm_with_json_output(llm, prompt)
+
+ # --- Process Response ---
+ if response and isinstance(response, dict) and \
+ "interpretation" in response and isinstance(response["interpretation"], str):
+ return response["interpretation"]
+ else:
+ logger.warning(f"Failed to get valid interpretation from LLM for Backdoor Adj. Response: {response}")
+ return default_interpretation
+
+ except Exception as e:
+ logger.error(f"Error during LLM interpretation for Backdoor Adj: {e}")
+ return f"Error generating interpretation: {e}"
diff --git a/auto_causal/methods/causal_method.py b/auto_causal/methods/causal_method.py
new file mode 100644
index 0000000000000000000000000000000000000000..17bb5dd5f4d14ad0f16b20fc6e6e436927b79d43
--- /dev/null
+++ b/auto_causal/methods/causal_method.py
@@ -0,0 +1,88 @@
+"""
+Abstract base class for all causal inference methods.
+
+This module defines the interface that all causal inference methods
+must implement, ensuring consistent behavior across different methods.
+"""
+
+from abc import ABC, abstractmethod
+from typing import Dict, List, Any
+import pandas as pd
+
+
+class CausalMethod(ABC):
+ """Base class for all causal inference methods.
+
+ This abstract class defines the required methods that all causal
+ inference implementations must provide. It ensures a consistent
+ interface across different methods like propensity score matching,
+ instrumental variables, etc.
+
+ Each implementation should handle the specifics of the causal
+ inference method while conforming to this interface.
+ """
+
+ @abstractmethod
+ def validate_assumptions(self, df: pd.DataFrame, treatment: str,
+ outcome: str, covariates: List[str]) -> Dict[str, Any]:
+ """Validate method assumptions against the dataset.
+
+ Args:
+ df: DataFrame containing the dataset
+ treatment: Name of the treatment variable column
+ outcome: Name of the outcome variable column
+ covariates: List of covariate column names
+
+ Returns:
+ Dict containing validation results with keys:
+ - assumptions_valid (bool): Whether all assumptions are met
+ - failed_assumptions (List[str]): List of failed assumptions
+ - warnings (List[str]): List of warnings
+ - suggestions (List[str]): Suggestions for addressing issues
+ """
+ pass
+
+ @abstractmethod
+ def estimate_effect(self, df: pd.DataFrame, treatment: str,
+ outcome: str, covariates: List[str]) -> Dict[str, Any]:
+ """Estimate causal effect using this method.
+
+ Args:
+ df: DataFrame containing the dataset
+ treatment: Name of the treatment variable column
+ outcome: Name of the outcome variable column
+ covariates: List of covariate column names
+
+ Returns:
+ Dict containing estimation results with keys:
+ - effect_estimate (float): Estimated causal effect
+ - confidence_interval (tuple): Confidence interval (lower, upper)
+ - p_value (float): P-value of the estimate
+ - additional_metrics (Dict): Any method-specific metrics
+ """
+ pass
+
+ @abstractmethod
+ def generate_code(self, dataset_path: str, treatment: str,
+ outcome: str, covariates: List[str]) -> str:
+ """Generate executable code for this causal method.
+
+ Args:
+ dataset_path: Path to the dataset file
+ treatment: Name of the treatment variable column
+ outcome: Name of the outcome variable column
+ covariates: List of covariate column names
+
+ Returns:
+ String containing executable Python code implementing this method
+ """
+ pass
+
+ @abstractmethod
+ def explain(self) -> str:
+ """Explain this causal method, its assumptions, and when to use it.
+
+ Returns:
+ String with detailed explanation of the method
+ """
+ pass
\ No newline at end of file
diff --git a/auto_causal/methods/diff_in_means/__init__.py b/auto_causal/methods/diff_in_means/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/auto_causal/methods/diff_in_means/diagnostics.py b/auto_causal/methods/diff_in_means/diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bee7489aadcd6e2452eeb611b48624a4aca3c45
--- /dev/null
+++ b/auto_causal/methods/diff_in_means/diagnostics.py
@@ -0,0 +1,60 @@
+"""
+Basic descriptive statistics for Difference in Means.
+"""
+
+from typing import Dict, Any
+import pandas as pd
+import numpy as np
+import logging
+
+logger = logging.getLogger(__name__)
+
+def run_dim_diagnostics(df: pd.DataFrame, treatment: str, outcome: str) -> Dict[str, Any]:
+ """
+ Calculates basic descriptive statistics for treatment and control groups.
+
+ Args:
+ df: Input DataFrame (should already be filtered for NaNs in treatment/outcome).
+ treatment: Name of the binary treatment variable column.
+ outcome: Name of the outcome variable column.
+
+ Returns:
+ Dictionary containing group means, standard deviations, and counts.
+ """
+ details = {}
+ try:
+ grouped = df.groupby(treatment)[outcome]
+ stats = grouped.agg(['mean', 'std', 'count'])
+
+ # Ensure both groups (0 and 1) are present if possible
+ control_stats = stats.loc[0].to_dict() if 0 in stats.index else {'mean': np.nan, 'std': np.nan, 'count': 0}
+ treated_stats = stats.loc[1].to_dict() if 1 in stats.index else {'mean': np.nan, 'std': np.nan, 'count': 0}
+
+ details['control_group_stats'] = control_stats
+ details['treated_group_stats'] = treated_stats
+
+ if control_stats['count'] == 0 or treated_stats['count'] == 0:
+ logger.warning("One or both treatment groups have zero observations.")
+ return {"status": "Warning - Empty Group(s)", "details": details}
+
+ # Simple check for variance difference (Levene's test could be added)
+ control_std = control_stats.get('std', 0)
+ treated_std = treated_stats.get('std', 0)
+ if control_std > 0 and treated_std > 0:
+ ratio = (control_std**2) / (treated_std**2)
+ details['variance_ratio_control_div_treated'] = ratio
+ if ratio > 4 or ratio < 0.25: # Rule of thumb
+ details['variance_homogeneity_status'] = "Potentially Unequal (ratio > 4 or < 0.25)"
+ else:
+ details['variance_homogeneity_status'] = "Likely Similar"
+ else:
+ details['variance_homogeneity_status'] = "Could not calculate (zero variance in a group)"
+
+ return {"status": "Success", "details": details}
+
+ except KeyError as ke:
+ logger.error(f"KeyError during diagnostics: {ke}. Treatment levels might not be 0/1.")
+ return {"status": "Failed", "error": f"Treatment levels might not be 0/1: {ke}", "details": details}
+ except Exception as e:
+ logger.error(f"Error running Difference in Means diagnostics: {e}")
+ return {"status": "Failed", "error": str(e), "details": details}
diff --git a/auto_causal/methods/diff_in_means/estimator.py b/auto_causal/methods/diff_in_means/estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..12bcd49f76c3e996b1242368a8bab2842626d662
--- /dev/null
+++ b/auto_causal/methods/diff_in_means/estimator.py
@@ -0,0 +1,107 @@
+"""
+Difference in Means / Simple Linear Regression Estimator.
+
+Estimates the Average Treatment Effect (ATE) by comparing the mean outcome
+between the treated and control groups. This is equivalent to a simple OLS
+regression of the outcome on the treatment indicator.
+
+Assumes no confounding (e.g., suitable for RCT data).
+"""
+import pandas as pd
+import statsmodels.api as sm
+import numpy as np
+import warnings
+from typing import Dict, Any, Optional
+import logging
+from langchain.chat_models.base import BaseChatModel # For type hinting llm
+
+from .diagnostics import run_dim_diagnostics
+from .llm_assist import interpret_dim_results
+
+logger = logging.getLogger(__name__)
+
+def estimate_effect(
+ df: pd.DataFrame,
+ treatment: str,
+ outcome: str,
+ query: Optional[str] = None, # For potential LLM use
+ llm: Optional[BaseChatModel] = None, # For potential LLM use
+ **kwargs # To capture any other potential arguments (e.g., covariates - which are ignored)
+) -> Dict[str, Any]:
+ """
+ Estimates the causal effect using Difference in Means (via OLS).
+
+ Ignores any provided covariates.
+
+ Args:
+ df: Input DataFrame.
+ treatment: Name of the binary treatment variable column (should be 0 or 1).
+ outcome: Name of the outcome variable column.
+ query: Optional user query for context.
+ llm: Optional Language Model instance.
+ **kwargs: Additional keyword arguments (ignored).
+
+ Returns:
+ Dictionary containing estimation results:
+ - 'effect_estimate': The difference in means (treatment coefficient).
+ - 'p_value': The p-value associated with the difference.
+ - 'confidence_interval': The 95% confidence interval for the difference.
+ - 'standard_error': The standard error of the difference.
+ - 'formula': The regression formula used.
+ - 'model_summary': Summary object from statsmodels.
+ - 'diagnostics': Basic group statistics.
+ - 'interpretation': LLM interpretation.
+ """
+ required_cols = [treatment, outcome]
+ missing_cols = [col for col in required_cols if col not in df.columns]
+ if missing_cols:
+ raise ValueError(f"Missing required columns: {missing_cols}")
+
+ # Validate treatment is binary (or close to it)
+ treat_vals = df[treatment].dropna().unique()
+ if not np.all(np.isin(treat_vals, [0, 1])):
+ warnings.warn(f"Treatment column '{treatment}' contains values other than 0 and 1: {treat_vals}. Proceeding, but results may be unreliable.", UserWarning)
+ # Optional: could raise ValueError here if strict binary is required
+
+ # Prepare data for statsmodels (add constant, handle potential NaNs)
+ df_analysis = df[required_cols].dropna()
+ if df_analysis.empty:
+ raise ValueError("No data remaining after dropping NaNs for required columns.")
+
+ X = df_analysis[[treatment]]
+ X = sm.add_constant(X) # Add intercept
+ y = df_analysis[outcome]
+
+ formula = f"{outcome} ~ {treatment} + const"
+ logger.info(f"Running Difference in Means regression: {formula}")
+
+ try:
+ model = sm.OLS(y, X)
+ results = model.fit()
+
+ effect_estimate = results.params[treatment]
+ p_value = results.pvalues[treatment]
+ conf_int = results.conf_int(alpha=0.05).loc[treatment].tolist()
+ std_err = results.bse[treatment]
+
+ # Run basic diagnostics (group means, stds, counts)
+ diag_results = run_dim_diagnostics(df_analysis, treatment, outcome)
+
+ # Get interpretation
+ interpretation = interpret_dim_results(results, diag_results, treatment, llm=llm)
+
+ return {
+ 'effect_estimate': effect_estimate,
+ 'p_value': p_value,
+ 'confidence_interval': conf_int,
+ 'standard_error': std_err,
+ 'formula': formula,
+ 'model_summary': results.summary(),
+ 'diagnostics': diag_results,
+ 'interpretation': interpretation,
+ 'method_used': 'Difference in Means (OLS)'
+ }
+
+ except Exception as e:
+ logger.error(f"Difference in Means failed: {e}")
+ raise
diff --git a/auto_causal/methods/diff_in_means/llm_assist.py b/auto_causal/methods/diff_in_means/llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f02aff2dfc58db77c2e8383a3c3439c7e3a5164
--- /dev/null
+++ b/auto_causal/methods/diff_in_means/llm_assist.py
@@ -0,0 +1,95 @@
+"""
+LLM assistance functions for Difference in Means analysis.
+"""
+
+from typing import Dict, Any, Optional
+import logging
+
+# Imported for type hinting
+from langchain.chat_models.base import BaseChatModel
+from statsmodels.regression.linear_model import RegressionResultsWrapper
+
+# Import shared LLM helpers
+from auto_causal.utils.llm_helpers import call_llm_with_json_output
+
+logger = logging.getLogger(__name__)
+
+def interpret_dim_results(
+ results: RegressionResultsWrapper,
+ diagnostics: Dict[str, Any],
+ treatment_var: str,
+ llm: Optional[BaseChatModel] = None
+) -> str:
+ """
+ Use LLM to interpret Difference in Means results.
+
+ Args:
+ results: Fitted statsmodels OLS results object (from outcome ~ treatment).
+ diagnostics: Dictionary of diagnostic results (group stats).
+ treatment_var: Name of the treatment variable.
+ llm: Optional LLM model instance.
+
+ Returns:
+ String containing natural language interpretation.
+ """
+ default_interpretation = "LLM interpretation not available for Difference in Means."
+ if llm is None:
+ logger.info("LLM not provided for Difference in Means interpretation.")
+ return default_interpretation
+
+ try:
+ # --- Prepare summary for LLM ---
+ results_summary = {}
+ diag_details = diagnostics.get('details', {})
+ control_stats = diag_details.get('control_group_stats', {})
+ treated_stats = diag_details.get('treated_group_stats', {})
+
+ effect = results.params.get(treatment_var)
+ pval = results.pvalues.get(treatment_var)
+
+ results_summary['Effect Estimate (Difference in Means)'] = f"{effect:.3f}" if isinstance(effect, (int, float)) else str(effect)
+ results_summary['P-value'] = f"{pval:.3f}" if isinstance(pval, (int, float)) else str(pval)
+ try:
+ conf_int = results.conf_int().loc[treatment_var]
+ results_summary['95% Confidence Interval'] = f"[{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
+ except KeyError:
+ results_summary['95% Confidence Interval'] = "Not Found"
+ except Exception as ci_e:
+ results_summary['95% Confidence Interval'] = f"Error ({ci_e})"
+
+ results_summary['Control Group Mean Outcome'] = f"{control_stats.get('mean', 'N/A'):.3f}" if isinstance(control_stats.get('mean'), (int, float)) else str(control_stats.get('mean'))
+ results_summary['Treated Group Mean Outcome'] = f"{treated_stats.get('mean', 'N/A'):.3f}" if isinstance(treated_stats.get('mean'), (int, float)) else str(treated_stats.get('mean'))
+ results_summary['Control Group Size'] = control_stats.get('count', 'N/A')
+ results_summary['Treated Group Size'] = treated_stats.get('count', 'N/A')
+
+ # --- Construct Prompt ---
+ prompt = f"""
+ You are assisting with interpreting Difference in Means results, likely from an RCT.
+
+ Results Summary:
+ {results_summary}
+
+ Explain these results in 1-3 concise sentences. Focus on:
+ 1. The estimated average treatment effect (magnitude, direction, statistical significance based on p-value < 0.05).
+ 2. Compare the mean outcomes between the treated and control groups.
+
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
+ {{
+ "interpretation": ""
+ }}
+ """
+
+ # --- Call LLM ---
+ response = call_llm_with_json_output(llm, prompt)
+
+ # --- Process Response ---
+ if response and isinstance(response, dict) and \
+ "interpretation" in response and isinstance(response["interpretation"], str):
+ return response["interpretation"]
+ else:
+ logger.warning(f"Failed to get valid interpretation from LLM for Difference in Means. Response: {response}")
+ return default_interpretation
+
+ except Exception as e:
+ logger.error(f"Error during LLM interpretation for Difference in Means: {e}")
+ return f"Error generating interpretation: {e}"
diff --git a/auto_causal/methods/difference_in_differences/diagnostics.py b/auto_causal/methods/difference_in_differences/diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f7377d58b985e5885c0f4e4d03b938c5944dbd4
--- /dev/null
+++ b/auto_causal/methods/difference_in_differences/diagnostics.py
@@ -0,0 +1,345 @@
+"""Diagnostic functions for Difference-in-Differences method."""
+
+import pandas as pd
+import numpy as np
+from typing import Dict, Any, Optional, List
+import logging
+import statsmodels.formula.api as smf # Import statsmodels
+from patsy import PatsyError # To catch formula errors
+
+# Import helper function from estimator -> Change to utils
+from .utils import create_post_indicator
+
+logger = logging.getLogger(__name__)
+
+def validate_parallel_trends(df: pd.DataFrame, time_var: str, outcome: str,
+ group_indicator_col: str, treatment_period_start: Any,
+ dataset_description: Optional[str] = None,
+ time_varying_covariates: Optional[List[str]] = None) -> Dict[str, Any]:
+ """Validates the parallel trends assumption using pre-treatment data.
+
+ Regresses the outcome on group-specific time trends before the treatment period.
+ Tests if the interaction terms between group and pre-treatment time periods are jointly significant.
+
+ Args:
+ df: DataFrame containing the data.
+ time_var: Name of the time variable column.
+ outcome: Name of the outcome variable column.
+ group_indicator_col: Name of the binary treatment group indicator column (0/1).
+ treatment_period_start: The time period value when treatment starts.
+ dataset_description: Optional dictionary for additional dataset description.
+ time_varying_covariates: Optional list of time-varying covariates to include.
+
+ Returns:
+ Dictionary with validation results.
+ """
+ logger.info("Validating parallel trends...")
+ validation_result = {"valid": False, "p_value": 1.0, "details": "", "error": None}
+
+ try:
+ # Filter pre-treatment data
+ pre_df = df[df[time_var] < treatment_period_start].copy()
+
+ if len(pre_df) < 20 or pre_df[group_indicator_col].nunique() < 2 or pre_df[time_var].nunique() < 2:
+ validation_result["details"] = "Insufficient pre-treatment data or variation to perform test."
+ logger.warning(validation_result["details"])
+ # Assume valid if cannot test? Or invalid? Let's default to True if we can't test
+ validation_result["valid"] = True
+ validation_result["details"] += " Defaulting to assuming parallel trends (unable to test)."
+ return validation_result
+
+ # Check if group indicator is binary
+ if pre_df[group_indicator_col].nunique() > 2:
+ validation_result["details"] = f"Group indicator '{group_indicator_col}' has more than 2 unique values. Using simple visual assessment."
+ logger.warning(validation_result["details"])
+ # Use visual assessment method instead (check if trends look roughly parallel)
+ validation_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col)
+ # Ensure p_value is set
+ if validation_result["p_value"] is None:
+ validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04
+ return validation_result
+
+ # Use a robust approach first - test for pre-trend differences using a simpler model
+ try:
+ # Create a linear time trend
+ pre_df['time_trend'] = pre_df[time_var].astype(float)
+
+ # Create interaction between trend and group
+ pre_df['group_trend'] = pre_df['time_trend'] * pre_df[group_indicator_col].astype(float)
+
+ # Simple regression with linear trend interaction
+ simple_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}') + time_trend + group_trend"
+ simple_model = smf.ols(simple_formula, data=pre_df)
+ simple_results = simple_model.fit()
+
+ # Check if trend interaction coefficient is significant
+ group_trend_pvalue = simple_results.pvalues['group_trend']
+
+ # If p > 0.05, trends are not significantly different
+ validation_result["valid"] = group_trend_pvalue > 0.05
+ validation_result["p_value"] = group_trend_pvalue
+ validation_result["details"] = f"Simple linear trend test: p-value for group-trend interaction: {group_trend_pvalue:.4f}. Parallel trends: {validation_result['valid']}."
+ logger.info(validation_result["details"])
+
+ # If we've successfully validated with the simple approach, return
+ return validation_result
+
+ except Exception as e:
+ logger.warning(f"Simple trend test failed: {e}. Trying alternative approach.")
+ # Continue to more complex method if simple method fails
+
+ # Try more complex approach with period-specific interactions
+ try:
+ # Create period dummies to avoid issues with categorical variables
+ time_periods = sorted(pre_df[time_var].unique())
+
+ # Create dummy variables for time periods (except first)
+ for period in time_periods[1:]:
+ period_col = f'period_{period}'
+ pre_df[period_col] = (pre_df[time_var] == period).astype(int)
+
+ # Create interaction with group
+ pre_df[f'group_x_{period_col}'] = pre_df[period_col] * pre_df[group_indicator_col].astype(float)
+
+ # Construct formula with manual dummies
+ interaction_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}')"
+
+ # Add period dummies except first (reference)
+ for period in time_periods[1:]:
+ period_col = f'period_{period}'
+ interaction_formula += f" + {period_col}"
+
+ # Add interactions
+ interaction_terms = []
+ for period in time_periods[1:]:
+ interaction_col = f'group_x_period_{period}'
+ interaction_formula += f" + {interaction_col}"
+ interaction_terms.append(interaction_col)
+
+ # Add covariates if provided
+ if time_varying_covariates:
+ for cov in time_varying_covariates:
+ interaction_formula += f" + Q('{cov}')"
+
+ # Fit model
+ complex_model = smf.ols(interaction_formula, data=pre_df)
+ complex_results = complex_model.fit()
+
+ # Test joint significance of interaction terms
+ if interaction_terms:
+ from statsmodels.formula.api import ols
+ from statsmodels.stats.anova import anova_lm
+
+ # Create models with and without interactions
+ formula_with = interaction_formula
+ formula_without = interaction_formula
+ for term in interaction_terms:
+ formula_without = formula_without.replace(f" + {term}", "")
+
+ model_with = smf.ols(formula_with, data=pre_df).fit()
+ model_without = smf.ols(formula_without, data=pre_df).fit()
+
+ # Compare models
+ try:
+ from scipy import stats
+ df_model = len(interaction_terms)
+ df_residual = model_with.df_resid
+ f_value = ((model_without.ssr - model_with.ssr) / df_model) / (model_with.ssr / df_residual)
+ p_value = 1 - stats.f.cdf(f_value, df_model, df_residual)
+
+ validation_result["valid"] = p_value > 0.05
+ validation_result["p_value"] = p_value
+ validation_result["details"] = f"Manual F-test for pre-treatment interactions: F({df_model}, {df_residual})={f_value:.4f}, p={p_value:.4f}. Parallel trends: {validation_result['valid']}."
+ logger.info(validation_result["details"])
+
+ except Exception as e:
+ logger.warning(f"Manual F-test failed: {e}. Using individual coefficient significance.")
+
+ # If F-test fails, check individual coefficients
+ significant_interactions = 0
+ for term in interaction_terms:
+ if term in complex_results.pvalues and complex_results.pvalues[term] < 0.05:
+ significant_interactions += 1
+
+ validation_result["valid"] = significant_interactions == 0
+ # Set a dummy p-value based on proportion of significant interactions
+ if len(interaction_terms) > 0:
+ validation_result["p_value"] = 1.0 - (significant_interactions / len(interaction_terms))
+ else:
+ validation_result["p_value"] = 1.0 # Default to 1.0 if no interaction terms
+ validation_result["details"] = f"{significant_interactions} out of {len(interaction_terms)} pre-treatment interactions are significant at p<0.05. Parallel trends: {validation_result['valid']}."
+ logger.info(validation_result["details"])
+ else:
+ validation_result["valid"] = True
+ validation_result["p_value"] = 1.0 # Default to 1.0 if no interaction terms
+ validation_result["details"] = "No pre-treatment interaction terms could be tested. Defaulting to assuming parallel trends."
+ logger.warning(validation_result["details"])
+
+ except Exception as e:
+ logger.warning(f"Complex trend test failed: {e}. Falling back to visual assessment.")
+ tmp_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col)
+ # Copy over values from visual assessment ensuring p_value is set
+ validation_result.update(tmp_result)
+ # Ensure p_value is set
+ if validation_result["p_value"] is None:
+ validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04
+
+ except Exception as e:
+ error_msg = f"Error during parallel trends validation: {e}"
+ logger.error(error_msg, exc_info=True)
+ validation_result["details"] = error_msg
+ validation_result["error"] = str(e)
+ # Default to assuming valid if test fails completely
+ validation_result["valid"] = True
+ validation_result["p_value"] = 1.0 # Default to 1.0 if test fails
+ validation_result["details"] += " Defaulting to assuming parallel trends (test failed)."
+
+ return validation_result
+
+def assess_trends_visually(df: pd.DataFrame, time_var: str, outcome: str,
+ group_indicator_col: str) -> Dict[str, Any]:
+ """Simple visual assessment of parallel trends by comparing group means over time.
+
+ This is a fallback method when statistical tests fail.
+ """
+ result = {"valid": False, "p_value": 1.0, "details": "", "error": None}
+
+ try:
+ # Group by time and treatment group, calculate means
+ grouped = df.groupby([time_var, group_indicator_col])[outcome].mean().reset_index()
+
+ # Pivot to get time series for each group
+ if df[group_indicator_col].nunique() <= 10: # Only if reasonable number of groups
+ pivot = grouped.pivot(index=time_var, columns=group_indicator_col, values=outcome)
+
+ # Calculate slopes between consecutive periods for each group
+ slopes = {}
+ time_values = sorted(df[time_var].unique())
+
+ if len(time_values) >= 3: # Need at least 3 periods to compare slopes
+ for group in pivot.columns:
+ group_slopes = []
+ for i in range(len(time_values) - 1):
+ t1, t2 = time_values[i], time_values[i+1]
+ if t1 in pivot.index and t2 in pivot.index:
+ slope = (pivot.loc[t2, group] - pivot.loc[t1, group]) / (t2 - t1)
+ group_slopes.append(slope)
+ if group_slopes:
+ slopes[group] = group_slopes
+
+ # Compare slopes between groups
+ if len(slopes) >= 2:
+ slope_diffs = []
+ groups = list(slopes.keys())
+ for i in range(len(slopes[groups[0]])):
+ if i < len(slopes[groups[1]]):
+ slope_diffs.append(abs(slopes[groups[0]][i] - slopes[groups[1]][i]))
+
+ # If average slope difference is small relative to outcome scale
+ outcome_scale = df[outcome].std()
+ avg_slope_diff = sum(slope_diffs) / len(slope_diffs) if slope_diffs else 0
+ relative_diff = avg_slope_diff / outcome_scale if outcome_scale > 0 else 0
+
+ result["valid"] = relative_diff < 0.2 # Threshold for "parallel enough"
+ # Set p-value based on relative difference
+ result["p_value"] = 1.0 - (relative_diff * 5) if relative_diff < 0.2 else 0.04
+ result["details"] = f"Visual assessment: relative slope difference = {relative_diff:.4f}. Parallel trends: {result['valid']}."
+ else:
+ result["valid"] = True
+ result["p_value"] = 1.0
+ result["details"] = "Visual assessment: insufficient group data for comparison. Defaulting to assuming parallel trends."
+ else:
+ result["valid"] = True
+ result["p_value"] = 1.0
+ result["details"] = "Visual assessment: insufficient time periods for comparison. Defaulting to assuming parallel trends."
+ else:
+ result["valid"] = True
+ result["p_value"] = 1.0
+ result["details"] = f"Visual assessment: too many groups ({df[group_indicator_col].nunique()}) for visual comparison. Defaulting to assuming parallel trends."
+
+ except Exception as e:
+ result["error"] = str(e)
+ result["valid"] = True
+ result["p_value"] = 1.0
+ result["details"] = f"Visual assessment failed: {e}. Defaulting to assuming parallel trends."
+
+ logger.info(result["details"])
+ return result
+
+def run_placebo_test(df: pd.DataFrame, time_var: str, group_var: str, outcome: str,
+ treated_unit_indicator: str, covariates: List[str],
+ treatment_period_start: Any,
+ placebo_period_start: Any) -> Dict[str, Any]:
+ """Runs a placebo test for DiD by assigning a fake earlier treatment period.
+
+ Re-runs the DiD estimation using the placebo period and checks if the effect is non-significant.
+
+ Args:
+ df: Original DataFrame.
+ time_var: Name of the time variable column.
+ group_var: Name of the unit/group ID column (for clustering SE).
+ outcome: Name of the outcome variable column.
+ treated_unit_indicator: Name of the binary treatment group indicator column (0/1).
+ covariates: List of covariate names.
+ treatment_period_start: The actual treatment start period.
+ placebo_period_start: The fake treatment start period (must be before actual start).
+
+ Returns:
+ Dictionary with placebo test results.
+ """
+ logger.info(f"Running placebo test assigning treatment start at {placebo_period_start}...")
+ placebo_result = {"passed": False, "effect_estimate": None, "p_value": None, "details": "", "error": None}
+
+ if placebo_period_start >= treatment_period_start:
+ error_msg = "Placebo period must be before the actual treatment period."
+ logger.error(error_msg)
+ placebo_result["error"] = error_msg
+ placebo_result["details"] = error_msg
+ return placebo_result
+
+ try:
+ df_placebo = df.copy()
+ # Create placebo post and interaction terms
+ post_placebo_col = 'post_placebo'
+ interaction_placebo_col = 'did_interaction_placebo'
+
+ df_placebo[post_placebo_col] = create_post_indicator(df_placebo, time_var, placebo_period_start)
+ df_placebo[interaction_placebo_col] = df_placebo[treated_unit_indicator] * df_placebo[post_placebo_col]
+
+ # Construct formula for placebo regression
+ formula = f"`{outcome}` ~ `{treated_unit_indicator}` + `{post_placebo_col}` + `{interaction_placebo_col}`"
+ if covariates:
+ formula += f" + {' + '.join([f'`{c}`' for c in covariates])}"
+ formula += f" + C(`{group_var}`) + C(`{time_var}`)" # Include FEs
+
+ logger.debug(f"Placebo test formula: {formula}")
+
+ # Fit the placebo model with clustered SE
+ ols_model = smf.ols(formula=formula, data=df_placebo)
+ results = ols_model.fit(cov_type='cluster', cov_kwds={'groups': df_placebo[group_var]})
+
+ # Check the significance of the placebo interaction term
+ placebo_effect = float(results.params[interaction_placebo_col])
+ placebo_p_value = float(results.pvalues[interaction_placebo_col])
+
+ # Test passes if the placebo effect is not statistically significant (e.g., p > 0.1)
+ passed_test = placebo_p_value > 0.10
+
+ placebo_result["passed"] = passed_test
+ placebo_result["effect_estimate"] = placebo_effect
+ placebo_result["p_value"] = placebo_p_value
+ placebo_result["details"] = f"Placebo treatment effect estimated at {placebo_effect:.4f} (p={placebo_p_value:.4f}). Test passed: {passed_test}."
+ logger.info(placebo_result["details"])
+
+ except (KeyError, PatsyError, ValueError, Exception) as e:
+ error_msg = f"Error during placebo test execution: {e}"
+ logger.error(error_msg, exc_info=True)
+ placebo_result["details"] = error_msg
+ placebo_result["error"] = str(e)
+
+ return placebo_result
+
+# TODO: Add function for Event Study plot (plot_event_study)
+# This would involve estimating effects for leads and lags around the treatment period.
+
+# Add other diagnostic functions as needed (e.g., plot_event_study)
\ No newline at end of file
diff --git a/auto_causal/methods/difference_in_differences/estimator.py b/auto_causal/methods/difference_in_differences/estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3bf42f26433ade9804ef6dc1de973b1f5be042d
--- /dev/null
+++ b/auto_causal/methods/difference_in_differences/estimator.py
@@ -0,0 +1,463 @@
+"""
+Difference-in-Differences Estimator using DoWhy with Statsmodels fallback.
+"""
+
+import logging
+import pandas as pd
+import numpy as np
+from typing import Dict, List, Optional, Any, Tuple
+from auto_causal.config import get_llm_client # IMPORT LLM Client Factory
+
+# DoWhy imports (Commented out for simplification)
+# from dowhy import CausalModel
+# from dowhy.causal_estimators import CausalEstimator
+# from dowhy.causal_estimator import CausalEstimate
+# Statsmodels import for estimation
+import statsmodels.formula.api as smf
+
+# Local imports
+from .llm_assist import (
+ identify_time_variable,
+ determine_treatment_period,
+ identify_treatment_group,
+ interpret_did_results
+)
+from .diagnostics import validate_parallel_trends # Import diagnostics
+# Import from the new utils module
+from .utils import create_post_indicator
+
+logger = logging.getLogger(__name__)
+
+# --- Helper functions moved from old file ---
+def format_did_results(statsmodels_results: Any, interaction_term_key: str,
+ validation_results: Dict[str, Any],
+ method_details: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
+ '''Formats the DiD results from statsmodels results into a standard dictionary.'''
+
+ try:
+ # Use the interaction_term_key passed directly
+ effect = float(statsmodels_results.params[interaction_term_key])
+ stderr = float(statsmodels_results.bse[interaction_term_key])
+ pval = float(statsmodels_results.pvalues[interaction_term_key])
+ ci = statsmodels_results.conf_int().loc[interaction_term_key].values.tolist()
+ ci_lower, ci_upper = float(ci[0]), float(ci[1])
+ logger.info(f"Extracted effect for '{interaction_term_key}'")
+
+ except KeyError:
+ logger.error(f"Interaction term '{interaction_term_key}' not found in statsmodels results. Available params: {statsmodels_results.params.index.tolist()}")
+ # Fallback to NaN if term not found
+ effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan
+ except Exception as e:
+ logger.error(f"Error extracting results from statsmodels object: {e}")
+ effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan
+
+ # Create a standardized results dictionary
+ results = {
+ "effect_estimate": effect,
+ "standard_error": stderr,
+ "p_value": pval,
+ "confidence_interval": [ci_lower, ci_upper],
+ "diagnostics": validation_results,
+ "parameters": parameters,
+ "details": str(statsmodels_results.summary())
+ }
+
+ return results
+
+# Comment out unused DoWhy result formatter
+# def format_dowhy_results(estimate: CausalEstimate,
+# validation_results: Dict[str, Any],
+# parameters: Dict[str, Any]) -> Dict[str, Any]:
+# '''Formats the DiD results from DoWhy causal estimate into a standard dictionary.'''
+
+# try:
+# # Extract values from DoWhy estimate
+# effect = float(estimate.value)
+# stderr = float(estimate.get_standard_error()) if hasattr(estimate, 'get_standard_error') else np.nan
+# ci_lower, ci_upper = estimate.get_confidence_intervals() if hasattr(estimate, 'get_confidence_intervals') else (np.nan, np.nan)
+# # Extract p-value if available, otherwise use NaN
+# pval = estimate.get_significance_test_results().get('p_value', np.nan) if hasattr(estimate, 'get_significance_test_results') else np.nan
+
+# # Get available details from estimate
+# details = str(estimate)
+# if hasattr(estimate, 'summary'):
+# details = str(estimate.summary())
+
+# logger.info(f"Extracted effect from DoWhy estimate: {effect}")
+
+# except Exception as e:
+# logger.error(f"Error extracting results from DoWhy estimate: {e}")
+# effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan
+# details = f"Error extracting DoWhy results: {e}"
+
+# # Create a standardized results dictionary
+# results = {
+# "effect_estimate": effect,
+# "effect_se": stderr,
+# "p_value": pval,
+# "confidence_interval": [ci_lower, ci_upper],
+# "diagnostics": validation_results,
+# "parameters": parameters,
+# "details": details,
+# "estimator": "dowhy"
+# }
+
+# return results
+
+# --- Main `estimate_effect` function ---
+
+def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str,
+ covariates: List[str],
+ dataset_description: Optional[str] = None,
+ query: Optional[str] = None,
+ **kwargs) -> Dict[str, Any]:
+ """Difference-in-Differences estimation using DoWhy with Statsmodels fallback.
+
+ Args:
+ df: Dataset containing causal variables
+ treatment: Name of treatment variable (or variable indicating treated group)
+ outcome: Name of outcome variable
+ covariates: List of covariate names
+ dataset_description: Optional dictionary describing the dataset
+ **kwargs: Method-specific parameters (e.g., time_var, group_var, query, llm instance if needed)
+
+ Returns:
+ Dictionary with effect estimate and diagnostics
+ """
+ query = kwargs.get('query_str')
+ # llm_instance = kwargs.get('llm') # Pass llm if helpers need it
+ df_processed = df.copy() # Work on a copy
+
+ logger.info("Starting DiD estimation using DoWhy with Statsmodels fallback...")
+
+ # --- Step 1: Identify Key Variables (using LLM Assist placeholders) ---
+ # Pass llm_instance to helpers if they are implemented to use it
+ llm_instance = get_llm_client() # Get llm instance if passed
+ time_var = kwargs.get('time_variable', identify_time_variable(df_processed, query, dataset_description, llm=llm_instance))
+ if time_var is None:
+ raise ValueError("Time variable could not be identified for DiD.")
+ if time_var not in df_processed.columns:
+ raise ValueError(f"Identified time variable '{time_var}' not found in DataFrame.")
+
+ # Determine the variable that identifies the panel unit (for grouping/FE)
+ group_var = kwargs.get('group_variable', identify_treatment_group(df_processed, treatment, query, dataset_description, llm=llm_instance))
+ if group_var is None:
+ raise ValueError("Group/Unit variable could not be identified for DiD.")
+ if group_var not in df_processed.columns:
+ raise ValueError(f"Identified group/unit variable '{group_var}' not found in DataFrame.")
+
+ # Check outcome exists before proceeding further
+ if outcome not in df_processed.columns:
+ raise ValueError(f"Outcome variable '{outcome}' not found in DataFrame.")
+
+ # Determine treatment period start
+ treatment_period = kwargs.get('treatment_period_start', kwargs.get('treatment_period',
+ determine_treatment_period(df_processed, time_var, treatment, query, dataset_description, llm=llm_instance)))
+
+ # --- Identify the TRUE binary treatment group indicator column ---
+ treated_group_col_for_formula = None
+
+ # Priority 1: Check if the 'treatment' argument itself is a valid binary indicator
+ if treatment in df_processed.columns and pd.api.types.is_numeric_dtype(df_processed[treatment]):
+ unique_treat_vals = set(df_processed[treatment].dropna().unique())
+ if unique_treat_vals.issubset({0, 1}):
+ treated_group_col_for_formula = treatment
+ logger.info(f"Using the provided 'treatment' argument '{treatment}' as binary group indicator.")
+
+ # Priority 2: Check if a column explicitly named 'group' exists and is binary
+ if treated_group_col_for_formula is None and 'group' in df_processed.columns and pd.api.types.is_numeric_dtype(df_processed['group']):
+ unique_group_vals = set(df_processed['group'].dropna().unique())
+ if unique_group_vals.issubset({0, 1}):
+ treated_group_col_for_formula = 'group'
+ logger.info(f"Using column 'group' as binary group indicator.")
+
+ # Priority 3: Fallback - Search other columns (excluding known roles and time-related ones)
+ if treated_group_col_for_formula is None:
+ logger.warning(f"Provided 'treatment' arg '{treatment}' is not binary 0/1 and no 'group' column found. Searching other columns...")
+ potential_group_cols = []
+ # Exclude outcome, time var, unit ID var, and common time indicators like 'post'
+ excluded_cols = [outcome, time_var, group_var, 'post', 'is_post_treatment', 'did_interaction']
+ for col_name in df_processed.columns:
+ if col_name in excluded_cols:
+ continue
+ try:
+ col_data = df_processed[col_name]
+ # Ensure we are working with a Series
+ if isinstance(col_data, pd.DataFrame):
+ if col_data.shape[1] == 1:
+ col_data = col_data.iloc[:, 0] # Extract the Series
+ else:
+ logger.warning(f"Skipping multi-column DataFrame slice for '{col_name}'.")
+ continue
+
+ # Check if the Series can be interpreted as binary 0/1
+ if not pd.api.types.is_numeric_dtype(col_data) and not pd.api.types.is_bool_dtype(col_data):
+ continue # Skip non-numeric/non-boolean columns
+
+ unique_vals = set(col_data.dropna().unique())
+ # Simplified check: directly test if unique values are a subset of {0, 1}
+ if unique_vals.issubset({0, 1}):
+ logger.info(f" Found potential binary indicator: {col_name}")
+ potential_group_cols.append(col_name)
+
+ except AttributeError as ae:
+ # Catch attribute errors likely due to unexpected types
+ logger.warning(f"Attribute error checking column '{col_name}': {ae}. Skipping.")
+ except Exception as e:
+ logger.warning(f"Unexpected error checking column '{col_name}' during group ID search: {e}")
+
+ if potential_group_cols:
+ treated_group_col_for_formula = potential_group_cols[0] # Take the first suitable one found
+ logger.info(f"Using column '{treated_group_col_for_formula}' found during search as binary group indicator.")
+ else:
+ # Final fallback: Use the originally identified group_var, but warn heavily
+ treated_group_col_for_formula = group_var
+ logger.error(f"CRITICAL WARNING: Could not find suitable binary treatment group indicator. Using '{group_var}', but this is likely incorrect and will produce invalid DiD estimates.")
+
+ # --- Final Check ---
+ if treated_group_col_for_formula not in df_processed.columns:
+ # This case should ideally not happen with the logic above but added defensively
+ raise ValueError(f"Determined treatment group column '{treated_group_col_for_formula}' not found in DataFrame.")
+ if df_processed[treated_group_col_for_formula].nunique(dropna=True) > 2:
+ logger.warning(f"Selected treatment group column '{treated_group_col_for_formula}' is not binary (has {df_processed[treated_group_col_for_formula].nunique()} unique values). DiD requires binary treatment group.")
+
+ # --- Step 2: Create Indicator Variables ---
+ post_indicator_col = 'post'
+ if post_indicator_col not in df_processed.columns:
+ # Create the post indicator if it doesn't exist
+ df_processed[post_indicator_col] = create_post_indicator(df_processed, time_var, treatment_period)
+
+ # Interaction term is treatment group * post
+ interaction_term_col = 'did_interaction' # Keep explicit interaction term
+ df_processed[interaction_term_col] = df_processed[treated_group_col_for_formula] * df_processed[post_indicator_col]
+
+ # --- Step 3: Validate Parallel Trends (using the group column) ---
+ parallel_trends_validation = validate_parallel_trends(df_processed, time_var, outcome,
+ treated_group_col_for_formula, treatment_period, dataset_description)
+ # Note: The validation result is currently just a placeholder
+ if not parallel_trends_validation.get('valid', False):
+ logger.warning("Parallel trends assumption potentially violated (based on placeholder check). Proceeding with estimation, but results may be biased.")
+ # Add this info to the final results diagnostics
+
+ # --- Step 4: Prepare for Statsmodels Estimation ---
+ # (DoWhy section commented out for simplicity)
+ # all_common_causes = covariates + [time_var, group_var] # group_var is unit ID
+ # use_dowhy_estimate = False
+ # dowhy_estimate = None
+
+ # try:
+ # # Create DoWhy CausalModel
+ # model = CausalModel(
+ # data=df_processed,
+ # treatment=treated_group_col_for_formula, # Use group indicator here
+ # outcome=outcome,
+ # common_causes=all_common_causes,
+ # )
+ # logger.info("DoWhy CausalModel created for DiD estimation.")
+
+ # # Identify estimand
+ # identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
+ # logger.info(f"DoWhy identified estimand: {identified_estimand.estimand_type}")
+
+ # # Try to estimate using DiD estimator if available in DoWhy
+ # try:
+ # logger.info("Attempting to use DoWhy's DiD estimator...")
+
+ # # Debug info - print DataFrame info to help diagnose possible issues
+ # logger.debug(f"DataFrame shape before DoWhy DiD: {df_processed.shape}")
+ # # ... (rest of DoWhy debug logs commented out) ...
+
+ # # Create params dictionary for DoWhy DiD estimator
+ # did_params = {
+ # 'time_var': time_var,
+ # 'treatment_period': treatment_period,
+ # 'unit_var': group_var
+ # }
+
+ # # Add control variables if available
+ # if covariates:
+ # did_params['control_vars'] = covariates
+
+ # logger.debug(f"DoWhy DiD params: {did_params}")
+
+ # # Try to use DiD estimator from DoWhy (requires recent version of DoWhy)
+ # if hasattr(model, 'estimate_effect'):
+ # try:
+ # # First check if difference_in_differences method is available
+ # available_methods = model.get_available_effect_estimators() if hasattr(model, 'get_available_effect_estimators') else []
+ # logger.debug(f"Available DoWhy estimators: {available_methods}")
+
+ # if "difference_in_differences" not in str(available_methods):
+ # logger.warning("'difference_in_differences' estimator not found in available DoWhy estimators. Falling back to statsmodels.")
+ # else:
+ # # Try the estimation with more error handling
+ # logger.info("Calling DoWhy DiD estimator...")
+ # estimate = model.estimate_effect(
+ # identified_estimand,
+ # method_name="difference_in_differences",
+ # method_params=did_params
+ # )
+
+ # if estimate:
+ # # Extra check to verify estimate has expected attributes
+ # if hasattr(estimate, 'value') and not pd.isna(estimate.value):
+ # dowhy_estimate = estimate
+ # use_dowhy_estimate = True
+ # logger.info(f"Successfully used DoWhy's DiD estimator. Effect estimate: {estimate.value}")
+ # else:
+ # logger.warning(f"DoWhy's DiD estimator returned invalid estimate: {estimate}. Falling back to statsmodels.")
+ # else:
+ # logger.warning("DoWhy's DiD estimator returned None. Falling back to statsmodels.")
+ # except IndexError as idx_err:
+ # # Handle specific IndexError that's occurring
+ # logger.error(f"IndexError in DoWhy DiD estimator: {idx_err}. Check input data structure.")
+ # # Trace more details about the error
+ # import traceback
+ # logger.error(f"Error traceback: {traceback.format_exc()}")
+ # logger.warning("Falling back to statsmodels due to IndexError in DoWhy.")
+ # else:
+ # logger.warning("DoWhy model does not have estimate_effect method. Falling back to statsmodels.")
+
+ # except (ImportError, AttributeError) as e:
+ # logger.warning(f"DoWhy DiD estimator not available or not implemented: {e}. Falling back to statsmodels.")
+ # except ValueError as ve:
+ # logger.error(f"ValueError in DoWhy DiD estimator: {ve}. Likely issue with data formatting. Falling back to statsmodels.")
+ # except Exception as e:
+ # logger.error(f"Error using DoWhy's DiD estimator: {e}. Falling back to statsmodels.")
+ # # Add traceback for better debugging
+ # import traceback
+ # logger.error(f"Full error traceback: {traceback.format_exc()}")
+
+ # except Exception as e:
+ # logger.error(f"Failed to create DoWhy CausalModel: {e}", exc_info=True)
+ # # model = None # Set model to None if creation fails
+
+ # Create parameters dictionary for formatting results
+ parameters = {
+ "time_var": time_var,
+ "group_var": group_var, # Unit ID
+ "treatment_indicator": treated_group_col_for_formula, # Group indicator used in formula basis
+ "post_indicator": post_indicator_col,
+ "treatment_period_start": treatment_period,
+ "covariates": covariates,
+ }
+
+ # Group diagnostics for formatting
+ did_diagnostics = {
+ "parallel_trends": parallel_trends_validation,
+ # "placebo_test": run_placebo_test(...)
+ }
+
+ # If DoWhy estimation was successful, use those results (Section Commented Out)
+ # if use_dowhy_estimate and dowhy_estimate:
+ # logger.info("Using DoWhy DiD estimation results.")
+ # parameters["estimation_method"] = "DoWhy Difference-in-Differences"
+
+ # # Format the results
+ # formatted_results = format_dowhy_results(dowhy_estimate, did_diagnostics, parameters)
+ # else:
+
+ # --- Step 5: Use Statsmodels OLS ---
+ logger.info("Determining Statsmodels OLS formula based on number of time periods...")
+
+ num_time_periods = df_processed[time_var].nunique()
+
+ interaction_term_key_for_results: str
+ method_details_str: str
+ formula: str
+
+ if num_time_periods == 2:
+ logger.info(
+ f"Number of unique time periods is 2. Using 2x2 DiD formula: "
+ f"{outcome} ~ {treated_group_col_for_formula} * {post_indicator_col}"
+ )
+ # For 2x2 DiD: outcome ~ group * post_indicator
+ # The interaction term A:B in statsmodels gives the DiD estimate.
+ formula_core = f"{treated_group_col_for_formula} * {post_indicator_col}"
+ interaction_term_key_for_results = f"{treated_group_col_for_formula}:{post_indicator_col}"
+
+ formula_parts = [formula_core]
+ main_model_terms = {outcome, treated_group_col_for_formula, post_indicator_col}
+
+ if covariates:
+ filtered_covs = [
+ c for c in covariates if c not in main_model_terms
+ ]
+ if filtered_covs:
+ formula_parts.extend(filtered_covs)
+
+ formula = f"{outcome} ~ {' + '.join(formula_parts)}"
+ parameters["estimation_method"] = "Statsmodels OLS for 2x2 DiD (Group * Post interaction)"
+ method_details_str = "DiD via Statsmodels 2x2 (Group * Post interaction)"
+
+ else: # num_time_periods > 2
+ logger.info(
+ f"Number of unique time periods is {num_time_periods} (>2). "
+ f"Using TWFE DiD formula: {outcome} ~ {interaction_term_col} + C({group_var}) + C({time_var})"
+ )
+ # For TWFE: outcome ~ actual_treatment_variable + UnitFE + TimeFE
+ # actual_treatment_variable is interaction_term_col (e.g., treated_group * post_indicator)
+ # UnitFE is C(group_var), TimeFE is C(time_var)
+ formula_parts = [
+ interaction_term_col,
+ f"C({group_var})",
+ f"C({time_var})"
+ ]
+ interaction_term_key_for_results = interaction_term_col
+ main_model_terms = {outcome, interaction_term_col, group_var, time_var}
+
+ if covariates:
+ filtered_covs = [
+ c for c in covariates if c not in main_model_terms
+ ]
+ if filtered_covs:
+ formula_parts.extend(filtered_covs)
+
+ formula = f"{outcome} ~ {' + '.join(formula_parts)}"
+ parameters["estimation_method"] = "Statsmodels OLS with TWFE (C() Notation)"
+ method_details_str = "DiD via Statsmodels TWFE (C() Notation)"
+
+ try:
+ logger.info(f"Using formula: {formula}")
+ logger.debug(f"Data head for statsmodels:\n{df_processed.head().to_string()}")
+ logger.debug(f"Regression DataFrame shape: {df_processed.shape}, Columns: {df_processed.columns.tolist()}")
+
+ ols_model = smf.ols(formula=formula, data=df_processed)
+ if group_var not in df_processed.columns:
+ # This check is mainly for clustering but good to ensure group_var exists.
+ # For 2x2, group_var (unit ID) might not be in formula but needed for clustering.
+ raise ValueError(f"Clustering variable '{group_var}' (panel unit ID) not found in regression data.")
+ logger.debug(f"Clustering standard errors by: {group_var}")
+ results = ols_model.fit(cov_type='cluster', cov_kwds={'groups': df_processed[group_var]})
+
+ logger.info("Statsmodels estimation complete.")
+ logger.info(f"Statsmodels Results Summary:\n{results.summary()}")
+
+ logger.debug(f"Extracting results using interaction term key: {interaction_term_key_for_results}")
+
+ parameters["final_formula"] = formula
+ parameters["interaction_term_coefficient_name"] = interaction_term_key_for_results
+
+ formatted_results = format_did_results(results, interaction_term_key_for_results,
+ did_diagnostics,
+ method_details=method_details_str,
+ parameters=parameters)
+ formatted_results["estimator"] = "statsmodels"
+
+ except Exception as e:
+ logger.error(f"Statsmodels OLS estimation failed: {e}", exc_info=True)
+ raise ValueError(f"DiD estimation failed (both DoWhy and Statsmodels): {e}")
+
+
+
+
+ # --- Add Interpretation --- (Now add interpretation to the formatted results)
+ try:
+ # Use the llm_instance fetched earlier
+ interpretation = interpret_did_results(formatted_results, did_diagnostics, dataset_description, llm=llm_instance)
+ formatted_results['interpretation'] = interpretation
+ except Exception as interp_e:
+ logger.error(f"DiD Interpretation failed: {interp_e}")
+ formatted_results['interpretation'] = "Interpretation failed."
+
+ return formatted_results
\ No newline at end of file
diff --git a/auto_causal/methods/difference_in_differences/llm_assist.py b/auto_causal/methods/difference_in_differences/llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..eef3b660e052fb031080dc1783663e982b3d1b5a
--- /dev/null
+++ b/auto_causal/methods/difference_in_differences/llm_assist.py
@@ -0,0 +1,362 @@
+"""LLM Assist functions for Difference-in-Differences method."""
+
+import pandas as pd
+import numpy as np
+from typing import Optional, Any, Dict, Union
+import logging
+from pydantic import BaseModel, Field, ValidationError
+from langchain_core.messages import HumanMessage
+from langchain_core.exceptions import OutputParserException
+
+# Import shared types if needed
+from langchain_core.language_models import BaseChatModel
+
+# Import shared LLM helpers
+from auto_causal.utils.llm_helpers import call_llm_with_json_output
+
+logger = logging.getLogger(__name__)
+
+# Placeholder LLM/Helper Functions
+
+# --- Pydantic model for LLM time variable extraction ---
+class LLMTimeVar(BaseModel):
+ time_variable_name: Optional[str] = Field(None, description="The column name identified as the primary time variable.")
+
+
+def identify_time_variable(df: pd.DataFrame,
+ query: Optional[str] = None,
+ dataset_description: Optional[str] = None,
+ llm: Optional[BaseChatModel] = None) -> Optional[str]:
+ '''Identifies the most likely time variable.
+
+ Current Implementation: Heuristic based on column names, with LLM fallback.
+ Future: Refine LLM prompt and parsing.
+ '''
+ # 1. Heuristic based on common time-related keywords
+ time_patterns = ['time', 'year', 'date', 'period', 'month', 'day']
+ columns = df.columns.tolist()
+ for col in columns:
+ if any(pattern in col.lower() for pattern in time_patterns):
+ logger.info(f"Identified '{col}' as time variable (heuristic).")
+ return col
+
+ # 2. LLM Fallback if heuristic fails and LLM is provided
+ if llm and query:
+ logger.warning("Heuristic failed for time variable. Trying LLM fallback...")
+ # --- Example: Add dataset description context ---
+ context_str = ""
+ if dataset_description:
+ # col_types = dataset_description.get('column_types', {}) # Description is now a string
+ context_str += f"\nDataset Description: {dataset_description}"
+ # Add other relevant info like sample values if available
+ # ------------------------------------------------
+ prompt = f"""Given the user query and the available data columns, identify the single most likely column representing the primary time dimension (e.g., year, date, period).
+
+User Query: "{query}"
+Available Columns: {columns}{context_str}
+
+Respond ONLY with a JSON object containing the identified column name using the key 'time_variable_name'. If no suitable time variable is found, return null for the value.
+Example: {{"time_variable_name": "Year"}} or {{"time_variable_name": null}}"""
+
+ messages = [HumanMessage(content=prompt)]
+ structured_llm = llm.with_structured_output(LLMTimeVar)
+
+ try:
+ parsed_result = structured_llm.invoke(messages)
+ llm_identified_col = parsed_result.time_variable_name
+
+ if llm_identified_col and llm_identified_col in columns:
+ logger.info(f"Identified '{llm_identified_col}' as time variable (LLM fallback).")
+ return llm_identified_col
+ elif llm_identified_col:
+ logger.warning(f"LLM fallback identified '{llm_identified_col}' but it's not in the columns. Ignoring.")
+ else:
+ logger.info("LLM fallback did not identify a time variable.")
+
+ except (OutputParserException, ValidationError) as e:
+ logger.error(f"LLM fallback for time variable failed parsing/validation: {e}")
+ except Exception as e:
+ logger.error(f"LLM fallback for time variable failed unexpectedly: {e}", exc_info=True)
+
+ logger.warning("Could not identify time variable using heuristics or LLM fallback.")
+ return None
+
+# --- Pydantic model for LLM treatment period extraction ---
+class LLMTreatmentPeriod(BaseModel):
+ treatment_start_period: Optional[Union[str, int, float]] = Field(None, description="The time period value (as string) when treatment is believed to start based on the query.")
+
+def determine_treatment_period(df: pd.DataFrame, time_var: str, treatment: str,
+ query: Optional[str] = None,
+ dataset_description: Optional[str] = None,
+ llm: Optional[BaseChatModel] = None) -> Any:
+ '''Determines the period when treatment starts.
+
+ Tries LLM first if available, then falls back to heuristic.
+ '''
+ if time_var not in df.columns:
+ raise ValueError(f"Time variable '{time_var}' not found in DataFrame.")
+
+ unique_times_sorted = np.sort(df[time_var].dropna().unique())
+ if len(unique_times_sorted) < 2:
+ raise ValueError("Need at least two time periods for DiD")
+
+ # --- Try LLM First (if available) ---
+ llm_period = None
+ if llm and query:
+ logger.info("Attempting LLM call to determine treatment period start...")
+ # Provide sorted unique times for context
+ times_str = ", ".join(map(str, unique_times_sorted)) if len(unique_times_sorted) < 20 else f"{unique_times_sorted[0]}...{unique_times_sorted[-1]}"
+ # --- Example: Add dataset description context ---
+ context_str = ""
+ if dataset_description:
+ # Example: Show summary stats for time var if helpful
+ # time_stats = dataset_description.get('summary_stats', {}).get(time_var) # Cannot get from string
+ context_str += f"\nDataset Description: {dataset_description}"
+ # ------------------------------------------------
+ prompt = f"""Based on the user query and the observed time periods, determine the specific period value when the treatment ('{treatment}') likely started.
+
+User Query: "{query}"
+Time Variable Name: '{time_var}'
+Observed Time Periods (sorted): [{times_str}]{context_str}
+
+Respond ONLY with a JSON object containing the identified start period using the key 'treatment_start_period'. The value should be one of the observed periods if possible. If the query doesn't specify a start period, return null.
+Example: {{"treatment_start_period": 2015}} or {{"treatment_start_period": null}}"""
+
+ messages = [HumanMessage(content=prompt)]
+ structured_llm = llm.with_structured_output(LLMTreatmentPeriod)
+
+ try:
+ parsed_result = structured_llm.invoke(messages)
+ potential_period = parsed_result.treatment_start_period
+
+ # Validate if the period exists in the data (might need type conversion)
+ if potential_period is not None:
+ # Try converting LLM output type to match data type if needed
+ try:
+ series_dtype = df[time_var].dtype
+ converted_period = pd.Series([potential_period]).astype(series_dtype).iloc[0]
+ except Exception:
+ converted_period = potential_period # Use raw if conversion fails
+
+ if converted_period in unique_times_sorted:
+ llm_period = converted_period
+ logger.info(f"LLM identified treatment period start: {llm_period}")
+ else:
+ logger.warning(f"LLM identified period '{potential_period}' (converted: '{converted_period}'), but it's not in the observed time periods. Ignoring LLM result.")
+ else:
+ logger.info("LLM did not identify a specific treatment start period from the query.")
+
+ except (OutputParserException, ValidationError) as e:
+ logger.error(f"LLM fallback for treatment period failed parsing/validation: {e}")
+ except Exception as e:
+ logger.error(f"LLM fallback for treatment period failed unexpectedly: {e}", exc_info=True)
+
+ if llm_period is not None:
+ return llm_period
+
+ # --- Fallback to Heuristic ---
+ logger.warning("Using heuristic (median time) to determine treatment period start.")
+ treatment_period_start = None
+ try:
+ if pd.api.types.is_numeric_dtype(df[time_var]):
+ median_time = np.median(unique_times_sorted)
+ possible_starts = unique_times_sorted[unique_times_sorted > median_time]
+ if len(possible_starts) > 0:
+ treatment_period_start = possible_starts[0]
+ else:
+ treatment_period_start = unique_times_sorted[-1]
+ logger.warning(f"Could not determine treatment start > median time. Defaulting to last period: {treatment_period_start}")
+ else: # Assume sortable categories or dates
+ median_idx = len(unique_times_sorted) // 2
+ if median_idx < len(unique_times_sorted):
+ treatment_period_start = unique_times_sorted[median_idx]
+ else:
+ treatment_period_start = unique_times_sorted[0]
+
+ if treatment_period_start is not None:
+ logger.info(f"Determined treatment period start: {treatment_period_start} (heuristic: median time).")
+ return treatment_period_start
+ else:
+ raise ValueError("Could not determine treatment start period using heuristic.")
+
+ except Exception as e:
+ logger.error(f"Error in heuristic for treatment period: {e}")
+ raise ValueError(f"Could not determine treatment start period using heuristic: {e}")
+
+# --- Pydantic model for LLM group variable extraction ---
+class LLMGroupVar(BaseModel):
+ group_variable_name: Optional[str] = Field(None, description="The column name identifying the panel unit (e.g., state, individual, firm).")
+
+def identify_treatment_group(df: pd.DataFrame, treatment_var: str,
+ query: Optional[str] = None,
+ dataset_description: Optional[str] = None,
+ llm: Optional[BaseChatModel] = None) -> Optional[str]:
+ '''Identifies the variable indicating the treated group/unit ID.
+
+ Tries heuristic check for non-binary treatment_var first, then LLM,
+ then falls back to assuming treatment_var is the group/unit identifier.
+ '''
+ columns = df.columns.tolist()
+ if treatment_var not in columns:
+ logger.error(f"Treatment variable '{treatment_var}' provided to identify_treatment_group not found in DataFrame.")
+ # Fallback: Look for common ID names if specified treatment is missing
+ id_keywords = ['id', 'unit', 'group', 'entity', 'state', 'firm']
+ for col in columns:
+ if any(keyword in col.lower() for keyword in id_keywords):
+ logger.warning(f"Specified treatment '{treatment_var}' not found. Falling back to potential ID column '{col}' as group identifier.")
+ return col
+ return None # Give up if no likely ID column found
+
+ # --- Heuristic: Check if treatment_var is non-binary, if so, look for ID columns ---
+ is_potentially_binary = False
+ if pd.api.types.is_numeric_dtype(df[treatment_var]):
+ unique_vals = set(df[treatment_var].dropna().unique())
+ if unique_vals.issubset({0, 1}):
+ is_potentially_binary = True
+
+ if not is_potentially_binary:
+ logger.info(f"Provided treatment variable '{treatment_var}' is not binary (0/1). Searching for a separate group/unit ID column heuristically.")
+ id_keywords = ['id', 'unit', 'group', 'entity', 'state', 'firm']
+ # Prioritize 'group' or 'unit' if available
+ for keyword in ['group', 'unit']:
+ for col in columns:
+ if keyword == col.lower():
+ logger.info(f"Heuristically identified '{col}' as group/unit ID (treatment '{treatment_var}' was non-binary)." )
+ return col
+ # Then check other keywords
+ for col in columns:
+ if col != treatment_var and any(keyword in col.lower() for keyword in id_keywords):
+ logger.info(f"Heuristically identified '{col}' as group/unit ID (treatment '{treatment_var}' was non-binary)." )
+ return col
+ logger.warning("Heuristic search for group/unit ID failed when treatment was non-binary.")
+
+ # --- LLM Attempt (if heuristic didn't find an alternative or wasn't needed) ---
+ # Useful if query context helps disambiguate (e.g., "effect across states")
+ if llm and query:
+ logger.info("Attempting LLM call to identify group/unit variable...")
+ # --- Example: Add dataset description context ---
+ context_str = ""
+ if dataset_description:
+ # col_types = dataset_description.get('column_types', {}) # Description is now a string
+ context_str += f"\nDataset Description: {dataset_description}"
+ # ------------------------------------------------
+ prompt = f"""Given the user query and data columns, identify the single column that most likely represents the unique identifier for the panel units (e.g., state, individual, firm, unit ID), distinct from the treatment status indicator ('{treatment_var}').
+
+User Query: "{query}"
+Treatment Variable Mentioned: '{treatment_var}'
+Available Columns: {columns}{context_str}
+
+Respond ONLY with a JSON object containing the identified unit identifier column name using the key 'group_variable_name'. If the best identifier seems to be the treatment variable itself or none is suitable, return null.
+Example: {{"group_variable_name": "state_id"}} or {{"group_variable_name": null}}"""
+
+ messages = [HumanMessage(content=prompt)]
+ structured_llm = llm.with_structured_output(LLMGroupVar)
+
+ try:
+ parsed_result = structured_llm.invoke(messages)
+ llm_identified_col = parsed_result.group_variable_name
+
+ if llm_identified_col and llm_identified_col in columns:
+ logger.info(f"Identified '{llm_identified_col}' as group/unit variable (LLM).")
+ return llm_identified_col
+ elif llm_identified_col:
+ logger.warning(f"LLM identified '{llm_identified_col}' but it's not in the columns. Ignoring.")
+ else:
+ logger.info("LLM did not identify a separate group/unit variable.")
+
+ except (OutputParserException, ValidationError) as e:
+ logger.error(f"LLM call for group/unit variable failed parsing/validation: {e}")
+ except Exception as e:
+ logger.error(f"LLM call for group/unit variable failed unexpectedly: {e}", exc_info=True)
+
+ # --- Final Fallback ---
+ logger.info(f"Defaulting to using provided treatment variable '{treatment_var}' as the group/unit identifier.")
+ return treatment_var
+
+# --- Add interpret_did_results function ---
+
+def interpret_did_results(
+ results: Dict[str, Any],
+ diagnostics: Optional[Dict[str, Any]],
+ dataset_description: Optional[str] = None,
+ llm: Optional[BaseChatModel] = None
+) -> str:
+ """Use LLM to interpret Difference-in-Differences results."""
+ default_interpretation = "LLM interpretation not available for DiD."
+ if llm is None:
+ logger.info("LLM not provided for DiD interpretation.")
+ return default_interpretation
+
+ try:
+ # --- Prepare summary for LLM ---
+ results_summary = {}
+ params = results.get('parameters', {})
+ diag_details = diagnostics.get('details', {}) if diagnostics else {}
+ parallel_trends = diag_details.get('parallel_trends', {})
+
+ effect = results.get('effect_estimate')
+ pval = results.get('p_value')
+ ci = results.get('confidence_interval')
+
+ results_summary['Method Used'] = results.get('method_details', 'Difference-in-Differences')
+ results_summary['Effect Estimate'] = f"{effect:.3f}" if isinstance(effect, (int, float)) else str(effect)
+ results_summary['P-value'] = f"{pval:.3f}" if isinstance(pval, (int, float)) else str(pval)
+ if isinstance(ci, (list, tuple)) and len(ci) == 2:
+ results_summary['Confidence Interval'] = f"[{ci[0]:.3f}, {ci[1]:.3f}]"
+ else:
+ results_summary['Confidence Interval'] = str(ci) if ci is not None else "N/A"
+
+ results_summary['Time Variable'] = params.get('time_var', 'N/A')
+ results_summary['Group/Unit Variable'] = params.get('group_var', 'N/A')
+ results_summary['Treatment Indicator Used'] = params.get('treatment_indicator', 'N/A')
+ results_summary['Treatment Start Period'] = params.get('treatment_period_start', 'N/A')
+ results_summary['Covariates Included'] = params.get('covariates', [])
+
+ diag_summary = {}
+ diag_summary['Parallel Trends Assumption Status'] = "Passed (Placeholder)" if parallel_trends.get('valid', False) else "Failed/Unknown (Placeholder)"
+ if not parallel_trends.get('valid', False) and parallel_trends.get('details') != "Placeholder validation":
+ diag_summary['Parallel Trends Details'] = parallel_trends.get('details', 'N/A')
+
+ # --- Example: Add dataset description context ---
+ context_str = ""
+ if dataset_description:
+ # context_str += f"\nDataset Context: {dataset_description.get('summary', 'N/A')}" # Use string directly
+ context_str += f"\n\nDataset Context Provided:\n{dataset_description}"
+ # ------------------------------------------------
+
+ # --- Construct Prompt ---
+ prompt = f"""
+ You are assisting with interpreting Difference-in-Differences (DiD) results.
+ {context_str} # Add context here
+
+ Estimation Results Summary:
+ {results_summary}
+
+ Diagnostics Summary:
+ {diag_summary}
+
+ Explain these DiD results in 2-4 concise sentences. Focus on:
+ 1. The estimated average treatment effect on the treated (magnitude, direction, statistical significance based on p-value < 0.05).
+ 2. The status of the parallel trends assumption (mentioning it's a key assumption for DiD).
+ 3. Note that the estimation controlled for unit and time fixed effects, and potentially covariates {results_summary['Covariates Included']}
+
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
+ {{
+ "interpretation": ""
+ }}
+ """
+
+ # --- Call LLM ---
+ response = call_llm_with_json_output(llm, prompt)
+
+ # --- Process Response ---
+ if response and isinstance(response, dict) and \
+ "interpretation" in response and isinstance(response["interpretation"], str):
+ return response["interpretation"]
+ else:
+ logger.warning(f"Failed to get valid interpretation from LLM for DiD. Response: {response}")
+ return default_interpretation
+
+ except Exception as e:
+ logger.error(f"Error during LLM interpretation for DiD: {e}")
+ return f"Error generating interpretation: {e}"
\ No newline at end of file
diff --git a/auto_causal/methods/difference_in_differences/utils.py b/auto_causal/methods/difference_in_differences/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac942e059b21c099045836500fb956ff693e40d9
--- /dev/null
+++ b/auto_causal/methods/difference_in_differences/utils.py
@@ -0,0 +1,65 @@
+# Utility functions for Difference-in-Differences
+import pandas as pd
+import logging
+
+logger = logging.getLogger(__name__)
+
+def create_post_indicator(df: pd.DataFrame, time_var: str, treatment_period_start: any) -> pd.Series:
+ """Creates the post-treatment indicator variable.
+ Checks if time_var is already a 0/1 indicator; otherwise, compares to treatment_period_start.
+ """
+ try:
+ time_var_series = df[time_var]
+ # Ensure numeric for checks and direct comparison
+ if pd.api.types.is_bool_dtype(time_var_series):
+ time_var_series = time_var_series.astype(int)
+
+ # Check if it's already a binary 0/1 indicator
+ if pd.api.types.is_numeric_dtype(time_var_series):
+ unique_vals = set(time_var_series.dropna().unique())
+ if unique_vals == {0, 1}:
+ logger.info(f"Time variable '{time_var}' is already a binary 0/1 indicator. Using it directly as post indicator.")
+ return time_var_series.astype(int)
+ else:
+ # Numeric, but not 0/1, so compare with treatment_period_start
+ logger.info(f"Time variable '{time_var}' is numeric. Comparing with treatment_period_start: {treatment_period_start}")
+ return (time_var_series >= treatment_period_start).astype(int)
+ else:
+ # Non-numeric and not boolean, will likely fall into TypeError for datetime conversion
+ # This else block might not be strictly necessary if TypeError is caught below
+ # but added for logical completeness before attempting datetime conversion.
+ pass # Let it fall through to TypeError if not numeric here
+
+ # If we reached here, it means it wasn't numeric or bool, try direct comparison which will likely raise TypeError
+ # and be caught by the except block for datetime conversion if applicable.
+ # This line is kept to ensure non-numeric non-datetime-like strings also trigger the except.
+ return (df[time_var] >= treatment_period_start).astype(int)
+
+ except TypeError:
+ # If direct comparison fails (e.g., comparing datetime with int/str, or non-numeric string with number),
+ # attempt to convert both to datetime objects for comparison.
+ logger.info(f"Direct comparison/numeric check failed for time_var '{time_var}'. Attempting datetime conversion.")
+ try:
+ time_series_dt = pd.to_datetime(df[time_var], errors='coerce')
+ # Try to convert treatment_period_start to datetime if it's not already
+ # This handles cases where treatment_period_start might be a date string
+ try:
+ treatment_start_dt = pd.to_datetime(treatment_period_start)
+ except Exception as e_conv:
+ logger.error(f"Could not convert treatment_period_start '{treatment_period_start}' to datetime: {e_conv}")
+ raise TypeError(f"treatment_period_start '{treatment_period_start}' could not be converted to a comparable datetime format.")
+
+ if time_series_dt.isna().all(): # if all values are NaT after conversion
+ raise ValueError(f"Time variable '{time_var}' could not be converted to datetime (all values NaT).")
+ if pd.isna(treatment_start_dt):
+ raise ValueError(f"Treatment start period '{treatment_period_start}' converted to NaT.")
+
+ logger.info(f"Comparing time_var '{time_var}' (as datetime) with treatment_start_dt '{treatment_start_dt}' (as datetime).")
+ return (time_series_dt >= treatment_start_dt).astype(int)
+ except Exception as e:
+ logger.error(f"Failed to compare time variable '{time_var}' with treatment start '{treatment_period_start}' using datetime logic: {e}", exc_info=True)
+ raise TypeError(f"Could not compare time variable '{time_var}' with treatment start '{treatment_period_start}'. Ensure they are comparable or convertible to datetime. Error: {e}")
+ except Exception as ex:
+ # Catch any other unexpected errors during the initial numeric processing
+ logger.error(f"Unexpected error processing time_var '{time_var}' for post indicator: {ex}", exc_info=True)
+ raise TypeError(f"Unexpected error processing time_var '{time_var}': {ex}")
\ No newline at end of file
diff --git a/auto_causal/methods/generalized_propensity_score/__init__.py b/auto_causal/methods/generalized_propensity_score/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f801df677ba38be8c36a1531067b06ddfd7af2c
--- /dev/null
+++ b/auto_causal/methods/generalized_propensity_score/__init__.py
@@ -0,0 +1,3 @@
+"""
+Generalized Propensity Score (GPS) method for continuous treatments.
+"""
\ No newline at end of file
diff --git a/auto_causal/methods/generalized_propensity_score/diagnostics.py b/auto_causal/methods/generalized_propensity_score/diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9b1d8717aa3113af3f879e6af734c7beacb34eb
--- /dev/null
+++ b/auto_causal/methods/generalized_propensity_score/diagnostics.py
@@ -0,0 +1,196 @@
+"""
+Diagnostic checks for the Generalized Propensity Score (GPS) method.
+"""
+from typing import Dict, List, Any
+import pandas as pd
+import logging
+import numpy as np
+import statsmodels.api as sm
+
+logger = logging.getLogger(__name__)
+
+def assess_gps_balance(
+ df_with_gps: pd.DataFrame,
+ treatment_var: str,
+ covariate_vars: List[str],
+ gps_col_name: str,
+ **kwargs: Any
+) -> Dict[str, Any]:
+ """
+ Assesses the balance of covariates conditional on the estimated GPS.
+
+ This function is typically called after GPS estimation to validate the
+ assumption that covariates are independent of treatment conditional on GPS.
+
+ Args:
+ df_with_gps: DataFrame containing the original data plus the estimated GPS column.
+ treatment_var: The name of the continuous treatment variable column.
+ covariate_vars: A list of covariate column names to check for balance.
+ gps_col_name: The name of the column containing the estimated GPS values.
+ **kwargs: Additional arguments (e.g., number of strata for checking balance).
+
+ Returns:
+ A dictionary containing balance statistics and summaries. For example:
+ {
+ "overall_balance_metric": 0.05,
+ "covariate_balance": {
+ "cov1": {"statistic": 0.03, "p_value": 0.5, "balanced": True},
+ "cov2": {"statistic": 0.12, "p_value": 0.02, "balanced": False}
+ },
+ "summary": "Balance assessment complete."
+ }
+ """
+ logger.info(f"Assessing GPS balance for covariates: {covariate_vars}")
+
+ # Default to 5 strata (quintiles) if not specified
+ num_strata = kwargs.get('num_strata', 5)
+ if not isinstance(num_strata, int) or num_strata <= 1:
+ logger.warning(f"Invalid num_strata ({num_strata}), defaulting to 5.")
+ num_strata = 5
+
+ balance_results = {}
+ overall_summary = {
+ "num_strata_used": num_strata,
+ "covariates_tested": len(covariate_vars),
+ "warnings": [],
+ "all_strata_coefficients": {cov: [] for cov in covariate_vars},
+ "all_strata_p_values": {cov: [] for cov in covariate_vars}
+ }
+
+ if df_with_gps[gps_col_name].isnull().all():
+ logger.error(f"All GPS scores in column '{gps_col_name}' are NaN. Cannot perform balance assessment.")
+ overall_summary["error"] = "All GPS scores are NaN."
+ return {
+ "error": "All GPS scores are NaN.",
+ "summary": "Balance assessment failed."
+ }
+
+ try:
+ # Create GPS strata (e.g., quintiles)
+ # Ensure unique bin edges for qcut, duplicates='drop' will handle cases with sparse GPS values
+ # but might result in fewer than num_strata if GPS distribution is highly skewed or has few unique values.
+ try:
+ df_with_gps['gps_stratum'] = pd.qcut(df_with_gps[gps_col_name], num_strata, labels=False, duplicates='drop')
+ actual_num_strata = df_with_gps['gps_stratum'].nunique()
+ if actual_num_strata < num_strata and actual_num_strata > 0:
+ logger.warning(f"Requested {num_strata} strata, but due to GPS distribution, only {actual_num_strata} could be formed.")
+ overall_summary["warnings"].append(f"Only {actual_num_strata} strata formed out of {num_strata} requested.")
+ overall_summary["actual_num_strata_formed"] = actual_num_strata
+ except ValueError as ve:
+ logger.error(f"Could not create strata using pd.qcut due to: {ve}. This might happen if GPS has too few unique values.")
+ logger.info("Attempting to use unique GPS values as strata if count is low.")
+ unique_gps_count = df_with_gps[gps_col_name].nunique()
+ if unique_gps_count <= num_strata * 2 and unique_gps_count > 1: # Arbitrary threshold to try unique values as strata
+ strata_map = {val: i for i, val in enumerate(df_with_gps[gps_col_name].unique())}
+ df_with_gps['gps_stratum'] = df_with_gps[gps_col_name].map(strata_map)
+ actual_num_strata = df_with_gps['gps_stratum'].nunique()
+ overall_summary["actual_num_strata_formed"] = actual_num_strata
+ overall_summary["warnings"].append(f"Used {actual_num_strata} unique GPS values as strata due to qcut error.")
+ else:
+ overall_summary["error"] = f"Failed to create GPS strata: {ve}. GPS may have too few unique values."
+ return {
+ "error": overall_summary["error"],
+ "summary": "Balance assessment failed due to strata creation issues."
+ }
+
+ if df_with_gps['gps_stratum'].isnull().all():
+ logger.error("Stratum assignment resulted in all NaNs.")
+ overall_summary["error"] = "Stratum assignment resulted in all NaNs."
+ return {"error": overall_summary["error"], "summary": "Balance assessment failed."}
+
+
+ for cov in covariate_vars:
+ balance_results[cov] = {
+ "strata_details": [],
+ "mean_abs_coefficient": None,
+ "num_significant_strata_p005": 0,
+ "balanced_heuristic": True # Assume balanced until proven otherwise
+ }
+ coeffs_for_cov = []
+ p_values_for_cov = []
+
+ for stratum_idx in sorted(df_with_gps['gps_stratum'].dropna().unique()):
+ stratum_data = df_with_gps[df_with_gps['gps_stratum'] == stratum_idx]
+ stratum_detail = {"stratum_index": int(stratum_idx), "n_obs": len(stratum_data)}
+
+ if len(stratum_data) < 10: # Need a minimum number of observations for stable regression
+ stratum_detail["status"] = "Skipped (too few observations)"
+ stratum_detail["coefficient_on_treatment"] = np.nan
+ stratum_detail["p_value_on_treatment"] = np.nan
+ balance_results[cov]["strata_details"].append(stratum_detail)
+ continue
+
+ # Ensure covariate and treatment have variance within the stratum
+ if stratum_data[cov].nunique() < 2 or stratum_data[treatment_var].nunique() < 2:
+ stratum_detail["status"] = "Skipped (no variance in cov or treatment)"
+ stratum_detail["coefficient_on_treatment"] = np.nan
+ stratum_detail["p_value_on_treatment"] = np.nan
+ balance_results[cov]["strata_details"].append(stratum_detail)
+ continue
+
+ try:
+ X_balance = sm.add_constant(stratum_data[[treatment_var]])
+ y_balance = stratum_data[cov]
+
+ # Drop NaNs for this specific regression within stratum
+ temp_df = pd.concat([y_balance, X_balance], axis=1).dropna()
+ if len(temp_df) < X_balance.shape[1] +1: # Check for enough data points after NaNs for regression
+ stratum_detail["status"] = "Skipped (too few non-NaN obs for regression)"
+ stratum_detail["coefficient_on_treatment"] = np.nan
+ stratum_detail["p_value_on_treatment"] = np.nan
+ balance_results[cov]["strata_details"].append(stratum_detail)
+ continue
+
+ y_balance_fit = temp_df[cov]
+ X_balance_fit = temp_df[[col for col in temp_df.columns if col != cov]]
+
+ balance_model = sm.OLS(y_balance_fit, X_balance_fit).fit()
+ coeff = balance_model.params.get(treatment_var, np.nan)
+ p_value = balance_model.pvalues.get(treatment_var, np.nan)
+
+ coeffs_for_cov.append(coeff)
+ p_values_for_cov.append(p_value)
+ overall_summary["all_strata_coefficients"][cov].append(coeff)
+ overall_summary["all_strata_p_values"][cov].append(p_value)
+
+ stratum_detail["status"] = "Analyzed"
+ stratum_detail["coefficient_on_treatment"] = coeff
+ stratum_detail["p_value_on_treatment"] = p_value
+ if not pd.isna(p_value) and p_value < 0.05:
+ balance_results[cov]["num_significant_strata_p005"] += 1
+ balance_results[cov]["balanced_heuristic"] = False # If any stratum is unbalanced
+
+ except Exception as e_bal:
+ logger.debug(f"Balance check regression failed for {cov} in stratum {stratum_idx}: {e_bal}")
+ stratum_detail["status"] = f"Error: {str(e_bal)}"
+ stratum_detail["coefficient_on_treatment"] = np.nan
+ stratum_detail["p_value_on_treatment"] = np.nan
+
+ balance_results[cov]["strata_details"].append(stratum_detail)
+
+ if coeffs_for_cov:
+ balance_results[cov]["mean_abs_coefficient"] = np.nanmean(np.abs(coeffs_for_cov))
+ else:
+ balance_results[cov]["mean_abs_coefficient"] = np.nan # No strata were analyzable
+
+ overall_summary["num_covariates_potentially_imbalanced_p005"] = sum(
+ 1 for cov_data in balance_results.values() if not cov_data["balanced_heuristic"]
+ )
+
+ except Exception as e:
+ logger.error(f"Error during GPS balance assessment: {e}", exc_info=True)
+ overall_summary["error"] = f"Overall assessment error: {str(e)}"
+ return {
+ "error": str(e),
+ "balance_results": balance_results,
+ "summary_stats": overall_summary,
+ "summary": "Balance assessment failed due to an unexpected error."
+ }
+
+ logger.info("GPS balance assessment complete.")
+
+ return {
+ "balance_results_per_covariate": balance_results,
+ "summary_stats": overall_summary,
+ "summary": "GPS balance assessment finished. Review strata details and mean absolute coefficients."
+ }
\ No newline at end of file
diff --git a/auto_causal/methods/generalized_propensity_score/estimator.py b/auto_causal/methods/generalized_propensity_score/estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..311c5642b6d0919c61677ef898dbb536d56f1c01
--- /dev/null
+++ b/auto_causal/methods/generalized_propensity_score/estimator.py
@@ -0,0 +1,386 @@
+"""
+Core estimation logic for the Generalized Propensity Score (GPS) method.
+"""
+from typing import Dict, List, Any
+import pandas as pd
+import logging
+import numpy as np
+import statsmodels.api as sm
+
+from .diagnostics import assess_gps_balance # Import for balance check
+
+logger = logging.getLogger(__name__)
+
+def estimate_effect_gps(
+ df: pd.DataFrame,
+ treatment: str,
+ outcome: str,
+ covariates: List[str],
+ **kwargs: Any
+) -> Dict[str, Any]:
+ """
+ Estimates the causal effect using the Generalized Propensity Score method
+ for continuous treatments.
+
+ This function will be called by the method_executor_tool.
+
+ Args:
+ df: The input DataFrame.
+ treatment: The name of the continuous treatment variable column.
+ outcome: The name of the outcome variable column.
+ covariates: A list of covariate column names.
+ **kwargs: Additional arguments for controlling the estimation, including:
+ - gps_model_spec (dict): Specification for the GPS model (T ~ X).
+ - outcome_model_spec (dict): Specification for the outcome model (Y ~ T, GPS).
+ - t_values_range (list or dict): Specification for treatment levels for ADRF.
+ - n_bootstraps (int): Number of bootstrap replications for SEs.
+
+ Returns:
+ A dictionary containing the estimation results, including:
+ - "effect_estimate": Typically the ADRF or a specific contrast.
+ - "standard_error": Standard error for the primary effect estimate.
+ - "confidence_interval": Confidence interval for the primary estimate.
+ - "adrf_curve": Data representing the Average Dose-Response Function.
+ - "specific_contrasts": Any calculated specific contrasts.
+ - "diagnostics": Results from diagnostic checks (e.g., balance).
+ - "method_details": Description of the method and models used.
+ - "parameters_used": Dictionary of parameters used.
+ """
+ logger.info(f"Starting GPS estimation for treatment '{treatment}', outcome '{outcome}'.")
+
+ # --- Parameter Extraction and Defaults ---
+ gps_model_spec = kwargs.get('gps_model_spec', {"type": "linear"})
+ outcome_model_spec = kwargs.get('outcome_model_spec', {"type": "polynomial", "degree": 2, "interaction": True})
+
+ # Get t_values for ADRF from llm_assist or kwargs, default to 10 points over observed range
+ # For simplicity, we'll use a simple range here. In a full impl, this might call llm_assist.
+ t_values_for_adrf = kwargs.get('t_values_for_adrf')
+ if t_values_for_adrf is None:
+ min_t_obs = df[treatment].min()
+ max_t_obs = df[treatment].max()
+ if pd.isna(min_t_obs) or pd.isna(max_t_obs) or min_t_obs == max_t_obs:
+ logger.warning(f"Cannot determine a valid range for treatment '{treatment}' for ADRF. Using limited points.")
+ t_values_for_adrf = sorted(list(df[treatment].dropna().unique()))[:10] # Fallback
+ else:
+ t_values_for_adrf = np.linspace(min_t_obs, max_t_obs, 10).tolist()
+
+ n_bootstraps = kwargs.get('n_bootstraps', 0) # Default to 0, meaning no bootstrap for now
+
+ logger.info(f"Using GPS model spec: {gps_model_spec}")
+ logger.info(f"Using outcome model spec: {outcome_model_spec}")
+ logger.info(f"Evaluating ADRF at t-values: {t_values_for_adrf}")
+
+ try:
+ # 2. Estimate GPS Values
+ df_with_gps, gps_estimation_diagnostics = _estimate_gps_values(
+ df.copy(), treatment, covariates, gps_model_spec
+ )
+ if 'gps_score' not in df_with_gps.columns or df_with_gps['gps_score'].isnull().all():
+ logger.error("GPS estimation failed or resulted in all NaNs.")
+ return {
+ "error": "GPS estimation failed.",
+ "diagnostics": gps_estimation_diagnostics,
+ "method_details": "GPS (Failed)",
+ "parameters_used": kwargs
+ }
+
+ # Drop rows where GPS or outcome or necessary modeling variables are NaN before proceeding
+ modeling_cols = [outcome, treatment, 'gps_score'] + covariates
+ df_with_gps.dropna(subset=modeling_cols, inplace=True)
+ if df_with_gps.empty:
+ logger.error("DataFrame is empty after GPS estimation and NaN removal.")
+ return {"error": "No data available after GPS estimation and NaN removal.", "method_details": "GPS (Failed)", "parameters_used": kwargs}
+
+
+ # 3. Assess GPS Balance (call diagnostics.assess_gps_balance)
+ balance_diagnostics = assess_gps_balance(
+ df_with_gps, treatment, covariates, 'gps_score' # kwargs for assess_gps_balance can be passed if needed
+ )
+
+ # 4. Estimate Outcome Model
+ fitted_outcome_model = _estimate_outcome_model(
+ df_with_gps, outcome, treatment, 'gps_score', outcome_model_spec
+ )
+
+ # 5. Generate Dose-Response Function
+ adrf_results = _generate_dose_response_function(
+ df_with_gps, fitted_outcome_model, treatment, 'gps_score', outcome_model_spec, t_values_for_adrf
+ )
+ adrf_curve_data = {"t_levels": t_values_for_adrf, "expected_outcomes": adrf_results}
+
+ # 6. Calculate specific contrasts if requested (Placeholder)
+ specific_contrasts = {"info": "Specific contrasts not implemented in this version."}
+
+ # 7. Perform bootstrapping for SEs if requested (Placeholder for now)
+ standard_error_info = {"info": "Bootstrap SEs not implemented in this version."}
+ confidence_interval_info = {"info": "Bootstrap CIs not implemented in this version."}
+ if n_bootstraps > 0:
+ logger.info(f"Bootstrapping with {n_bootstraps} replications (placeholder).")
+ # Actual bootstrapping logic would go here.
+ # For now, we'll just note that it's not implemented.
+
+ logger.info("GPS estimation steps completed.")
+
+ # Consolidate diagnostics
+ all_diagnostics = {
+ "gps_estimation_diagnostics": gps_estimation_diagnostics,
+ "balance_check": balance_diagnostics, # Now using the actual balance check results
+ "outcome_model_summary": str(fitted_outcome_model.summary()) if fitted_outcome_model else "Outcome model not fitted.",
+ "warnings": [], # Populate with any warnings during the process
+ "summary": "GPS estimation complete."
+ }
+
+ return {
+ "effect_estimate": adrf_curve_data, # The ADRF is the primary "effect"
+ "standard_error_info": standard_error_info, # Placeholder
+ "confidence_interval_info": confidence_interval_info, # Placeholder
+ "adrf_curve": adrf_curve_data,
+ "specific_contrasts": specific_contrasts, # Placeholder
+ "diagnostics": all_diagnostics,
+ "method_details": f"Generalized Propensity Score (GPS) with {gps_model_spec.get('type', 'N/A')} GPS model and {outcome_model_spec.get('type', 'N/A')} outcome model.",
+ "parameters_used": {
+ "treatment_var": treatment,
+ "outcome_var": outcome,
+ "covariate_vars": covariates,
+ "gps_model_spec": gps_model_spec,
+ "outcome_model_spec": outcome_model_spec,
+ "t_values_for_adrf": t_values_for_adrf,
+ "n_bootstraps": n_bootstraps,
+ **kwargs
+ }
+ }
+ except Exception as e:
+ logger.error(f"Error during GPS estimation pipeline: {e}", exc_info=True)
+ return {
+ "error": f"Pipeline failed: {str(e)}",
+ "method_details": "GPS (Failed)",
+ "diagnostics": {"error": f"Pipeline failed during GPS estimation: {str(e)}"}, # Add diagnostics here too
+ "parameters_used": kwargs
+ }
+
+
+# Placeholder for internal helper functions
+def _estimate_gps_values(
+ df: pd.DataFrame,
+ treatment: str,
+ covariates: List[str],
+ gps_model_spec: Dict
+) -> tuple[pd.DataFrame, Dict]:
+ """
+ Estimates Generalized Propensity Scores.
+ Assumes T | X ~ N(X*beta, sigma^2), so GPS is the conditional density.
+ """
+ logger.info(f"Estimating GPS for treatment '{treatment}' using covariates: {covariates}")
+ diagnostics = {}
+
+ if not covariates:
+ logger.error("No covariates provided for GPS estimation.")
+ diagnostics["error"] = "No covariates provided."
+ df['gps_score'] = np.nan # Ensure gps_score column is added
+ return df, diagnostics
+
+ X_df = df[covariates]
+ T_series = df[treatment]
+
+ # Handle potential NaN values in covariates or treatment before modeling
+ valid_indices = X_df.dropna().index.intersection(T_series.dropna().index)
+ if len(valid_indices) < len(df):
+ logger.warning(f"Dropped {len(df) - len(valid_indices)} rows due to NaNs in treatment/covariates before GPS estimation.")
+ diagnostics["pre_estimation_nan_rows_dropped"] = len(df) - len(valid_indices)
+
+ X = X_df.loc[valid_indices]
+ T = T_series.loc[valid_indices]
+
+ if X.empty or T.empty:
+ logger.error("Covariate or treatment data is empty after NaN handling.")
+ diagnostics["error"] = "Covariate or treatment data is empty after NaN handling."
+ return df, diagnostics
+
+ X_sm = sm.add_constant(X, has_constant='add')
+
+ try:
+ if gps_model_spec.get("type") == 'linear':
+ model = sm.OLS(T, X_sm).fit()
+ t_hat = model.predict(X_sm)
+ residuals = T - t_hat
+ # MSE: sum of squared residuals / (n - k) where k is number of regressors (including const)
+ if len(T) <= X_sm.shape[1]:
+ logger.error("Not enough degrees of freedom to estimate sigma_sq_hat.")
+ diagnostics["error"] = "Not enough degrees of freedom for GPS variance."
+ df['gps_score'] = np.nan
+ return df, diagnostics
+
+ sigma_sq_hat = np.sum(residuals**2) / (len(T) - X_sm.shape[1])
+
+ if sigma_sq_hat <= 1e-9: # Check for effectively zero or very small variance
+ logger.warning(f"Estimated residual variance (sigma_sq_hat) is very close to zero ({sigma_sq_hat}). GPS will be set to NaN.")
+ diagnostics["warning_sigma_sq_hat_near_zero"] = sigma_sq_hat
+ df['gps_score'] = np.nan # Set GPS to NaN as density is ill-defined
+ if sigma_sq_hat == 0: # if it is exactly zero, add specific error
+ diagnostics["error_sigma_sq_hat_is_zero"] = "Residual variance is exactly zero."
+ return df, diagnostics
+
+
+ # Calculate GPS: (1 / sqrt(2*pi*sigma_hat^2)) * exp(-(T_i - T_hat_i)^2 / (2*sigma_hat^2))
+ # Ensure calculation is done on the original T values (T_series.loc[valid_indices])
+ # and corresponding t_hat for those valid_indices
+ gps_values_calculated = (1 / np.sqrt(2 * np.pi * sigma_sq_hat)) * np.exp(-((T - t_hat)**2) / (2 * sigma_sq_hat))
+
+ # Assign back to the original DataFrame using .loc to ensure alignment
+ df['gps_score'] = np.nan # Initialize column
+ df.loc[valid_indices, 'gps_score'] = gps_values_calculated
+
+ diagnostics["gps_model_type"] = "linear_ols"
+ diagnostics["gps_model_rsquared"] = model.rsquared
+ diagnostics["gps_residual_variance_mse"] = sigma_sq_hat
+ diagnostics["num_observations_for_gps_model"] = len(T)
+
+ else:
+ logger.error(f"GPS model type '{gps_model_spec.get('type')}' not implemented.")
+ diagnostics["error"] = f"GPS model type '{gps_model_spec.get('type')}' not implemented."
+ df['gps_score'] = np.nan
+
+ except Exception as e:
+ logger.error(f"Error during GPS model estimation: {e}", exc_info=True)
+ diagnostics["error"] = f"Exception during GPS estimation: {str(e)}"
+ df['gps_score'] = np.nan
+
+ # Ensure the original df is not modified if no valid indices for GPS estimation
+ if 'gps_score' not in df.columns:
+ df['gps_score'] = np.nan
+
+ return df, diagnostics
+
+def _estimate_outcome_model(
+ df_with_gps: pd.DataFrame,
+ outcome: str,
+ treatment: str,
+ gps_col_name: str,
+ outcome_model_spec: Dict
+) -> Any: # Returns a fitted statsmodels model
+ """
+ Estimates the outcome model Y ~ f(T, GPS).
+ """
+ logger.info(f"Estimating outcome model for '{outcome}' using T='{treatment}', GPS='{gps_col_name}'")
+
+ Y = df_with_gps[outcome]
+ T_val = pd.Series(df_with_gps[treatment].values, index=df_with_gps.index)
+ GPS_val = pd.Series(df_with_gps[gps_col_name].values, index=df_with_gps.index)
+
+ X_outcome_dict = {'intercept': np.ones(len(df_with_gps))}
+
+ model_type = outcome_model_spec.get("type", "polynomial")
+ degree = outcome_model_spec.get("degree", 2)
+ interaction = outcome_model_spec.get("interaction", True)
+
+ if model_type == "polynomial":
+ X_outcome_dict['T'] = T_val
+ X_outcome_dict['GPS'] = GPS_val
+ if degree >= 2:
+ X_outcome_dict['T_sq'] = T_val**2
+ X_outcome_dict['GPS_sq'] = GPS_val**2
+ if degree >=3: # Example for higher order, can be made more general
+ X_outcome_dict['T_cub'] = T_val**3
+ X_outcome_dict['GPS_cub'] = GPS_val**3
+ if interaction:
+ X_outcome_dict['T_x_GPS'] = T_val * GPS_val
+ if degree >=2: # Interaction with squared terms if degree allows
+ X_outcome_dict['T_sq_x_GPS'] = (T_val**2) * GPS_val
+ X_outcome_dict['T_x_GPS_sq'] = T_val * (GPS_val**2)
+
+ # Add more model types as needed (e.g., splines)
+ else:
+ logger.warning(f"Outcome model type '{model_type}' not fully recognized. Defaulting to T + GPS.")
+ X_outcome_dict['T'] = T_val
+ X_outcome_dict['GPS'] = GPS_val
+ # Fallback to linear if spec is unknown or simple
+
+ X_outcome_df = pd.DataFrame(X_outcome_dict, index=df_with_gps.index)
+
+ # Drop rows with NaNs that might have been introduced by transformations if T or GPS were NaN
+ # (though earlier dropna should handle most of this for input T/GPS)
+ valid_outcome_model_indices = Y.dropna().index.intersection(X_outcome_df.dropna().index)
+ if len(valid_outcome_model_indices) < len(df_with_gps):
+ logger.warning(f"Dropped {len(df_with_gps) - len(valid_outcome_model_indices)} rows due to NaNs before outcome model fitting.")
+
+ Y_fit = Y.loc[valid_outcome_model_indices]
+ X_outcome_df_fit = X_outcome_df.loc[valid_outcome_model_indices]
+
+ if Y_fit.empty or X_outcome_df_fit.empty:
+ logger.error("Not enough data to fit outcome model after NaN handling.")
+ raise ValueError("Empty data for outcome model fitting.")
+
+ try:
+ model = sm.OLS(Y_fit, X_outcome_df_fit).fit()
+ logger.info("Outcome model estimated successfully.")
+ return model
+ except Exception as e:
+ logger.error(f"Error during outcome model estimation: {e}", exc_info=True)
+ raise # Re-raise the exception to be caught by the main try-except block
+
+def _generate_dose_response_function(
+ df_with_gps: pd.DataFrame,
+ fitted_outcome_model: Any,
+ treatment: str,
+ gps_col_name: str,
+ outcome_model_spec: Dict, # To know how to construct X_pred features
+ t_values_to_evaluate: List[float]
+) -> List[float]:
+ """
+ Calculates the Average Dose-Response Function (ADRF).
+ E[Y(t)] = integral over E[Y | T=t, GPS=g] * f(g) dg
+ ~= (1/N) * sum_i E[Y | T=t, GPS=g_i] (using observed GPS values)
+ """
+ logger.info(f"Calculating ADRF for treatment levels: {t_values_to_evaluate}")
+ adrf_estimates = []
+
+ if not t_values_to_evaluate: # Handle empty list case
+ logger.warning("t_values_to_evaluate is empty. ADRF calculation will be skipped.")
+ return []
+
+ model_exog_names = fitted_outcome_model.model.exog_names
+
+ # Original GPS values from the dataframe
+ original_gps_values = pd.Series(df_with_gps[gps_col_name].values, index=df_with_gps.index)
+
+ for t_level in t_values_to_evaluate:
+ # Create a new DataFrame for prediction at this t_level
+ # Each row corresponds to an original observation's GPS, but with T set to t_level
+ X_pred_dict = {'intercept': np.ones(len(df_with_gps))}
+
+ # Reconstruct features based on outcome_model_spec and model_exog_names
+ # This mirrors the construction in _estimate_outcome_model
+ degree = outcome_model_spec.get("degree", 2)
+ interaction = outcome_model_spec.get("interaction", True)
+
+ if 'T' in model_exog_names: X_pred_dict['T'] = t_level
+ if 'GPS' in model_exog_names: X_pred_dict['GPS'] = original_gps_values
+
+ if 'T_sq' in model_exog_names: X_pred_dict['T_sq'] = t_level**2
+ if 'GPS_sq' in model_exog_names: X_pred_dict['GPS_sq'] = original_gps_values**2
+
+ if 'T_cub' in model_exog_names: X_pred_dict['T_cub'] = t_level**3 # Example
+ if 'GPS_cub' in model_exog_names: X_pred_dict['GPS_cub'] = original_gps_values**3 # Example
+
+ if 'T_x_GPS' in model_exog_names and interaction:
+ X_pred_dict['T_x_GPS'] = t_level * original_gps_values
+ if 'T_sq_x_GPS' in model_exog_names and interaction and degree >=2:
+ X_pred_dict['T_sq_x_GPS'] = (t_level**2) * original_gps_values
+ if 'T_x_GPS_sq' in model_exog_names and interaction and degree >=2:
+ X_pred_dict['T_x_GPS_sq'] = t_level * (original_gps_values**2)
+
+ X_pred_df = pd.DataFrame(X_pred_dict, index=df_with_gps.index)
+
+ # Ensure all required columns are present and in the correct order
+ # Drop any rows that might have NaNs if original_gps_values had NaNs (though they should be filtered before this)
+ X_pred_df_fit = X_pred_df[model_exog_names].dropna()
+
+ if X_pred_df_fit.empty:
+ logger.warning(f"Prediction data for t_level={t_level} is empty after NaN drop. Assigning NaN to ADRF point.")
+ adrf_estimates.append(np.nan)
+ continue
+
+ predicted_outcomes_at_t = fitted_outcome_model.predict(X_pred_df_fit)
+ adrf_estimates.append(np.mean(predicted_outcomes_at_t))
+
+ return adrf_estimates
\ No newline at end of file
diff --git a/auto_causal/methods/generalized_propensity_score/llm_assist.py b/auto_causal/methods/generalized_propensity_score/llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..55a07eae6b708a77986be04f426bc24cc843e201
--- /dev/null
+++ b/auto_causal/methods/generalized_propensity_score/llm_assist.py
@@ -0,0 +1,208 @@
+"""
+LLM-assisted components for the Generalized Propensity Score (GPS) method.
+
+These functions help in suggesting model specifications or parameters
+by leveraging an LLM, providing intelligent defaults when not specified by the user.
+"""
+from typing import Dict, List, Any, Optional
+import pandas as pd
+import logging
+from auto_causal.utils.llm_helpers import call_llm_with_json_output # Hypothetical import
+
+logger = logging.getLogger(__name__)
+
+def suggest_treatment_model_spec(
+ df: pd.DataFrame,
+ treatment_var: str,
+ covariate_vars: List[str],
+ query: Optional[str] = None,
+ llm_client: Optional[Any] = None
+) -> Dict[str, Any]:
+ """
+ Suggests a model specification for the treatment mechanism (T ~ X) in GPS.
+
+ Args:
+ df: The input DataFrame.
+ treatment_var: The name of the continuous treatment variable.
+ covariate_vars: A list of covariate names.
+ query: Optional user query for context.
+ llm_client: Optional LLM client for making a call.
+
+ Returns:
+ A dictionary representing the suggested model specification.
+ E.g., {"type": "linear", "formula": "T ~ X1 + X2"} or
+ {"type": "random_forest", "params": {...}}
+ """
+ logger.info(f"Suggesting treatment model spec for: {treatment_var}")
+
+ # Example of constructing a more detailed prompt for an LLM
+ prompt_parts = [
+ f"You are an expert econometrician. The user wants to estimate a Generalized Propensity Score (GPS) for a continuous treatment variable '{treatment_var}'.",
+ f"The available covariates are: {covariate_vars}.",
+ f"The user's research query is: '{query if query else 'Not specified'}'.",
+ "Based on this information and general best practices for GPS estimation:",
+ "1. Suggest a suitable model type for estimating the treatment (T) given covariates (X). Common choices include 'linear' (OLS), or flexible models like 'random_forest' or 'gradient_boosting' if non-linearities are suspected.",
+ "2. If suggesting a regression model like OLS, provide a Patsy-style formula string (e.g., 'treatment ~ cov1 + cov2 + cov1*cov2').",
+ "3. If suggesting a machine learning model, list key hyperparameters and reasonable starting values (e.g., n_estimators, max_depth).",
+ "Return your suggestion as a JSON object with the following structure:",
+ '''
+ {
+ "model_type": "",
+ "formula": "",
+ "parameters": { // if applicable for ML models
+ "": "",
+ "": ""
+ },
+ "reasoning": ""
+ }
+ '''
+ ]
+ full_prompt = "\n".join(prompt_parts)
+
+ if llm_client:
+ logger.info("LLM client provided. Sending constructed prompt (actual call is hypothetical).")
+ logger.debug(f"LLM Prompt for treatment model spec:\n{full_prompt}")
+ # In a real implementation:
+ # response_json = call_llm_with_json_output(llm_client, full_prompt)
+ # if response_json and isinstance(response_json, dict):
+ # return response_json
+ # else:
+ # logger.warning("LLM did not return a valid JSON dict for treatment model spec.")
+ pass # Pass for now as it's a hypothetical call
+
+ # Default suggestion if no LLM or LLM fails
+ return {
+ "model_type": "linear",
+ "formula": f"{treatment_var} ~ {' + '.join(covariate_vars) if covariate_vars else '1'}",
+ "parameters": None,
+ "reasoning": "Defaulting to a linear model for T ~ X. Consider a more flexible model if non-linearities are expected.",
+ "comment": "This is a default suggestion."
+ }
+
+def suggest_outcome_model_spec(
+ df: pd.DataFrame,
+ outcome_var: str,
+ treatment_var: str,
+ gps_col_name: str,
+ query: Optional[str] = None,
+ llm_client: Optional[Any] = None
+) -> Dict[str, Any]:
+ """
+ Suggests a model specification for the outcome mechanism (Y ~ T, GPS) in GPS.
+
+ Args:
+ df: The input DataFrame.
+ outcome_var: The name of the outcome variable.
+ treatment_var: The name of the continuous treatment variable.
+ gps_col_name: The name of the GPS column.
+ query: Optional user query for context.
+ llm_client: Optional LLM client for making a call.
+
+ Returns:
+ A dictionary representing the suggested model specification.
+ E.g., {"type": "polynomial", "degree": 2, "interaction": True,
+ "formula": "Y ~ T + T^2 + GPS + GPS^2 + T*GPS"}
+ """
+ logger.info(f"Suggesting outcome model spec for: {outcome_var}")
+
+ prompt_parts = [
+ f"You are an expert econometrician. For a Generalized Propensity Score (GPS) analysis, the user needs to model the outcome '{outcome_var}' conditional on the continuous treatment '{treatment_var}' and the estimated GPS (column name '{gps_col_name}').",
+ "The goal is to flexibly capture the relationship E[Y | T, GPS]. A common approach is to use a polynomial specification for T and GPS, including interaction terms.",
+ f"The user's research query is: '{query if query else 'Not specified'}'.",
+ "Suggest a specification for this outcome model. Consider:",
+ "1. The functional form for T (e.g., linear, quadratic, cubic).",
+ "2. The functional form for GPS (e.g., linear, quadratic, cubic).",
+ "3. Whether to include interaction terms between T and GPS (e.g., T*GPS, T^2*GPS, T*GPS^2).",
+ "Return your suggestion as a JSON object with the following structure:",
+ '''
+ {
+ "model_type": "polynomial", // Or other types like "splines"
+ "treatment_terms": ["T", "T_sq"], // e.g., ["T"] for linear, ["T", "T_sq"] for quadratic
+ "gps_terms": ["GPS", "GPS_sq"], // e.g., ["GPS"] for linear, ["GPS", "GPS_sq"] for quadratic
+ "interaction_terms": ["T_x_GPS", "T_sq_x_GPS", "T_x_GPS_sq"], // Interactions to include, or empty list
+ "reasoning": ""
+ }
+ '''
+ ]
+ full_prompt = "\n".join(prompt_parts)
+
+ if llm_client:
+ logger.info("LLM client provided. Sending constructed prompt for outcome model (hypothetical call).")
+ logger.debug(f"LLM Prompt for outcome model spec:\n{full_prompt}")
+ # In a real implementation:
+ # response_json = call_llm_with_json_output(llm_client, full_prompt)
+ # if response_json and isinstance(response_json, dict):
+ # # Basic validation of expected keys for outcome model could go here
+ # return response_json
+ # else:
+ # logger.warning("LLM did not return a valid JSON dict for outcome model spec.")
+ pass # Pass for now
+
+ # Default suggestion
+ return {
+ "model_type": "polynomial",
+ "treatment_terms": ["T", "T_sq"],
+ "gps_terms": ["GPS", "GPS_sq"],
+ "interaction_terms": ["T_x_GPS"],
+ "reasoning": "Defaulting to a quadratic specification for T and GPS with a simple T*GPS interaction. This is a common starting point.",
+ "comment": "This is a default suggestion."
+ }
+
+def suggest_dose_response_t_values(
+ df: pd.DataFrame,
+ treatment_var: str,
+ num_points: int = 20,
+ llm_client: Optional[Any] = None
+) -> List[float]:
+ """
+ Suggests a relevant range and number of points for estimating the ADRF.
+
+ Args:
+ df: The input DataFrame.
+ treatment_var: The name of the continuous treatment variable.
+ num_points: Desired number of points for the ADRF curve.
+ llm_client: Optional LLM client for making a call.
+
+ Returns:
+ A list of treatment values at which to evaluate the ADRF.
+ """
+ logger.info(f"Suggesting dose response t-values for: {treatment_var}")
+
+ prompt_parts = [
+ f"For a Generalized Propensity Score (GPS) analysis with continuous treatment '{treatment_var}', the user needs to estimate an Average Dose-Response Function (ADRF).",
+ f"The observed range of '{treatment_var}' is from {df[treatment_var].min():.2f} to {df[treatment_var].max():.2f}.",
+ f"The user desires approximately {num_points} points for the ADRF curve.",
+ f"The user's research query is: '{query if query else 'Not specified'}'.",
+ "Suggest a list of specific treatment values (t_values) at which to evaluate the ADRF. Consider:",
+ "1. Covering the observed range of the treatment.",
+ "2. Potentially including specific points of policy interest if deducible from the query (though this is advanced).",
+ "3. Ensuring a reasonable distribution of points (e.g., equally spaced, or based on quantiles).",
+ "Return your suggestion as a JSON object with a single key 't_values' holding a list of floats:",
+ '''
+ {
+ "t_values": [, , ..., ],
+ "reasoning": ""
+ }
+ '''
+ ]
+ full_prompt = "\n".join(prompt_parts)
+
+ if llm_client:
+ logger.info("LLM client provided. Sending prompt for t-values (hypothetical call).")
+ logger.debug(f"LLM Prompt for t-values:\n{full_prompt}")
+ # In a real implementation:
+ # response_json = call_llm_with_json_output(llm_client, full_prompt)
+ # if response_json and isinstance(response_json, dict) and 't_values' in response_json and isinstance(response_json['t_values'], list):
+ # return response_json['t_values'] # Assuming it returns the list directly based on current function signature
+ # else:
+ # logger.warning("LLM did not return a valid JSON with 't_values' list for ADRF points.")
+ pass # Pass for now
+
+ # Default: Linearly spaced points
+ min_t = df[treatment_var].min()
+ max_t = df[treatment_var].max()
+ if pd.isna(min_t) or pd.isna(max_t) or min_t == max_t:
+ logger.warning(f"Could not determine a valid range for treatment '{treatment_var}'. Returning empty list.")
+ return []
+
+ return list(pd.Series.linspace(min_t, max_t, num_points))
\ No newline at end of file
diff --git a/auto_causal/methods/instrumental_variable/__init__.py b/auto_causal/methods/instrumental_variable/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b105bc6452b27b0a548d217652b23c0506668e3
--- /dev/null
+++ b/auto_causal/methods/instrumental_variable/__init__.py
@@ -0,0 +1 @@
+from .estimator import estimate_effect
\ No newline at end of file
diff --git a/auto_causal/methods/instrumental_variable/diagnostics.py b/auto_causal/methods/instrumental_variable/diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ffaee08c49ee8a55d40ce12aa935b1d1eb41f68
--- /dev/null
+++ b/auto_causal/methods/instrumental_variable/diagnostics.py
@@ -0,0 +1,218 @@
+# Placeholder for IV-specific diagnostic functions
+import pandas as pd
+import statsmodels.api as sm
+from statsmodels.regression.linear_model import OLS
+# from statsmodels.sandbox.regression.gmm import IV2SLSResults # Removed problematic import
+from typing import Dict, Any, List, Tuple, Optional
+import logging # Import logging
+import numpy as np # Import numpy for np.zeros
+
+# Configure logger
+logger = logging.getLogger(__name__)
+
+def calculate_first_stage_f_statistic(df: pd.DataFrame, treatment: str, instruments: List[str], covariates: List[str]) -> Tuple[Optional[float], Optional[float]]:
+ """
+ Calculates the F-statistic for instrument relevance in the first stage regression.
+
+ Regresses treatment ~ instruments + covariates.
+ Tests the joint significance of the instrument coefficients.
+
+ Args:
+ df: Input DataFrame.
+ treatment: Name of the treatment variable.
+ instruments: List of instrument variable names.
+ covariates: List of covariate names.
+
+ Returns:
+ A tuple containing (F-statistic, p-value). Returns (None, None) on error.
+ """
+ logger.info("Diagnostics: Calculating First-Stage F-statistic...")
+ try:
+ df_copy = df.copy()
+ df_copy['intercept'] = 1
+ exog_vars = ['intercept'] + covariates
+ all_first_stage_exog = list(dict.fromkeys(exog_vars + instruments)) # Ensure unique columns
+
+ endog = df_copy[treatment]
+ exog = df_copy[all_first_stage_exog]
+
+ # Check for perfect multicollinearity before fitting
+ if exog.shape[1] > 1:
+ corr_matrix = exog.corr()
+ # Check if correlation matrix calculation failed (e.g., constant columns) or high correlation
+ if corr_matrix.isnull().values.any() or (corr_matrix.abs() > 0.9999).sum().sum() > exog.shape[1]: # Check off-diagonal elements
+ logger.warning("High multicollinearity or constant column detected in first stage exogenous variables.")
+ # Note: statsmodels OLS might handle perfect collinearity by dropping columns, but F-test might be unreliable.
+
+ first_stage_model = OLS(endog, exog).fit()
+
+ # Construct the restriction matrix (R) to test H0: instrument coeffs = 0
+ num_instruments = len(instruments)
+ if num_instruments == 0:
+ logger.warning("No instruments provided for F-statistic calculation.")
+ return None, None
+ num_exog_total = len(all_first_stage_exog)
+
+ # Ensure instruments are actually in the fitted model's exog names (in case statsmodels dropped some)
+ fitted_exog_names = first_stage_model.model.exog_names
+ valid_instruments = [inst for inst in instruments if inst in fitted_exog_names]
+ if not valid_instruments:
+ logger.error("None of the provided instruments were included in the first-stage regression model (possibly due to collinearity).")
+ return None, None
+ if len(valid_instruments) < len(instruments):
+ logger.warning(f"Instruments dropped by OLS: {set(instruments) - set(valid_instruments)}")
+
+ instrument_indices = [fitted_exog_names.index(inst) for inst in valid_instruments]
+
+ # Need to adjust R matrix size based on fitted model's exog
+ R = np.zeros((len(valid_instruments), len(fitted_exog_names)))
+ for i, idx in enumerate(instrument_indices):
+ R[i, idx] = 1
+
+ # Perform F-test
+ f_test_result = first_stage_model.f_test(R)
+
+ f_statistic = float(f_test_result.fvalue)
+ p_value = float(f_test_result.pvalue)
+
+ logger.info(f" F-statistic: {f_statistic:.4f}, p-value: {p_value:.4f}")
+ return f_statistic, p_value
+
+ except Exception as e:
+ logger.error(f"Error calculating first-stage F-statistic: {e}", exc_info=True)
+ return None, None
+
+def run_overidentification_test(sm_results: Optional[Any], df: pd.DataFrame, treatment: str, outcome: str, instruments: List[str], covariates: List[str]) -> Tuple[Optional[float], Optional[float], Optional[str]]:
+ """
+ Runs an overidentification test (Sargan-Hansen) if applicable.
+
+ This test is only valid if the number of instruments exceeds the number
+ of endogenous regressors (typically 1, the treatment variable).
+
+ Requires results from a statsmodels IV estimation.
+
+ Args:
+ sm_results: The fitted results object from statsmodels IV2SLS.fit().
+ df: Input DataFrame.
+ treatment: Name of the treatment variable.
+ outcome: Name of the outcome variable.
+ instruments: List of instrument variable names.
+ covariates: List of covariate names.
+
+ Returns:
+ Tuple: (test_statistic, p_value, status_message) or (None, None, error_message)
+ """
+ logger.info("Diagnostics: Running Overidentification Test...")
+ num_instruments = len(instruments)
+ num_endog = 1 # Assuming only one treatment variable is endogenous
+
+ if num_instruments <= num_endog:
+ logger.info(" Over-ID test not applicable (model is exactly identified or underidentified).")
+ return None, None, "Test not applicable (Need more instruments than endogenous regressors)"
+
+ if sm_results is None or not hasattr(sm_results, 'resid'):
+ logger.warning(" Over-ID test requires valid statsmodels results object with residuals.")
+ return None, None, "Statsmodels results object not available or invalid for test."
+
+ try:
+ # Statsmodels IV2SLSResults does not seem to have a direct method for this test (as of common versions).
+ # We need to calculate it manually using residuals and instruments.
+ # Formula: N * R^2 from regressing residuals (u_hat) on all exogenous variables (instruments + covariates).
+ # Degrees of freedom = num_instruments - num_endogenous_vars
+
+ residuals = sm_results.resid
+ df_copy = df.copy()
+ df_copy['intercept'] = 1
+ exog_vars = ['intercept'] + covariates
+ all_exog_instruments = list(dict.fromkeys(exog_vars + instruments))
+
+ # Ensure columns exist in the dataframe before selecting
+ missing_cols = [col for col in all_exog_instruments if col not in df_copy.columns]
+ if missing_cols:
+ raise ValueError(f"Missing columns required for Over-ID test: {missing_cols}")
+
+ exog_for_test = df_copy[all_exog_instruments]
+
+ # Check shapes match after potential NA handling in main estimator
+ if len(residuals) != exog_for_test.shape[0]:
+ # Attempt to align based on index if lengths differ (might happen if NAs were dropped)
+ logger.warning(f"Residual length ({len(residuals)}) differs from exog_for_test rows ({exog_for_test.shape[0]}). Trying to align indices.")
+ common_index = residuals.index.intersection(exog_for_test.index)
+ if len(common_index) == 0:
+ raise ValueError("Cannot align residuals and exogenous variables for Over-ID test after NA handling.")
+ residuals = residuals.loc[common_index]
+ exog_for_test = exog_for_test.loc[common_index]
+ logger.warning(f"Aligned to {len(common_index)} common observations.")
+
+
+ # Regress residuals on all exogenous instruments
+ aux_model = OLS(residuals, exog_for_test).fit()
+ r_squared = aux_model.rsquared
+ n_obs = len(residuals) # Use length of residuals after potential alignment
+
+ test_statistic = n_obs * r_squared
+
+ # Calculate p-value from Chi-squared distribution
+ from scipy.stats import chi2
+ degrees_of_freedom = num_instruments - num_endog
+ if degrees_of_freedom < 0:
+ # This shouldn't happen if the initial check passed, but as a safeguard
+ raise ValueError("Degrees of freedom for Sargan test are negative.")
+ elif degrees_of_freedom == 0:
+ # R-squared should be 0 if exactly identified, but handle edge case
+ p_value = 1.0 if np.isclose(test_statistic, 0) else 0.0
+ else:
+ p_value = chi2.sf(test_statistic, degrees_of_freedom)
+
+ logger.info(f" Sargan Test Statistic: {test_statistic:.4f}, p-value: {p_value:.4f}, df: {degrees_of_freedom}")
+ return test_statistic, p_value, "Test successful"
+
+ except Exception as e:
+ logger.error(f"Error running overidentification test: {e}", exc_info=True)
+ return None, None, f"Error during test: {e}"
+
+def run_iv_diagnostics(df: pd.DataFrame, treatment: str, outcome: str, instruments: List[str], covariates: List[str], sm_results: Optional[Any] = None, dw_results: Optional[Any] = None) -> Dict[str, Any]:
+ """
+ Runs standard IV diagnostic checks.
+
+ Args:
+ df: Input DataFrame.
+ treatment: Name of the treatment variable.
+ outcome: Name of the outcome variable.
+ instruments: List of instrument variable names.
+ covariates: List of covariate names.
+ sm_results: Optional fitted results object from statsmodels IV2SLS.fit().
+ dw_results: Optional results object from DoWhy (structure may vary).
+
+ Returns:
+ Dictionary containing diagnostic results.
+ """
+ diagnostics = {}
+
+ # 1. Instrument Relevance / Weak Instrument Test (First-Stage F-statistic)
+ f_stat, f_p_val = calculate_first_stage_f_statistic(df, treatment, instruments, covariates)
+ diagnostics['first_stage_f_statistic'] = f_stat
+ diagnostics['first_stage_p_value'] = f_p_val
+ diagnostics['is_instrument_weak'] = (f_stat < 10) if f_stat is not None else None # Common rule of thumb
+ if f_stat is None:
+ diagnostics['weak_instrument_test_status'] = "Error during calculation"
+ elif diagnostics['is_instrument_weak']:
+ diagnostics['weak_instrument_test_status'] = "Warning: Instrument(s) may be weak (F < 10)"
+ else:
+ diagnostics['weak_instrument_test_status'] = "Instrument(s) appear sufficiently strong (F >= 10)"
+
+
+ # 2. Overidentification Test (e.g., Sargan-Hansen)
+ overid_stat, overid_p_val, overid_status = run_overidentification_test(sm_results, df, treatment, outcome, instruments, covariates)
+ diagnostics['overid_test_statistic'] = overid_stat
+ diagnostics['overid_test_p_value'] = overid_p_val
+ diagnostics['overid_test_status'] = overid_status
+ diagnostics['overid_test_applicable'] = not ("not applicable" in overid_status.lower() if overid_status else True)
+
+ # 3. Exogeneity/Exclusion Restriction (Conceptual Check)
+ diagnostics['exclusion_restriction_assumption'] = "Assumed based on graph/input; cannot be statistically tested directly. Qualitative LLM check recommended."
+
+ # Potential future additions:
+ # - Endogeneity tests (e.g., Hausman test - requires comparing OLS and IV estimates)
+
+ return diagnostics
\ No newline at end of file
diff --git a/auto_causal/methods/instrumental_variable/estimator.py b/auto_causal/methods/instrumental_variable/estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b065af5a63d6630d40e5b9506750f18f2f65fcb
--- /dev/null
+++ b/auto_causal/methods/instrumental_variable/estimator.py
@@ -0,0 +1,370 @@
+import pandas as pd
+import statsmodels.api as sm
+from statsmodels.sandbox.regression.gmm import IV2SLS
+from dowhy import CausalModel # Primary path
+from typing import Dict, Any, List, Union, Optional
+import logging
+from langchain.chat_models.base import BaseChatModel
+
+from .diagnostics import run_iv_diagnostics
+from .llm_assist import identify_instrument_variable, validate_instrument_assumptions_qualitative, interpret_iv_results
+
+logger = logging.getLogger(__name__)
+
+def build_iv_graph_gml(treatment: str, outcome: str, instruments: List[str], covariates: List[str]) -> str:
+ """
+ Constructs a GML string representing the causal graph for IV.
+
+ Assumptions:
+ - Instruments cause Treatment
+ - Covariates cause Treatment and Outcome
+ - Treatment causes Outcome
+ - Instruments do NOT directly cause Outcome (Exclusion)
+ - Instruments are NOT caused by Covariates (can be relaxed if needed)
+ - Unobserved Confounder (U) affects Treatment and Outcome
+
+ Args:
+ treatment: Name of the treatment variable.
+ outcome: Name of the outcome variable.
+ instruments: List of instrument variable names.
+ covariates: List of covariate names.
+
+ Returns:
+ A GML graph string.
+ """
+ nodes = []
+ edges = []
+
+ # Define nodes - ensure no duplicates if a variable is both instrument and covariate (SHOULD NOT HAPPEN)
+ # Use a set to ensure unique variable names
+ all_vars_set = set([treatment, outcome] + instruments + covariates + ['U'])
+ all_vars = list(all_vars_set)
+
+ for var in all_vars:
+ nodes.append(f'node [ id "{var}" label "{var}" ]')
+
+ # Define edges
+ # Instruments -> Treatment
+ for inst in instruments:
+ edges.append(f'edge [ source "{inst}" target "{treatment}" ]')
+
+ # Covariates -> Treatment
+ for cov in covariates:
+ # Ensure we don't add self-loops or duplicate edges if cov == treatment (shouldn't happen)
+ if cov != treatment:
+ edges.append(f'edge [ source "{cov}" target "{treatment}" ]')
+
+ # Covariates -> Outcome
+ for cov in covariates:
+ if cov != outcome:
+ edges.append(f'edge [ source "{cov}" target "{outcome}" ]')
+
+ # Treatment -> Outcome
+ edges.append(f'edge [ source "{treatment}" target "{outcome}" ]')
+
+ # Unobserved Confounder -> Treatment and Outcome
+ edges.append(f'edge [ source "U" target "{treatment}" ]')
+ edges.append(f'edge [ source "U" target "{outcome}" ]')
+
+ # Core IV Assumption: Instruments are NOT caused by U (implicitly handled by not adding edge)
+ # Core IV Assumption: Instruments do NOT directly cause Outcome (handled by not adding edge)
+
+ # Format nodes and edges with indentation before inserting into f-string
+ formatted_nodes = '\n '.join(nodes)
+ formatted_edges = '\n '.join(edges)
+
+ gml_string = f"""
+graph [
+ directed 1
+ {formatted_nodes}
+ {formatted_edges}
+]
+"""
+ # Convert print to logger
+ logger.debug("\n--- Generated GML Graph ---")
+ logger.debug(gml_string)
+ logger.debug("-------------------------\n")
+ return gml_string
+
+def format_iv_results(estimate: Optional[float], raw_results: Dict, diagnostics: Dict, treatment: str, outcome: str, instrument: List[str], method_used: str, llm: Optional[BaseChatModel] = None) -> Dict[str, Any]:
+ """
+ Formats the results from IV estimation into a standardized dictionary.
+
+ Args:
+ estimate: The point estimate of the causal effect.
+ raw_results: Dictionary containing raw outputs from DoWhy/statsmodels.
+ diagnostics: Dictionary containing diagnostic results.
+ treatment: Name of the treatment variable.
+ outcome: Name of the outcome variable.
+ instrument: List of instrument variable names.
+ method_used: 'dowhy' or 'statsmodels'.
+ llm: Optional LLM instance for interpretation.
+
+ Returns:
+ Standardized results dictionary.
+ """
+ formatted = {
+ "effect_estimate": estimate,
+ "treatment_variable": treatment,
+ "outcome_variable": outcome,
+ "instrument_variables": instrument,
+ "method_used": method_used,
+ "diagnostics": diagnostics,
+ "raw_results": {k: str(v) for k, v in raw_results.items() if "object" not in k}, # Avoid serializing large objects
+ "confidence_interval": None,
+ "standard_error": None,
+ "p_value": None,
+ "interpretation": "Placeholder"
+ }
+
+ # Extract details from statsmodels results if available
+ sm_results = raw_results.get('statsmodels_results_object')
+ if method_used == 'statsmodels' and sm_results:
+ try:
+ # Use .bse for standard error in statsmodels results
+ formatted["standard_error"] = float(sm_results.bse[treatment])
+ formatted["p_value"] = float(sm_results.pvalues[treatment])
+ conf_int = sm_results.conf_int().loc[treatment].tolist()
+ formatted["confidence_interval"] = [float(ci) for ci in conf_int]
+ except AttributeError as e:
+ logger.warning(f"Could not extract all details from statsmodels results object (likely missing attribute): {e}")
+ except Exception as e:
+ logger.warning(f"Error extracting details from statsmodels results: {e}")
+
+ # Extract details from DoWhy results if available
+ # Note: DoWhy's CausalEstimate object structure needs inspection
+ dw_results = raw_results.get('dowhy_results_object')
+ if method_used == 'dowhy' and dw_results:
+ try:
+ # Attempt common attributes, may need adjustment based on DoWhy version/output
+ if hasattr(dw_results, 'stderr'):
+ formatted["standard_error"] = float(dw_results.stderr)
+ if hasattr(dw_results, 'p_value'):
+ formatted["p_value"] = float(dw_results.p_value)
+ if hasattr(dw_results, 'conf_intervals'):
+ # Assuming it's stored similarly to statsmodels, might need adjustment
+ ci = dw_results.conf_intervals().loc[treatment].tolist() # Fictional attribute/method - check DoWhy docs!
+ formatted["confidence_interval"] = [float(c) for c in ci]
+ elif hasattr(dw_results, 'get_confidence_intervals'):
+ ci = dw_results.get_confidence_intervals() # Check DoWhy docs for format
+ # Check format of ci before converting
+ if isinstance(ci, (list, tuple)) and len(ci) == 2:
+ formatted["confidence_interval"] = [float(c) for c in ci] # Adapt parsing
+ else:
+ logger.warning(f"Could not parse confidence intervals from DoWhy object: {ci}")
+
+ except Exception as e:
+ logger.warning(f"Could not extract all details from DoWhy results: {e}. Structure might be different.", exc_info=True)
+ # Avoid printing dir in production code, use logger.debug if needed for dev
+ # logger.debug(f"DoWhy result object dir(): {dir(dw_results)}")
+
+ # Generate LLM interpretation - pass llm object
+ if estimate is not None:
+ formatted["interpretation"] = interpret_iv_results(formatted, diagnostics, llm=llm)
+ else:
+ formatted["interpretation"] = "Estimation failed, cannot interpret results."
+
+
+ return formatted
+
+def estimate_effect(
+ df: pd.DataFrame,
+ treatment: str,
+ outcome: str,
+ covariates: List[str],
+ query: Optional[str] = None,
+ dataset_description: Optional[str] = None,
+ llm: Optional[BaseChatModel] = None,
+ **kwargs
+) -> Dict[str, Any]:
+
+ instrument = kwargs.get('instrument_variable')
+ if not instrument:
+ return {"error": "Instrument variable ('instrument_variable') not found in kwargs.", "method_used": "none", "diagnostics": {}}
+
+ instrument_list = [instrument] if isinstance(instrument, str) else instrument
+ valid_instruments = [inst for inst in instrument_list if isinstance(inst, str)]
+ clean_covariates = [cov for cov in covariates if cov not in valid_instruments]
+
+ logger.info(f"\n--- Starting Instrumental Variable Estimation ---")
+ logger.info(f"Treatment: {treatment}, Outcome: {outcome}, Instrument(s): {valid_instruments}, Original Covariates: {covariates}, Cleaned Covariates: {clean_covariates}")
+ results = {}
+ method_used = "none"
+ sm_results_obj = None
+ dw_results_obj = None
+ identified_estimand = None # Initialize
+ model = None # Initialize
+ refutation_results = {} # Initialize
+
+ # --- Input Validation ---
+ required_cols = [treatment, outcome] + valid_instruments + clean_covariates
+ missing_cols = [col for col in required_cols if col not in df.columns]
+ if missing_cols:
+ return {"error": f"Missing required columns in DataFrame: {missing_cols}", "method_used": method_used, "diagnostics": {}}
+ if not valid_instruments:
+ return {"error": "Instrument variable(s) must be provided and valid.", "method_used": method_used, "diagnostics": {}}
+
+ # --- LLM Pre-Checks ---
+ if query and llm:
+ qualitative_check = validate_instrument_assumptions_qualitative(treatment, outcome, valid_instruments, clean_covariates, query, llm=llm)
+ results['llm_assumption_check'] = qualitative_check
+ logger.info(f"LLM Qualitative Assumption Check: {qualitative_check}")
+
+ # --- Build Graph and Instantiate CausalModel (Do this before estimation attempts) ---
+ # This allows using identify_effect and refute_estimate even if DoWhy estimation fails
+ try:
+ graph = build_iv_graph_gml(treatment, outcome, valid_instruments, clean_covariates)
+ if not graph:
+ raise ValueError("Failed to build GML graph for DoWhy.")
+
+ model = CausalModel(data=df, treatment=treatment, outcome=outcome, graph=graph)
+
+ # Identify Effect (essential for refutation later)
+ identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
+ logger.debug("\nDoWhy Identified Estimand:")
+ logger.debug(identified_estimand)
+ if not identified_estimand:
+ raise ValueError("DoWhy could not identify a valid estimand.")
+
+ except Exception as model_init_e:
+ logger.error(f"Failed to initialize CausalModel or identify effect: {model_init_e}", exc_info=True)
+ # Cannot proceed without model/estimand for DoWhy or refutation
+ results['error'] = f"Failed to initialize CausalModel: {model_init_e}"
+ # Attempt statsmodels anyway? Or return error? Let's try statsmodels.
+ pass # Allow falling through to statsmodels if desired
+
+ # --- Primary Path: DoWhy Estimation ---
+ if model and identified_estimand and not kwargs.get('force_statsmodels', False):
+ logger.info("\nAttempting estimation with DoWhy...")
+ try:
+ dw_results_obj = model.estimate_effect(
+ identified_estimand,
+ method_name="iv.instrumental_variable",
+ method_params={'iv_instrument_name': valid_instruments}
+ )
+ logger.debug("\nDoWhy Estimation Result:")
+ logger.debug(dw_results_obj)
+ results['dowhy_estimate'] = dw_results_obj.value
+ results['dowhy_results_object'] = dw_results_obj
+ method_used = 'dowhy'
+ logger.info("DoWhy estimation successful.")
+ except Exception as e:
+ logger.error(f"DoWhy IV estimation failed: {e}", exc_info=True)
+ results['dowhy_error'] = str(e)
+ if not kwargs.get('allow_fallback', True):
+ logger.warning("Fallback to statsmodels disabled. Estimation failed.")
+ method_used = "dowhy_failed"
+ # Still run diagnostics and format output
+ else:
+ logger.info("Proceeding to statsmodels fallback.")
+ elif not model or not identified_estimand:
+ logger.warning("Skipping DoWhy estimation due to CausalModel initialization/identification failure.")
+ # Ensure we proceed to statsmodels if fallback is allowed
+ if not kwargs.get('allow_fallback', True):
+ logger.error("Cannot estimate effect: CausalModel failed and fallback disabled.")
+ method_used = "dowhy_failed"
+ else:
+ logger.info("Proceeding to statsmodels fallback.")
+
+ # --- Fallback Path: statsmodels IV2SLS ---
+ if method_used not in ['dowhy', 'dowhy_failed']:
+ logger.info("\nAttempting estimation with statsmodels IV2SLS...")
+ try:
+ df_copy = df.copy().dropna(subset=required_cols)
+ if df_copy.empty:
+ raise ValueError("DataFrame becomes empty after dropping NAs in required columns.")
+ df_copy['intercept'] = 1
+ exog_regressors = ['intercept'] + clean_covariates
+ endog_var = treatment
+ all_instruments_for_sm = list(dict.fromkeys(exog_regressors + valid_instruments))
+ endog_data = df_copy[outcome]
+ exog_data_sm_cols = list(dict.fromkeys(exog_regressors + [endog_var]))
+ exog_data_sm = df_copy[exog_data_sm_cols]
+ instrument_data_sm = df_copy[all_instruments_for_sm]
+ num_endog = 1
+ num_external_iv = len(valid_instruments)
+ if num_endog > num_external_iv:
+ raise ValueError(f"Model underidentified: More endogenous regressors ({num_endog}) than unique external instruments ({num_external_iv}).")
+ iv_model = IV2SLS(endog=endog_data, exog=exog_data_sm, instrument=instrument_data_sm)
+ sm_results_obj = iv_model.fit()
+ logger.info("\nStatsmodels Estimation Summary:")
+ logger.info(f" Estimate for {treatment}: {sm_results_obj.params[treatment]}")
+ logger.info(f" Std Error: {sm_results_obj.bse[treatment]}")
+ logger.info(f" P-value: {sm_results_obj.pvalues[treatment]}")
+ results['statsmodels_estimate'] = sm_results_obj.params[treatment]
+ results['statsmodels_results_object'] = sm_results_obj
+ method_used = 'statsmodels'
+ logger.info("Statsmodels estimation successful.")
+ except Exception as sm_e:
+ logger.error(f"Statsmodels IV estimation also failed: {sm_e}", exc_info=True)
+ results['statsmodels_error'] = str(sm_e)
+ method_used = 'statsmodels_failed' if method_used == "none" else "dowhy_failed_sm_failed"
+
+ # --- Diagnostics ---
+ logger.info("\nRunning diagnostics...")
+ diagnostics = run_iv_diagnostics(df, treatment, outcome, valid_instruments, clean_covariates, sm_results_obj, dw_results_obj)
+ results['diagnostics'] = diagnostics
+
+ # --- Refutation Step ---
+ final_estimate_value = results.get('dowhy_estimate') if method_used == 'dowhy' else results.get('statsmodels_estimate')
+
+ # Only run permute refuter if estimate is valid AND came from DoWhy
+ if method_used == 'dowhy' and dw_results_obj and final_estimate_value is not None:
+ logger.info("\nRunning refutation test (Placebo Treatment - Permute - requires DoWhy estimate object)...")
+ try:
+ # Pass the actual DoWhy estimate object
+ refuter_result = model.refute_estimate(
+ identified_estimand,
+ dw_results_obj, # Pass the original DoWhy result object
+ method_name="placebo_treatment_refuter",
+ placebo_type="permute" # Necessary for IV according to docs/examples
+ )
+ logger.info("Refutation test completed.")
+ logger.debug(f"Refuter Result:\n{refuter_result}")
+ # Store relevant info from refuter_result (check its structure)
+ refutation_results = {
+ "refuter": "placebo_treatment_refuter",
+ "new_effect": getattr(refuter_result, 'new_effect', 'N/A'),
+ "p_value": getattr(refuter_result, 'refutation_result', {}).get('p_value', 'N/A') if hasattr(refuter_result, 'refutation_result') else 'N/A',
+ # Passed if p-value > 0.05 (or not statistically significant)
+ "passed": getattr(refuter_result, 'refutation_result', {}).get('is_statistically_significant', None) == False if hasattr(refuter_result, 'refutation_result') else None
+ }
+ except Exception as refute_e:
+ logger.error(f"Refutation test failed: {refute_e}", exc_info=True)
+ refutation_results = {"error": f"Refutation failed: {refute_e}"}
+
+ elif final_estimate_value is not None and method_used == 'statsmodels':
+ logger.warning("Skipping placebo permutation refuter: Estimate was generated by statsmodels, not DoWhy's IV estimator.")
+ refutation_results = {"status": "skipped_wrong_estimator_for_permute"}
+
+ elif final_estimate_value is None:
+ logger.warning("Skipping refutation test because estimation failed.")
+ refutation_results = {"status": "skipped_due_to_failed_estimation"}
+
+ else: # Model or estimand failed earlier, or unknown method_used
+ logger.warning(f"Skipping refutation test due to earlier failure (method_used: {method_used}).")
+ refutation_results = {"status": "skipped_due_to_model_failure_or_unknown"}
+
+ results['refutation_results'] = refutation_results # Add to main results
+
+ # --- Formatting Results ---
+ if final_estimate_value is None and method_used not in ['dowhy', 'statsmodels']:
+ logger.error("ERROR: Both estimation methods failed.")
+ # Ensure error key exists if not set earlier
+ if 'error' not in results:
+ results['error'] = "Both DoWhy and statsmodels IV estimation failed."
+
+ logger.info("\n--- Formatting Final Results ---")
+ formatted_results = format_iv_results(
+ final_estimate_value, # Pass the numeric value
+ results, # Pass the dict containing estimate objects and refutation results
+ diagnostics,
+ treatment,
+ outcome,
+ valid_instruments,
+ method_used,
+ llm=llm
+ )
+
+ logger.info("--- Instrumental Variable Estimation Complete ---\n")
+ return formatted_results
\ No newline at end of file
diff --git a/auto_causal/methods/instrumental_variable/llm_assist.py b/auto_causal/methods/instrumental_variable/llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..abbf56196fae336990e9bfcfa232b0d5c2ea44c8
--- /dev/null
+++ b/auto_causal/methods/instrumental_variable/llm_assist.py
@@ -0,0 +1,240 @@
+"""
+LLM assistance functions for Instrumental Variable (IV) analysis.
+
+This module provides functions for LLM-based assistance in instrumental variable analysis,
+including identifying potential instruments, validating IV assumptions, and interpreting results.
+"""
+
+from typing import List, Dict, Any, Optional
+import logging
+
+# Imported for type hinting
+from langchain.chat_models.base import BaseChatModel
+
+# Import shared LLM helpers
+from auto_causal.utils.llm_helpers import call_llm_with_json_output
+
+logger = logging.getLogger(__name__)
+
+def identify_instrument_variable(
+ df_cols: List[str],
+ query: str,
+ llm: Optional[BaseChatModel] = None
+) -> List[str]:
+ """
+ Use LLM to identify potential instrumental variables from available columns.
+
+ Args:
+ df_cols: List of column names from the dataset
+ query: User's causal query text
+ llm: Optional LLM model instance
+
+ Returns:
+ List of column names identified as potential instruments
+ """
+ if llm is None:
+ logger.warning("No LLM provided for instrument identification")
+ return []
+
+ prompt = f"""
+ You are assisting with an instrumental variable analysis.
+
+ Available columns in the dataset: {df_cols}
+ User query: {query}
+
+ Identify potential instrumental variable(s) from the available columns based on the query.
+ The treatment and outcome should NOT be included as instruments.
+
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
+ {{
+ "potential_instruments": ["column_name1", "column_name2", ...]
+ }}
+ """
+
+ response = call_llm_with_json_output(llm, prompt)
+
+ if response and "potential_instruments" in response and isinstance(response["potential_instruments"], list):
+ # Basic validation: ensure items are strings (column names)
+ valid_instruments = [item for item in response["potential_instruments"] if isinstance(item, str)]
+ if len(valid_instruments) != len(response["potential_instruments"]):
+ logger.warning("LLM returned non-string items in potential_instruments list.")
+ return valid_instruments
+
+ logger.warning(f"Failed to get valid instrument recommendations from LLM. Response: {response}")
+ return []
+
+def validate_instrument_assumptions_qualitative(
+ treatment: str,
+ outcome: str,
+ instrument: List[str],
+ covariates: List[str],
+ query: str,
+ llm: Optional[BaseChatModel] = None
+) -> Dict[str, str]:
+ """
+ Use LLM to provide qualitative assessment of IV assumptions.
+
+ Args:
+ treatment: Treatment variable name
+ outcome: Outcome variable name
+ instrument: List of instrumental variable names
+ covariates: List of covariate variable names
+ query: User's causal query text
+ llm: Optional LLM model instance
+
+ Returns:
+ Dictionary with qualitative assessments of exclusion and exogeneity assumptions
+ """
+ default_fail = {
+ "exclusion_assessment": "LLM Check Failed",
+ "exogeneity_assessment": "LLM Check Failed"
+ }
+
+ if llm is None:
+ return {
+ "exclusion_assessment": "LLM Not Provided",
+ "exogeneity_assessment": "LLM Not Provided"
+ }
+
+ prompt = f"""
+ You are assisting with assessing the validity of instrumental variable assumptions.
+
+ Treatment variable: {treatment}
+ Outcome variable: {outcome}
+ Instrumental variable(s): {instrument}
+ Covariates: {covariates}
+ User query: {query}
+
+ Assess the core Instrumental Variable (IV) assumptions based *only* on the provided variable names and query context:
+ 1. Exclusion restriction: Plausibility that the instrument(s) affect the outcome ONLY through the treatment.
+ 2. Exogeneity (also called Independence): Plausibility that the instrument(s) are not correlated with unobserved confounders that also affect the outcome.
+
+ Provide a brief, qualitative assessment (e.g., 'Plausible', 'Unlikely', 'Requires Domain Knowledge', 'Potentially Violated').
+
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
+ {{
+ "exclusion_assessment": "",
+ "exogeneity_assessment": ""
+ }}
+ """
+
+ response = call_llm_with_json_output(llm, prompt)
+
+ if response and isinstance(response, dict) and \
+ "exclusion_assessment" in response and isinstance(response["exclusion_assessment"], str) and \
+ "exogeneity_assessment" in response and isinstance(response["exogeneity_assessment"], str):
+ return response
+
+ logger.warning(f"Failed to get valid assumption assessment from LLM. Response: {response}")
+ return default_fail
+
+def interpret_iv_results(
+ results: Dict[str, Any],
+ diagnostics: Dict[str, Any],
+ llm: Optional[BaseChatModel] = None
+) -> str:
+ """
+ Use LLM to interpret IV results in natural language.
+
+ Args:
+ results: Dictionary of estimation results (e.g., effect_estimate, p_value, confidence_interval)
+ diagnostics: Dictionary of diagnostic test results (e.g., first_stage_f_statistic, overid_test)
+ llm: Optional LLM model instance
+
+ Returns:
+ String containing natural language interpretation of results
+ """
+ if llm is None:
+ return "LLM was not available to provide interpretation. Please review the numeric results manually."
+
+ # Construct a concise summary of inputs for the prompt
+ results_summary = {}
+
+ effect = results.get('effect_estimate')
+ if effect is not None:
+ try:
+ results_summary['Effect Estimate'] = f"{float(effect):.3f}"
+ except (ValueError, TypeError):
+ results_summary['Effect Estimate'] = 'N/A (Invalid Format)'
+ else:
+ results_summary['Effect Estimate'] = 'N/A'
+
+ p_value = results.get('p_value')
+ if p_value is not None:
+ try:
+ results_summary['P-value'] = f"{float(p_value):.3f}"
+ except (ValueError, TypeError):
+ results_summary['P-value'] = 'N/A (Invalid Format)'
+ else:
+ results_summary['P-value'] = 'N/A'
+
+ ci = results.get('confidence_interval')
+ if ci is not None and isinstance(ci, (list, tuple)) and len(ci) == 2:
+ try:
+ results_summary['Confidence Interval'] = f"[{float(ci[0]):.3f}, {float(ci[1]):.3f}]"
+ except (ValueError, TypeError):
+ results_summary['Confidence Interval'] = 'N/A (Invalid Format)'
+ else:
+ # Handle cases where CI is None or not a 2-element list/tuple
+ results_summary['Confidence Interval'] = str(ci) if ci is not None else 'N/A'
+
+ if 'treatment_variable' in results:
+ results_summary['Treatment'] = results['treatment_variable']
+ if 'outcome_variable' in results:
+ results_summary['Outcome'] = results['outcome_variable']
+
+ diagnostics_summary = {}
+ f_stat = diagnostics.get('first_stage_f_statistic')
+ if f_stat is not None:
+ try:
+ diagnostics_summary['First-Stage F-statistic'] = f"{float(f_stat):.2f}"
+ except (ValueError, TypeError):
+ diagnostics_summary['First-Stage F-statistic'] = 'N/A (Invalid Format)'
+ else:
+ diagnostics_summary['First-Stage F-statistic'] = 'N/A'
+
+ if 'weak_instrument_test_status' in diagnostics:
+ diagnostics_summary['Weak Instrument Test'] = diagnostics['weak_instrument_test_status']
+
+ overid_p = diagnostics.get('overid_test_p_value')
+ if overid_p is not None:
+ try:
+ diagnostics_summary['Overidentification Test P-value'] = f"{float(overid_p):.3f}"
+ diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A')
+ except (ValueError, TypeError):
+ diagnostics_summary['Overidentification Test P-value'] = 'N/A (Invalid Format)'
+ diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A')
+ else:
+ # Explicitly state if not applicable or not available
+ if diagnostics.get('overid_test_applicable') == False:
+ diagnostics_summary['Overidentification Test'] = 'Not Applicable'
+ else:
+ diagnostics_summary['Overidentification Test P-value'] = 'N/A'
+ diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A')
+
+ prompt = f"""
+ You are assisting with interpreting instrumental variable (IV) analysis results.
+
+ Estimation results summary: {results_summary}
+ Diagnostic test results summary: {diagnostics_summary}
+
+ Explain these Instrumental Variable (IV) results in clear, concise language (2-4 sentences).
+ Focus on:
+ 1. The estimated causal effect (magnitude, direction, statistical significance based on p-value < 0.05).
+ 2. The strength of the instrument(s) (based on F-statistic, typically > 10 indicates strength).
+ 3. Any implications from other diagnostic tests (e.g., overidentification test suggesting instrument validity issues if p < 0.05).
+
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
+ {{
+ "interpretation": ""
+ }}
+ """
+
+ response = call_llm_with_json_output(llm, prompt)
+
+ if response and isinstance(response, dict) and \
+ "interpretation" in response and isinstance(response["interpretation"], str):
+ return response["interpretation"]
+
+ logger.warning(f"Failed to get valid interpretation from LLM. Response: {response}")
+ return "LLM interpretation could not be generated. Please review the numeric results manually."
\ No newline at end of file
diff --git a/auto_causal/methods/linear_regression/__init__.py b/auto_causal/methods/linear_regression/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/auto_causal/methods/linear_regression/diagnostics.py b/auto_causal/methods/linear_regression/diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..80de57653b8c943021553126b32d386f0f3f5d99
--- /dev/null
+++ b/auto_causal/methods/linear_regression/diagnostics.py
@@ -0,0 +1,76 @@
+"""
+Diagnostic checks for Linear Regression models.
+"""
+
+from typing import Dict, Any
+import statsmodels.api as sm
+from statsmodels.stats.diagnostic import het_breuschpagan, normal_ad
+from statsmodels.stats.stattools import jarque_bera
+from statsmodels.regression.linear_model import RegressionResultsWrapper
+import pandas as pd
+import logging
+
+logger = logging.getLogger(__name__)
+
+def run_lr_diagnostics(results: RegressionResultsWrapper, X: pd.DataFrame) -> Dict[str, Any]:
+ """
+ Runs diagnostic checks on a fitted OLS model.
+
+ Args:
+ results: A fitted statsmodels OLS results object.
+ X: The design matrix (including constant) used for the regression.
+ Needed for heteroskedasticity tests.
+
+ Returns:
+ Dictionary containing diagnostic metrics.
+ """
+
+ diagnostics = {}
+
+ try:
+ diagnostics['r_squared'] = results.rsquared
+ diagnostics['adj_r_squared'] = results.rsquared_adj
+ diagnostics['f_statistic'] = results.fvalue
+ diagnostics['f_p_value'] = results.f_pvalue
+ diagnostics['n_observations'] = int(results.nobs)
+ diagnostics['degrees_of_freedom_resid'] = int(results.df_resid)
+
+ # --- Normality of Residuals (Jarque-Bera) ---
+ try:
+ jb_value, jb_p_value, skew, kurtosis = jarque_bera(results.resid)
+ diagnostics['residuals_normality_jb_stat'] = jb_value
+ diagnostics['residuals_normality_jb_p_value'] = jb_p_value
+ diagnostics['residuals_skewness'] = skew
+ diagnostics['residuals_kurtosis'] = kurtosis
+ diagnostics['residuals_normality_status'] = "Normal" if jb_p_value > 0.05 else "Non-Normal"
+ except Exception as e:
+ logger.warning(f"Could not run Jarque-Bera test: {e}")
+ diagnostics['residuals_normality_status'] = "Test Failed"
+
+ # --- Homoscedasticity (Breusch-Pagan) ---
+ # Requires the design matrix X used in the model fitting
+ try:
+ lm_stat, lm_p_value, f_stat, f_p_value = het_breuschpagan(results.resid, X)
+ diagnostics['homoscedasticity_bp_lm_stat'] = lm_stat
+ diagnostics['homoscedasticity_bp_lm_p_value'] = lm_p_value
+ diagnostics['homoscedasticity_bp_f_stat'] = f_stat
+ diagnostics['homoscedasticity_bp_f_p_value'] = f_p_value
+ diagnostics['homoscedasticity_status'] = "Homoscedastic" if lm_p_value > 0.05 else "Heteroscedastic"
+ except Exception as e:
+ logger.warning(f"Could not run Breusch-Pagan test: {e}")
+ diagnostics['homoscedasticity_status'] = "Test Failed"
+
+ # --- Linearity (Basic check - often requires visual inspection) ---
+ # No standard quantitative test implemented here. Usually assessed via residual plots.
+ diagnostics['linearity_check'] = "Requires visual inspection (e.g., residual vs fitted plot)"
+
+ # --- Multicollinearity (Placeholder - requires VIF calculation) ---
+ # VIF requires iterating through predictors, more involved
+ diagnostics['multicollinearity_check'] = "Not Implemented (Requires VIF)"
+
+ return {"status": "Success", "details": diagnostics}
+
+ except Exception as e:
+ logger.error(f"Error running LR diagnostics: {e}")
+ return {"status": "Failed", "error": str(e), "details": diagnostics} # Return partial results if possible
+
diff --git a/auto_causal/methods/linear_regression/estimator.py b/auto_causal/methods/linear_regression/estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcc25602bcafc594c6e91c75fe15cdd33c9ef747
--- /dev/null
+++ b/auto_causal/methods/linear_regression/estimator.py
@@ -0,0 +1,355 @@
+"""
+Linear Regression Estimator for Causal Inference.
+
+Uses Ordinary Least Squares (OLS) to estimate the treatment effect, potentially
+adjusting for covariates.
+"""
+import pandas as pd
+import statsmodels.api as sm
+import statsmodels.formula.api as smf
+from typing import Dict, Any, List, Optional, Union
+import logging
+from langchain.chat_models.base import BaseChatModel
+import re
+import json
+from pydantic import BaseModel, ValidationError
+from langchain_core.messages import HumanMessage
+from langchain_core.exceptions import OutputParserException
+
+
+from auto_causal.models import LLMIdentifiedRelevantParams
+from auto_causal.prompts.regression_prompts import STATSMODELS_PARAMS_IDENTIFICATION_PROMPT_TEMPLATE
+from auto_causal.config import get_llm_client
+
+# Placeholder for potential future LLM assistance integration
+# from .llm_assist import interpret_lr_results, suggest_lr_covariates
+# Placeholder for potential future diagnostics integration
+# from .diagnostics import run_lr_diagnostics
+
+logger = logging.getLogger(__name__)
+
+def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel) -> Optional[BaseModel]:
+ """Helper to call LLM with structured output and handle errors."""
+ try:
+ messages = [HumanMessage(content=prompt)]
+ structured_llm = llm.with_structured_output(pydantic_model)
+ parsed_result = structured_llm.invoke(messages)
+ return parsed_result
+ except (OutputParserException, ValidationError) as e:
+ logger.error(f"LLM call failed parsing/validation for {pydantic_model.__name__}: {e}")
+ except Exception as e:
+ logger.error(f"LLM call failed unexpectedly for {pydantic_model.__name__}: {e}", exc_info=True)
+ return None
+
+# Define module-level helper function
+def _clean_variable_name_for_patsy_local(name: str) -> str:
+ if not isinstance(name, str):
+ name = str(name)
+ name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
+ if not re.match(r'^[a-zA-Z_]', name):
+ name = 'var_' + name
+ return name
+
+
+def estimate_effect(
+ df: pd.DataFrame,
+ treatment: str,
+ outcome: str,
+ covariates: Optional[List[str]] = None,
+ query_str: Optional[str] = None, # For potential LLM use
+ llm: Optional[BaseChatModel] = None, # For potential LLM use
+ **kwargs # To capture any other potential arguments
+) -> Dict[str, Any]:
+ """
+ Estimates the causal effect using Linear Regression (OLS).
+
+ Args:
+ df: Input DataFrame.
+ treatment: Name of the treatment variable column.
+ outcome: Name of the outcome variable column.
+ covariates: Optional list of covariate names.
+ query_str: Optional user query for context (e.g., for LLM).
+ llm: Optional Language Model instance.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ Dictionary containing estimation results:
+ - 'effect_estimate': The estimated coefficient for the treatment variable.
+ - 'p_value': The p-value associated with the treatment coefficient.
+ - 'confidence_interval': The 95% confidence interval for the effect.
+ - 'standard_error': The standard error of the treatment coefficient.
+ - 'formula': The regression formula used.
+ - 'model_summary': Summary object from statsmodels.
+ - 'diagnostics': Placeholder for diagnostic results.
+ - 'interpretation': Placeholder for LLM interpretation.
+ """
+ if covariates is None:
+ covariates = []
+
+ # Retrieve additional args from kwargs
+ interaction_term_suggested = kwargs.get('interaction_term_suggested', False)
+ # interaction_variable_candidate is the *original* name from query_interpreter
+ interaction_variable_candidate_orig_name = kwargs.get('interaction_variable_candidate')
+ treatment_reference_level = kwargs.get('treatment_reference_level')
+ column_mappings = kwargs.get('column_mappings', {})
+
+ required_cols = [treatment, outcome] + covariates
+ # If interaction variable is suggested, ensure it (or its processed form) is in df for analysis
+ # This check is complex here as interaction_variable_candidate_orig_name needs mapping to processed column(s)
+ # We'll rely on df_analysis.dropna() and formula construction to handle missing interaction var columns later
+
+ missing_cols = [col for col in required_cols if col not in df.columns]
+ if missing_cols:
+ raise ValueError(f"Missing required columns: {missing_cols}")
+
+ # Prepare data for statsmodels (add constant, handle potential NaNs)
+ df_analysis = df[required_cols].dropna()
+ if df_analysis.empty:
+ raise ValueError("No data remaining after dropping NaNs for required columns.")
+
+ X = df_analysis[[treatment] + covariates]
+ X = sm.add_constant(X) # Add intercept
+ y = df_analysis[outcome]
+
+ # --- Formula Construction ---
+ outcome_col_name = outcome # Name in processed df
+ treatment_col_name = treatment # Name in processed df
+ processed_covariate_col_names = covariates # List of names in processed df
+
+ rhs_terms = []
+
+ # 1. Treatment Term
+ treatment_patsy_term = treatment_col_name # Default
+ original_treatment_info = column_mappings.get(treatment_col_name, {}) # Info from preprocess_data
+
+ is_binary_encoded = original_treatment_info.get('transformed_as') == 'label_encoded_binary'
+ is_still_categorical_in_df = df_analysis[treatment_col_name].dtype.name in ['object', 'category']
+
+ if is_still_categorical_in_df and not is_binary_encoded: # Covers multi-level and binary categoricals not yet numeric
+ if treatment_reference_level:
+ treatment_patsy_term = f"C({treatment_col_name}, Treatment(reference='{treatment_reference_level}'))"
+ logger.info(f"Treating '{treatment_col_name}' as multi-level categorical with reference '{treatment_reference_level}'.")
+ else:
+ # Default C() wrapping for categoricals if no specific reference is given.
+ # This applies to multi-level or binary categoricals that were not label_encoded to 0/1 by preprocess_data.
+ treatment_patsy_term = f"C({treatment_col_name})"
+ logger.info(f"Treating '{treatment_col_name}' as categorical (Patsy will pick reference).")
+ elif is_binary_encoded: # Was binary and explicitly label encoded to 0/1 by preprocess_data
+ # Even if it's now numeric 0/1, C() ensures Patsy treats it categorically for parameter naming consistency.
+ treatment_patsy_term = f"C({treatment_col_name})"
+ logger.info(f"Treating label-encoded binary '{treatment_col_name}' as categorical for Patsy.")
+ else: # Assumed to be already numeric (continuous or discrete numeric not needing C() for main effect)
+ # treatment_patsy_term remains treatment_col_name (default)
+ logger.info(f"Treating '{treatment_col_name}' as numeric for Patsy formula.")
+
+ rhs_terms.append(treatment_patsy_term)
+
+ # 2. Covariate Terms
+ for cov_col_name in processed_covariate_col_names:
+ if cov_col_name == treatment_col_name: # Should not happen if covariates list is clean
+ continue
+ # Assume covariates are already numeric/dummy. If one was object/category in df_analysis (unlikely), C() it.
+ if df_analysis[cov_col_name].dtype.name in ['object', 'category']:
+ rhs_terms.append(f"C({cov_col_name})")
+ else:
+ rhs_terms.append(cov_col_name)
+
+ # 3. Interaction Term (Simplified: interaction_variable_candidate_orig_name must map to a single column in df_analysis)
+ actual_interaction_term_added_to_formula = None
+ if interaction_term_suggested and interaction_variable_candidate_orig_name:
+ processed_interaction_col_name = None
+ interaction_var_info = column_mappings.get(interaction_variable_candidate_orig_name, {})
+
+ if interaction_var_info.get('transformed_as') == 'one_hot_encoded':
+ logger.warning(f"Interaction with one-hot encoded variable '{interaction_variable_candidate_orig_name}' is complex. Currently skipping this interaction for Linear Regression.")
+ elif interaction_var_info.get('new_column_name') and interaction_var_info['new_column_name'] in df_analysis.columns:
+ processed_interaction_col_name = interaction_var_info['new_column_name']
+ elif interaction_variable_candidate_orig_name in df_analysis.columns: # Was not in mappings, or mapping didn't change name (e.g. numeric)
+ processed_interaction_col_name = interaction_variable_candidate_orig_name
+
+ if processed_interaction_col_name:
+ interaction_var_patsy_term = processed_interaction_col_name
+ # If the processed interaction column itself is categorical (e.g. label encoded binary)
+ if df_analysis[processed_interaction_col_name].dtype.name in ['object', 'category', 'bool'] or \
+ interaction_var_info.get('original_dtype') in ['bool', 'category']:
+ interaction_var_patsy_term = f"C({processed_interaction_col_name})"
+
+ actual_interaction_term_added_to_formula = f"{treatment_patsy_term}:{interaction_var_patsy_term}"
+ rhs_terms.append(actual_interaction_term_added_to_formula)
+ logger.info(f"Adding interaction term to formula: {actual_interaction_term_added_to_formula}")
+ elif interaction_variable_candidate_orig_name: # Log if it was suggested but couldn't be mapped/found
+ logger.warning(f"Could not resolve interaction variable candidate '{interaction_variable_candidate_orig_name}' to a single usable column in processed data. Skipping interaction term.")
+
+ # Build the formula string for reporting and fitting
+ if not rhs_terms: # Should always have at least treatment
+ formula = f"{outcome_col_name} ~ 1"
+ else:
+ formula = f"{outcome_col_name} ~ {' + '.join(rhs_terms)}"
+ logger.info(f"Using formula for Linear Regression: {formula}")
+
+ try:
+ model = smf.ols(formula=formula, data=df_analysis)
+ results = model.fit()
+ logger.info("OLS model fitted successfully.")
+ logger.info(results.summary()) # Changed to debug level for less verbose default logging
+
+ # --- Result Extraction: LLM attempt first, then Regex fallback ---
+ effect_estimates_by_level = {}
+ all_params_extracted = False # Default to False
+ llm_extraction_successful = False
+
+ # Attempt LLM-based extraction if llm client and query are available
+ llm = get_llm_client()
+ if llm and query_str:
+ logger.info(f"Attempting LLM-based result extraction (informed by query: '{query_str[:50]}...').")
+ try:
+ param_names_list = results.params.index.tolist()
+ param_estimates_list = results.params.tolist()
+ param_p_values_list = results.pvalues.tolist()
+ param_std_errs_list = results.bse.tolist()
+
+ conf_int_df = results.conf_int(alpha=0.05)
+ param_conf_ints_low_list = []
+ param_conf_ints_high_list = []
+
+ if not conf_int_df.empty and len(conf_int_df.columns) == 2:
+ aligned_conf_int_df = conf_int_df.reindex(results.params.index)
+ param_conf_ints_low_list = aligned_conf_int_df.iloc[:, 0].fillna(float('nan')).tolist()
+ param_conf_ints_high_list = aligned_conf_int_df.iloc[:, 1].fillna(float('nan')).tolist()
+ else:
+ nan_list_ci = [float('nan')] * len(param_names_list)
+ param_conf_ints_low_list = nan_list_ci
+ param_conf_ints_high_list = nan_list_ci
+
+ # Placeholder for the new prompt template tailored for this extraction task
+ # MOVED TO causalscientist/auto_causal/prompts/regression_prompts.py
+
+ is_multilevel_case_for_prompt = bool(treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded)
+ reference_level_for_prompt_str = str(treatment_reference_level) if is_multilevel_case_for_prompt else "N/A"
+
+ indexed_param_names_for_prompt = [f"{idx}: '{name}'" for idx, name in enumerate(param_names_list)]
+ indexed_param_names_str_for_prompt = "\n".join(indexed_param_names_for_prompt)
+
+ prompt_text_for_identification = STATSMODELS_PARAMS_IDENTIFICATION_PROMPT_TEMPLATE.format(
+ user_query=query_str,
+ treatment_patsy_term=treatment_patsy_term,
+ treatment_col_name=treatment_col_name,
+ is_multilevel_case=is_multilevel_case_for_prompt,
+ reference_level_for_prompt=reference_level_for_prompt_str,
+ indexed_param_names_str=indexed_param_names_str_for_prompt, # Pass the indexed list as a string
+ llm_response_schema_json=json.dumps(LLMIdentifiedRelevantParams.model_json_schema(), indent=2)
+ )
+
+ llm_identification_response = _call_llm_for_var(llm, prompt_text_for_identification, LLMIdentifiedRelevantParams)
+
+ if llm_identification_response and llm_identification_response.identified_params:
+ logger.info("LLM identified relevant parameters. Proceeding with programmatic extraction.")
+ for item in llm_identification_response.identified_params:
+ param_idx = item.param_index
+ # Validate index against actual list length
+ if 0 <= param_idx < len(results.params.index):
+ actual_param_name = results.params.index[param_idx]
+ # Sanity check if LLM returned name matches actual name at index
+ if item.param_name != actual_param_name:
+ logger.warning(f"LLM returned param_name '{item.param_name}' but name at index {param_idx} is '{actual_param_name}'. Using actual name from results.")
+
+ current_effect_stats = {
+ 'estimate': results.params.iloc[param_idx],
+ 'p_value': results.pvalues.iloc[param_idx],
+ 'conf_int': results.conf_int(alpha=0.05).iloc[param_idx].tolist(),
+ 'std_err': results.bse.iloc[param_idx]
+ }
+
+ key_for_effect_dict = 'treatment_effect' # Default for single/binary
+ if is_multilevel_case_for_prompt: # If it was a multi-level case
+ match = re.search(r'\[T\.([^]]+)]', actual_param_name) # Use actual_param_name
+ if match:
+ level = match.group(1)
+ if level != reference_level_for_prompt_str: # Ensure it's not the ref level itself
+ key_for_effect_dict = level
+ else:
+ logger.warning(f"Could not parse level from LLM-identified param: {actual_param_name}. Storing with raw name.")
+ key_for_effect_dict = actual_param_name # Fallback key
+
+ effect_estimates_by_level[key_for_effect_dict] = current_effect_stats
+ else:
+ logger.warning(f"LLM returned an invalid parameter index: {param_idx}. Skipping.")
+
+ if effect_estimates_by_level: # If any effects were successfully processed
+ all_params_extracted = llm_identification_response.all_parameters_successfully_identified
+ llm_extraction_successful = True
+ logger.info(f"Successfully processed LLM-identified parameters. all_parameters_successfully_identified={all_params_extracted}")
+ print(f"effect_estimates_by_level: {effect_estimates_by_level}")
+ else:
+ logger.warning("LLM identified parameters, but none could be processed into effects_estimates_by_level. Falling back to regex.")
+ else:
+ logger.warning("LLM parameter identification did not yield usable parameters. Falling back to regex.")
+
+ except Exception as e_llm:
+ logger.warning(f"LLM-based result extraction failed: {e_llm}. Falling back to regex.", exc_info=True)
+
+
+ # --- End of Existing Regex Logic Block ---
+
+ # Primary effect_estimate for simple reporting (e.g. first level or the only one)
+ # For multi-level, this is ambiguous. For now, let's report None or the first one.
+ # The full details are in effect_estimates_by_level.
+ main_effect_estimate = None
+ main_p_value = None
+ main_conf_int = [None, None] # Default for single or if no effects
+ main_std_err = None
+
+ if effect_estimates_by_level:
+ if 'treatment_effect' in effect_estimates_by_level: # Single effect case
+ single_effect_data = effect_estimates_by_level['treatment_effect']
+ main_effect_estimate = single_effect_data['estimate']
+ main_p_value = single_effect_data['p_value']
+ main_conf_int = single_effect_data['conf_int']
+ main_std_err = single_effect_data['std_err']
+ else: # Multi-level case
+ logger.info("Multi-level treatment effects extracted. Populating dicts for main estimate fields.")
+ effect_estimate_dict = {}
+ p_value_dict = {}
+ conf_int_dict = {}
+ std_err_dict = {}
+ for level, stats in effect_estimates_by_level.items():
+ effect_estimate_dict[level] = stats.get('estimate')
+ p_value_dict[level] = stats.get('p_value')
+ conf_int_dict[level] = stats.get('conf_int') # This is already a list [low, high]
+ std_err_dict[level] = stats.get('std_err')
+
+ main_effect_estimate = effect_estimate_dict
+ main_p_value = p_value_dict
+ main_conf_int = conf_int_dict
+ main_std_err = std_err_dict
+
+ interpretation_details = {}
+ if actual_interaction_term_added_to_formula and actual_interaction_term_added_to_formula in results.params.index:
+ interpretation_details['interaction_term_coefficient'] = results.params[actual_interaction_term_added_to_formula]
+ interpretation_details['interaction_term_p_value'] = results.pvalues[actual_interaction_term_added_to_formula]
+ logger.info(f"Interaction term '{actual_interaction_term_added_to_formula}' coeff: {interpretation_details['interaction_term_coefficient']}")
+
+ diag_results = {}
+ interpretation = "Interpretation not available."
+
+ output_dict = {
+ 'effect_estimate': main_effect_estimate,
+ 'p_value': main_p_value,
+ 'confidence_interval': main_conf_int,
+ 'standard_error': main_std_err,
+ 'estimated_effects_by_level': effect_estimates_by_level if (treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded and effect_estimates_by_level) else None,
+ 'reference_level_used': treatment_reference_level if (treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded) else None,
+ 'formula': formula,
+ 'model_summary_text': results.summary().as_text(), # Store as text for easier serialization
+ 'diagnostics': diag_results,
+ 'interpretation_details': interpretation_details, # Added interaction details
+ 'interpretation': interpretation,
+ 'method_used': 'Linear Regression (OLS)'
+ }
+ if not all_params_extracted:
+ output_dict['warnings'] = ["Could not reliably extract all requested parameters from model results. Please check model_summary_text."]
+ return output_dict
+
+ except Exception as e:
+ logger.error(f"Linear Regression failed: {e}")
+ raise # Re-raise the exception after logging
\ No newline at end of file
diff --git a/auto_causal/methods/linear_regression/llm_assist.py b/auto_causal/methods/linear_regression/llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..25639db9e2114a369b6222c7b769001f9ef0afee
--- /dev/null
+++ b/auto_causal/methods/linear_regression/llm_assist.py
@@ -0,0 +1,146 @@
+"""
+LLM assistance functions for Linear Regression analysis.
+"""
+
+from typing import List, Dict, Any, Optional
+import logging
+
+# Imported for type hinting
+from langchain.chat_models.base import BaseChatModel
+from statsmodels.regression.linear_model import RegressionResultsWrapper
+
+# Import shared LLM helpers
+from auto_causal.utils.llm_helpers import call_llm_with_json_output
+
+logger = logging.getLogger(__name__)
+
+def suggest_lr_covariates(
+ df_cols: List[str],
+ treatment: str,
+ outcome: str,
+ query: str,
+ llm: Optional[BaseChatModel] = None
+) -> List[str]:
+ """
+ (Placeholder) Use LLM to suggest relevant covariates for linear regression.
+
+ Args:
+ df_cols: List of available column names.
+ treatment: Treatment variable name.
+ outcome: Outcome variable name.
+ query: User's causal query text.
+ llm: Optional LLM model instance.
+
+ Returns:
+ List of suggested covariate names.
+ """
+ logger.info("LLM covariate suggestion for LR is not implemented yet.")
+ if llm:
+ # Placeholder: Call LLM here in future
+ pass
+ return []
+
+def interpret_lr_results(
+ results: RegressionResultsWrapper,
+ diagnostics: Dict[str, Any],
+ treatment_var: str, # Need treatment variable name to extract coefficient
+ llm: Optional[BaseChatModel] = None
+) -> str:
+ """
+ Use LLM to interpret Linear Regression results.
+
+ Args:
+ results: Fitted statsmodels OLS results object.
+ diagnostics: Dictionary of diagnostic test results.
+ treatment_var: Name of the treatment variable.
+ llm: Optional LLM model instance.
+
+ Returns:
+ String containing natural language interpretation.
+ """
+ default_interpretation = "LLM interpretation not available for Linear Regression."
+ if llm is None:
+ logger.info("LLM not provided for LR interpretation.")
+ return default_interpretation
+
+ try:
+ # --- Prepare summary for LLM ---
+ results_summary = {}
+ treatment_val = results.params.get(treatment_var)
+ pval_val = results.pvalues.get(treatment_var)
+
+ if treatment_val is not None:
+ results_summary['Treatment Effect Estimate'] = f"{treatment_val:.3f}"
+ else:
+ logger.warning(f"Treatment variable '{treatment_var}' not found in regression parameters.")
+ results_summary['Treatment Effect Estimate'] = "Not Found"
+
+ if pval_val is not None:
+ results_summary['Treatment P-value'] = f"{pval_val:.3f}"
+ else:
+ logger.warning(f"P-value for treatment variable '{treatment_var}' not found in regression results.")
+ results_summary['Treatment P-value'] = "Not Found"
+
+ try:
+ conf_int = results.conf_int().loc[treatment_var]
+ results_summary['Treatment 95% CI'] = f"[{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
+ except KeyError:
+ logger.warning(f"Confidence interval for treatment variable '{treatment_var}' not found.")
+ results_summary['Treatment 95% CI'] = "Not Found"
+ except Exception as ci_e:
+ logger.warning(f"Could not extract confidence interval for '{treatment_var}': {ci_e}")
+ results_summary['Treatment 95% CI'] = "Error"
+
+ results_summary['R-squared'] = f"{results.rsquared:.3f}"
+ results_summary['Adj. R-squared'] = f"{results.rsquared_adj:.3f}"
+
+ diag_summary = {}
+ if diagnostics.get("status") == "Success":
+ diag_details = diagnostics.get("details", {})
+ # Format p-values only if they are numbers
+ jb_p = diag_details.get('residuals_normality_jb_p_value')
+ bp_p = diag_details.get('homoscedasticity_bp_lm_p_value')
+ diag_summary['Residuals Normality (Jarque-Bera P-value)'] = f"{jb_p:.3f}" if isinstance(jb_p, (int, float)) else str(jb_p)
+ diag_summary['Homoscedasticity (Breusch-Pagan P-value)'] = f"{bp_p:.3f}" if isinstance(bp_p, (int, float)) else str(bp_p)
+ diag_summary['Homoscedasticity Status'] = diag_details.get('homoscedasticity_status', 'N/A')
+ diag_summary['Residuals Normality Status'] = diag_details.get('residuals_normality_status', 'N/A')
+ else:
+ diag_summary['Status'] = diagnostics.get("status", "Unknown")
+ if "error" in diagnostics:
+ diag_summary['Error'] = diagnostics["error"]
+
+ # --- Construct Prompt ---
+ prompt = f"""
+ You are assisting with interpreting Linear Regression (OLS) results for causal inference.
+
+ Model Results Summary:
+ {results_summary}
+
+ Model Diagnostics Summary:
+ {diag_summary}
+
+ Explain these results in 2-4 concise sentences. Focus on:
+ 1. The estimated causal effect of the treatment variable '{treatment_var}' (magnitude, direction, statistical significance based on p-value < 0.05).
+ 2. Overall model fit (using R-squared as a rough guide).
+ 3. Key diagnostic findings (specifically, mention if residuals are non-normal or if heteroscedasticity is detected, as these violate OLS assumptions and can affect inference).
+
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
+ {{
+ "interpretation": ""
+ }}
+ """
+
+ # --- Call LLM ---
+ response = call_llm_with_json_output(llm, prompt)
+
+ # --- Process Response ---
+ if response and isinstance(response, dict) and \
+ "interpretation" in response and isinstance(response["interpretation"], str):
+ return response["interpretation"]
+ else:
+ logger.warning(f"Failed to get valid interpretation from LLM. Response: {response}")
+ return default_interpretation
+
+ except Exception as e:
+ logger.error(f"Error during LLM interpretation for LR: {e}")
+ return f"Error generating interpretation: {e}"
diff --git a/auto_causal/methods/propensity_score/__init__.py b/auto_causal/methods/propensity_score/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a2245ac37636e8987b8c7537ec47a2121c6f6cd
--- /dev/null
+++ b/auto_causal/methods/propensity_score/__init__.py
@@ -0,0 +1,13 @@
+from .base import estimate_propensity_scores
+from .matching import estimate_effect as estimate_matching_effect
+from .weighting import estimate_effect as estimate_weighting_effect
+from .diagnostics import assess_balance, plot_overlap, plot_balance
+
+__all__ = [
+ "estimate_propensity_scores",
+ "estimate_matching_effect",
+ "estimate_weighting_effect",
+ "assess_balance",
+ "plot_overlap",
+ "plot_balance"
+]
\ No newline at end of file
diff --git a/auto_causal/methods/propensity_score/base.py b/auto_causal/methods/propensity_score/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..be2fa15bba02a9e9eaaff880fbb507d657449a1a
--- /dev/null
+++ b/auto_causal/methods/propensity_score/base.py
@@ -0,0 +1,80 @@
+# Base functionality for Propensity Score methods
+import pandas as pd
+import numpy as np
+from sklearn.linear_model import LogisticRegression
+from sklearn.preprocessing import StandardScaler
+from typing import List, Optional, Dict, Any
+
+# Placeholder for LLM interaction to select model type
+def select_propensity_model(df: pd.DataFrame, treatment: str, covariates: List[str],
+ query: Optional[str] = None) -> str:
+ '''Selects the appropriate propensity score model type (e.g., logistic, GBM).
+
+ Placeholder: Currently defaults to Logistic Regression.
+ '''
+ # TODO: Implement LLM call or heuristic to select model based on data characteristics
+ return "logistic"
+
+def estimate_propensity_scores(df: pd.DataFrame, treatment: str,
+ covariates: List[str], model_type: str = 'logistic',
+ **kwargs) -> np.ndarray:
+ '''Estimate propensity scores using a specified model.
+
+ Args:
+ df: DataFrame containing the data
+ treatment: Name of the treatment variable
+ covariates: List of covariate variable names
+ model_type: Type of model to use ('logistic' supported for now)
+ **kwargs: Additional arguments for the model
+
+ Returns:
+ Array of propensity scores
+ '''
+
+ X = df[covariates]
+ y = df[treatment]
+
+ # Standardize covariates for logistic regression
+ scaler = StandardScaler()
+ X_scaled = scaler.fit_transform(X)
+
+ if model_type.lower() == 'logistic':
+ # Fit logistic regression
+ model = LogisticRegression(max_iter=kwargs.get('max_iter', 1000),
+ solver=kwargs.get('solver', 'liblinear'), # Use liblinear for L1/L2
+ C=kwargs.get('C', 1.0),
+ penalty=kwargs.get('penalty', 'l2'))
+ model.fit(X_scaled, y)
+
+ # Predict probabilities
+ propensity_scores = model.predict_proba(X_scaled)[:, 1]
+ # TODO: Add other model types like Gradient Boosting, etc.
+ # elif model_type.lower() == 'gbm':
+ # from sklearn.ensemble import GradientBoostingClassifier
+ # model = GradientBoostingClassifier(...)
+ # model.fit(X, y)
+ # propensity_scores = model.predict_proba(X)[:, 1]
+ else:
+ raise ValueError(f"Unsupported propensity score model type: {model_type}")
+
+ # Clip scores to avoid extremes which can cause issues in weighting/matching
+ propensity_scores = np.clip(propensity_scores, 0.01, 0.99)
+
+ return propensity_scores
+
+# Common formatting function (can be expanded)
+def format_ps_results(effect_estimate: float, effect_se: float,
+ diagnostics: Dict[str, Any], method_details: str,
+ parameters: Dict[str, Any]) -> Dict[str, Any]:
+ '''Standard formatter for PS method results.'''
+ ci_lower = effect_estimate - 1.96 * effect_se
+ ci_upper = effect_estimate + 1.96 * effect_se
+ return {
+ "effect_estimate": float(effect_estimate),
+ "effect_se": float(effect_se),
+ "confidence_interval": [float(ci_lower), float(ci_upper)],
+ "diagnostics": diagnostics,
+ "method_details": method_details,
+ "parameters": parameters
+ # Add p-value if needed (can be calculated from estimate and SE)
+ }
\ No newline at end of file
diff --git a/auto_causal/methods/propensity_score/diagnostics.py b/auto_causal/methods/propensity_score/diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2d94cd16b631be20678c47c849599854b69e340
--- /dev/null
+++ b/auto_causal/methods/propensity_score/diagnostics.py
@@ -0,0 +1,74 @@
+# Balance and sensitivity analysis diagnostics for Propensity Score methods
+
+import pandas as pd
+import numpy as np
+from typing import Dict, List, Optional, Any
+
+# Import necessary plotting libraries if visualizations are needed
+# import matplotlib.pyplot as plt
+# import seaborn as sns
+
+# Import utility for standardized differences if needed
+from auto_causal.methods.utils import calculate_standardized_differences
+
+def assess_balance(df_original: pd.DataFrame, df_matched_or_weighted: pd.DataFrame,
+ treatment: str, covariates: List[str],
+ method: str,
+ propensity_scores_original: Optional[np.ndarray] = None,
+ propensity_scores_matched: Optional[np.ndarray] = None,
+ weights: Optional[np.ndarray] = None) -> Dict[str, Any]:
+ '''Assesses covariate balance before and after matching/weighting.
+
+ Placeholder: Returns dummy diagnostic data.
+ '''
+ print(f"Assessing balance for {method}...")
+ # TODO: Implement actual balance checking using standardized differences,
+ # variance ratios, KS tests, etc.
+ # Example using standardized differences (needs calculate_standardized_differences):
+ # std_diff_before = calculate_standardized_differences(df_original, treatment, covariates)
+ # std_diff_after = calculate_standardized_differences(df_matched_or_weighted, treatment, covariates, weights=weights)
+
+ dummy_balance_metric = {cov: np.random.rand() * 0.1 for cov in covariates} # Simulate good balance
+
+ return {
+ "balance_metrics": dummy_balance_metric,
+ "balance_achieved": True, # Placeholder
+ "problematic_covariates": [], # Placeholder
+ # Add plots or paths to plots if generated
+ "plots": {
+ "balance_plot": "balance_plot.png",
+ "overlap_plot": "overlap_plot.png"
+ }
+ }
+
+def assess_weight_distribution(weights: np.ndarray, treatment_indicator: pd.Series) -> Dict[str, Any]:
+ '''Assesses the distribution of IPW weights.
+
+ Placeholder: Returns dummy diagnostic data.
+ '''
+ print("Assessing weight distribution...")
+ # TODO: Implement checks for extreme weights, effective sample size, etc.
+ return {
+ "min_weight": float(np.min(weights)),
+ "max_weight": float(np.max(weights)),
+ "mean_weight": float(np.mean(weights)),
+ "std_dev_weight": float(np.std(weights)),
+ "effective_sample_size": len(weights) / (1 + np.std(weights)**2 / np.mean(weights)**2), # Kish's ESS approx
+ "potential_issues": np.max(weights) > 20 # Example check
+ }
+
+def plot_overlap(df: pd.DataFrame, treatment: str, propensity_scores: np.ndarray, save_path: str = 'overlap_plot.png'):
+ '''Generates plot showing propensity score overlap.
+ Placeholder: Does nothing.
+ '''
+ print(f"Generating overlap plot (placeholder) -> {save_path}")
+ # TODO: Implement actual plotting (e.g., using seaborn histplot or kdeplot)
+ pass
+
+def plot_balance(balance_metrics_before: Dict[str, float], balance_metrics_after: Dict[str, float], save_path: str = 'balance_plot.png'):
+ '''Generates plot showing covariate balance before/after.
+ Placeholder: Does nothing.
+ '''
+ print(f"Generating balance plot (placeholder) -> {save_path}")
+ # TODO: Implement actual plotting (e.g., Love plot)
+ pass
\ No newline at end of file
diff --git a/auto_causal/methods/propensity_score/llm_assist.py b/auto_causal/methods/propensity_score/llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe1a3e7f5ba1122dada7ff09603b071892012f7a
--- /dev/null
+++ b/auto_causal/methods/propensity_score/llm_assist.py
@@ -0,0 +1,45 @@
+# LLM Integration points for Propensity Score methods
+import pandas as pd
+from typing import List, Optional, Dict, Any
+
+def determine_optimal_caliper(df: pd.DataFrame, treatment: str,
+ covariates: List[str],
+ query: Optional[str] = None) -> float:
+ '''Determines optimal caliper for PSM using data or LLM.
+
+ Placeholder: Returns a default value.
+ '''
+ # TODO: Implement data-driven (e.g., based on PS distribution) or LLM-assisted caliper selection.
+ # Common rule of thumb is 0.2 * std dev of logit(PS), but that requires calculating PS first.
+ return 0.2
+
+def determine_optimal_weight_type(df: pd.DataFrame, treatment: str,
+ query: Optional[str] = None) -> str:
+ '''Determines the optimal type of IPW weights (ATE, ATT, etc.).
+
+ Placeholder: Defaults to ATE.
+ '''
+ # TODO: Implement LLM or rule-based selection.
+ return "ATE"
+
+def determine_optimal_trim_threshold(df: pd.DataFrame, treatment: str,
+ propensity_scores: Optional[pd.Series] = None,
+ query: Optional[str] = None) -> Optional[float]:
+ '''Determines optimal threshold for trimming extreme propensity scores.
+
+ Placeholder: Defaults to no trimming (None).
+ '''
+ # TODO: Implement data-driven or LLM-assisted threshold selection (e.g., based on score distribution).
+ return None # Corresponds to no trimming by default
+
+# Placeholder for calling LLM to get parameters (can use the one in utils if general enough)
+def get_llm_parameters(df: pd.DataFrame, query: str, method: str) -> Dict[str, Any]:
+ '''Placeholder to get parameters via LLM based on dataset and query.'''
+ # In reality, call something like analyze_dataset_for_method from utils.llm_helpers
+ print(f"Simulating LLM call to get parameters for {method}...")
+ if method == "PS.Matching":
+ return {"parameters": {"caliper": 0.15}, "validation": {"check_balance": True}}
+ elif method == "PS.Weighting":
+ return {"parameters": {"weight_type": "ATE", "trim_threshold": 0.05}, "validation": {"check_weights": True}}
+ else:
+ return {"parameters": {}, "validation": {}}
\ No newline at end of file
diff --git a/auto_causal/methods/propensity_score/matching.py b/auto_causal/methods/propensity_score/matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1329e78d03a9fc437d5cca82a49b1d283f344cd
--- /dev/null
+++ b/auto_causal/methods/propensity_score/matching.py
@@ -0,0 +1,341 @@
+# Propensity Score Matching Implementation
+import pandas as pd
+import numpy as np
+from sklearn.neighbors import NearestNeighbors
+import statsmodels.api as sm # For bias adjustment regression
+import logging # For logging fallback
+from typing import Dict, List, Optional, Any
+
+# Import DoWhy
+from dowhy import CausalModel
+
+from .base import estimate_propensity_scores, format_ps_results, select_propensity_model
+from .diagnostics import assess_balance #, plot_overlap, plot_balance # Import diagnostic functions
+# Remove determine_optimal_caliper, it will be replaced by a heuristic
+from .llm_assist import get_llm_parameters # Import LLM helpers
+
+logger = logging.getLogger(__name__)
+
+def _calculate_logit(pscore):
+ """Calculate logit of propensity score, clipping to avoid inf."""
+ # Clip pscore to prevent log(0) or log(1) issues which lead to inf
+ epsilon = 1e-7
+ pscore_clipped = np.clip(pscore, epsilon, 1 - epsilon)
+ return np.log(pscore_clipped / (1 - pscore_clipped))
+
+def _perform_matching_and_get_att(
+ df_sample: pd.DataFrame,
+ treatment: str,
+ outcome: str,
+ covariates: List[str],
+ propensity_model_type: str,
+ n_neighbors: int,
+ caliper: float,
+ perform_bias_adjustment: bool,
+ **kwargs
+) -> float:
+ """
+ Helper to perform Custom KNN PSM and calculate ATT, potentially with bias adjustment.
+ Returns the ATT estimate.
+ """
+ df_ps = df_sample.copy()
+ try:
+ propensity_scores = estimate_propensity_scores(
+ df_ps, treatment, covariates, model_type=propensity_model_type, **kwargs
+ )
+ except Exception as e:
+ logger.warning(f"Propensity score estimation failed in helper: {e}")
+ return np.nan # Cannot proceed without propensity scores
+
+ df_ps['propensity_score'] = propensity_scores
+
+ treated = df_ps[df_ps[treatment] == 1]
+ control = df_ps[df_ps[treatment] == 0]
+
+ if treated.empty or control.empty:
+ return np.nan
+
+ nn = NearestNeighbors(n_neighbors=n_neighbors, radius=caliper if caliper is not None else np.inf, metric='minkowski', p=2)
+ try:
+ # Ensure control PS are valid before fitting
+ control_ps_values = control[['propensity_score']].values
+ if np.isnan(control_ps_values).any():
+ logger.warning("NaN values found in control propensity scores before NN fitting.")
+ return np.nan
+ nn.fit(control_ps_values)
+
+ # Ensure treated PS are valid before querying
+ treated_ps_values = treated[['propensity_score']].values
+ if np.isnan(treated_ps_values).any():
+ logger.warning("NaN values found in treated propensity scores before NN query.")
+ return np.nan
+ distances, indices = nn.kneighbors(treated_ps_values)
+
+ except ValueError as e:
+ # Handles case where control group might be too small or have NaN PS scores
+ logger.warning(f"NearestNeighbors fitting/query failed: {e}")
+ return np.nan
+
+ matched_outcomes_treated = []
+ matched_outcomes_control_means = []
+ propensity_diffs = []
+
+ for i in range(len(treated)):
+ treated_unit = treated.iloc[[i]]
+ valid_neighbors_mask = distances[i] <= (caliper if caliper is not None else np.inf)
+ valid_neighbors_idx = indices[i][valid_neighbors_mask]
+
+ if len(valid_neighbors_idx) > 0:
+ matched_controls_for_this_treated = control.iloc[valid_neighbors_idx]
+ if matched_controls_for_this_treated.empty:
+ continue # Should not happen with valid_neighbors_idx check, but safety
+
+ matched_outcomes_treated.append(treated_unit[outcome].values[0])
+ matched_outcomes_control_means.append(matched_controls_for_this_treated[outcome].mean())
+
+ if perform_bias_adjustment:
+ # Ensure PS scores are valid before calculating difference
+ treated_ps = treated_unit['propensity_score'].values[0]
+ control_ps_mean = matched_controls_for_this_treated['propensity_score'].mean()
+ if np.isnan(treated_ps) or np.isnan(control_ps_mean):
+ logger.warning("NaN propensity score encountered during bias adjustment calculation.")
+ # Cannot perform bias adjustment for this unit, potentially skip or handle
+ # For now, let's skip adding to propensity_diffs if NaN found
+ continue
+ propensity_diff = treated_ps - control_ps_mean
+ propensity_diffs.append(propensity_diff)
+
+ if not matched_outcomes_treated:
+ return np.nan
+
+ raw_att_components = np.array(matched_outcomes_treated) - np.array(matched_outcomes_control_means)
+
+ if perform_bias_adjustment:
+ # Ensure lengths match *after* potential skips due to NaNs
+ if not propensity_diffs or len(raw_att_components) != len(propensity_diffs):
+ logger.warning("Bias adjustment skipped due to inconsistent data lengths after NaN checks.")
+ return np.mean(raw_att_components)
+
+ try:
+ X_bias_adj = sm.add_constant(np.array(propensity_diffs))
+ y_bias_adj = raw_att_components
+ # Add check for NaNs/Infs in inputs to OLS
+ if np.isnan(X_bias_adj).any() or np.isnan(y_bias_adj).any() or \
+ np.isinf(X_bias_adj).any() or np.isinf(y_bias_adj).any():
+ logger.warning("NaN/Inf values detected in OLS inputs for bias adjustment. Falling back.")
+ return np.mean(raw_att_components)
+
+ bias_model = sm.OLS(y_bias_adj, X_bias_adj).fit()
+ bias_adjusted_att = bias_model.params[0]
+ return bias_adjusted_att
+ except Exception as e:
+ logger.warning(f"OLS for bias adjustment failed: {e}. Falling back to raw ATT.")
+ return np.mean(raw_att_components)
+ else:
+ return np.mean(raw_att_components)
+
+def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str,
+ covariates: List[str], **kwargs) -> Dict[str, Any]:
+ '''Estimate ATT using Propensity Score Matching.
+ Tries DoWhy's PSM first, falls back to custom implementation if DoWhy fails.
+ Uses bootstrap SE based on the custom implementation regardless.
+ '''
+ query = kwargs.get('query')
+ n_bootstraps = kwargs.get('n_bootstraps', 100)
+
+ # --- Parameter Setup (as before) ---
+ llm_params = get_llm_parameters(df, query, "PS.Matching")
+ llm_suggested_params = llm_params.get("parameters", {})
+
+ caliper = kwargs.get('caliper', llm_suggested_params.get('caliper'))
+ temp_propensity_scores_for_caliper = None
+ try:
+ temp_propensity_scores_for_caliper = estimate_propensity_scores(
+ df, treatment, covariates,
+ model_type=llm_suggested_params.get('propensity_model_type', 'logistic'),
+ **kwargs
+ )
+ if caliper is None and temp_propensity_scores_for_caliper is not None:
+ logit_ps = _calculate_logit(temp_propensity_scores_for_caliper)
+ if not np.isnan(logit_ps).all(): # Check if logit calculation was successful
+ caliper = 0.2 * np.nanstd(logit_ps) # Use nanstd for robustness
+ else:
+ logger.warning("Logit of propensity scores resulted in NaNs, cannot calculate heuristic caliper.")
+ caliper = None
+ elif caliper is None:
+ logger.warning("Could not estimate propensity scores for caliper heuristic.")
+ caliper = None
+
+ except Exception as e:
+ logger.warning(f"Failed to estimate initial propensity scores for caliper heuristic: {e}. Caliper set to None.")
+ caliper = None # Proceed without caliper if heuristic fails
+
+ n_neighbors = kwargs.get('n_neighbors', llm_suggested_params.get('n_neighbors', 1))
+ propensity_model_type = kwargs.get('propensity_model_type',
+ llm_suggested_params.get('propensity_model_type',
+ select_propensity_model(df, treatment, covariates, query)))
+
+ # --- Attempt DoWhy PSM for Point Estimate ---
+ att_estimate = np.nan
+ method_used_for_att = "Fallback Custom PSM"
+ dowhy_model = None
+ identified_estimand = None
+
+ try:
+ logger.info("Attempting estimation using DoWhy Propensity Score Matching...")
+ dowhy_model = CausalModel(
+ data=df,
+ treatment=treatment,
+ outcome=outcome,
+ common_causes=covariates,
+ estimand_type='nonparametric-ate' # Provide list of names directly
+ )
+ # Identify estimand (optional step, but good practice)
+ identified_estimand = dowhy_model.identify_effect(proceed_when_unidentifiable=True)
+ logger.info(f"DoWhy identified estimand: {identified_estimand}")
+
+ # Estimate effect using DoWhy's PSM
+ estimate = dowhy_model.estimate_effect(
+ identified_estimand,
+ method_name="backdoor.propensity_score_matching",
+ target_units="att",
+ method_params={}
+ )
+ att_estimate = estimate.value
+ method_used_for_att = "DoWhy PSM"
+ logger.info(f"DoWhy PSM successful. ATT Estimate: {att_estimate}")
+
+ except Exception as e:
+ logger.warning(f"DoWhy PSM failed: {e}. Falling back to custom PSM implementation.")
+ # Fallback is triggered implicitly if att_estimate remains NaN
+
+ # --- Fallback or if DoWhy failed ---
+ if np.isnan(att_estimate):
+ logger.info("Calculating ATT estimate using fallback custom PSM...")
+ att_estimate = _perform_matching_and_get_att(
+ df, treatment, outcome, covariates,
+ propensity_model_type, n_neighbors, caliper,
+ perform_bias_adjustment=True, **kwargs # Bias adjust the fallback
+ )
+ method_used_for_att = "Fallback Custom PSM" # Confirm it's fallback
+ if np.isnan(att_estimate):
+ raise ValueError("Fallback custom PSM estimation also failed. Cannot proceed.")
+ logger.info(f"Fallback Custom PSM successful. ATT Estimate: {att_estimate}")
+
+ # --- Bootstrap SE (using custom helper for consistency) ---
+ logger.info(f"Calculating Bootstrap SE using custom helper ({n_bootstraps} iterations)...")
+ bootstrap_atts = []
+ for i in range(n_bootstraps):
+ try:
+ # Ensure bootstrap samples are drawn correctly
+ df_boot = df.sample(n=len(df), replace=True, random_state=np.random.randint(1000000) + i)
+ # Bias adjustment in bootstrap can be slow, optionally disable it
+ boot_att = _perform_matching_and_get_att(
+ df_boot, treatment, outcome, covariates,
+ propensity_model_type, n_neighbors, caliper,
+ perform_bias_adjustment=False, **kwargs # Set bias adjustment to False for speed in bootstrap
+ )
+ if not np.isnan(boot_att):
+ bootstrap_atts.append(boot_att)
+ except Exception as boot_e:
+ logger.warning(f"Bootstrap iteration {i+1} failed: {boot_e}")
+ continue # Skip failed bootstrap iteration
+
+ att_se = np.nanstd(bootstrap_atts) if bootstrap_atts else np.nan # Use nanstd
+ actual_bootstrap_iterations = len(bootstrap_atts)
+ logger.info(f"Bootstrap SE calculated: {att_se} from {actual_bootstrap_iterations} successful iterations.")
+
+ # --- Diagnostics (using custom matching logic for consistency) ---
+ logger.info("Performing diagnostic checks using custom matching logic...")
+ diagnostics = {"error": "Diagnostics failed to run."}
+ propensity_scores_orig = temp_propensity_scores_for_caliper # Reuse if available and not None
+
+ if propensity_scores_orig is None:
+ try:
+ propensity_scores_orig = estimate_propensity_scores(
+ df, treatment, covariates, model_type=propensity_model_type, **kwargs
+ )
+ except Exception as e:
+ logger.error(f"Failed to estimate propensity scores for diagnostics: {e}")
+ propensity_scores_orig = None
+
+ if propensity_scores_orig is not None and not np.isnan(propensity_scores_orig).all():
+ df_ps_orig = df.copy()
+ df_ps_orig['propensity_score'] = propensity_scores_orig
+ treated_orig = df_ps_orig[df_ps_orig[treatment] == 1]
+ control_orig = df_ps_orig[df_ps_orig[treatment] == 0]
+ unmatched_treated_count = 0
+
+ # Drop rows with NaN propensity scores before diagnostics
+ treated_orig = treated_orig.dropna(subset=['propensity_score'])
+ control_orig = control_orig.dropna(subset=['propensity_score'])
+
+ if not treated_orig.empty and not control_orig.empty:
+ try:
+ nn_diag = NearestNeighbors(n_neighbors=n_neighbors, radius=caliper if caliper is not None else np.inf, metric='minkowski', p=2)
+ nn_diag.fit(control_orig[['propensity_score']].values)
+ distances_diag, indices_diag = nn_diag.kneighbors(treated_orig[['propensity_score']].values)
+
+ matched_treated_indices_diag = []
+ matched_control_indices_diag = []
+
+ for i in range(len(treated_orig)):
+ valid_neighbors_mask_diag = distances_diag[i] <= (caliper if caliper is not None else np.inf)
+ valid_neighbors_idx_diag = indices_diag[i][valid_neighbors_mask_diag]
+ if len(valid_neighbors_idx_diag) > 0:
+ # Get original DataFrame indices from control_orig based on iloc indices
+ selected_control_original_indices = control_orig.index[valid_neighbors_idx_diag]
+ matched_treated_indices_diag.extend([treated_orig.index[i]] * len(selected_control_original_indices))
+ matched_control_indices_diag.extend(selected_control_original_indices)
+ else:
+ unmatched_treated_count += 1
+
+ if matched_control_indices_diag:
+ # Use unique indices for creating the diagnostic dataframe
+ unique_matched_control_indices = list(set(matched_control_indices_diag))
+ unique_matched_treated_indices = list(set(matched_treated_indices_diag))
+
+ matched_control_df_diag = df.loc[unique_matched_control_indices]
+ matched_treated_df_for_diag = df.loc[unique_matched_treated_indices]
+ matched_df_diag = pd.concat([matched_treated_df_for_diag, matched_control_df_diag]).drop_duplicates()
+
+ # Retrieve propensity scores for the specific units in matched_df_diag
+ ps_matched_for_diag = propensity_scores_orig.loc[matched_df_diag.index]
+
+ diagnostics = assess_balance(df, matched_df_diag, treatment, covariates,
+ method="PSM",
+ propensity_scores_original=propensity_scores_orig,
+ propensity_scores_matched=ps_matched_for_diag)
+ else:
+ diagnostics = {"message": "No units could be matched for diagnostic assessment."}
+ # If no controls were matched, all treated were unmatched
+ unmatched_treated_count = len(treated_orig) if not treated_orig.empty else 0
+ except Exception as diag_e:
+ logger.error(f"Error during diagnostic matching/balance assessment: {diag_e}")
+ diagnostics = {"error": f"Diagnostics failed: {diag_e}"}
+ else:
+ diagnostics = {"message": "Treatment or control group empty after dropping NaN PS, diagnostics skipped."}
+ unmatched_treated_count = len(treated_orig) if not treated_orig.empty else 0
+
+ # Ensure unmatched count calculation is safe
+ if 'unmatched_treated_count' not in locals():
+ unmatched_treated_count = 0 # Initialize if loop didn't run
+ diagnostics["unmatched_treated_count"] = unmatched_treated_count
+ diagnostics["percent_treated_matched"] = (len(treated_orig) - unmatched_treated_count) / len(treated_orig) * 100 if len(treated_orig) > 0 else 0
+ else:
+ diagnostics = {"error": "Propensity scores could not be estimated for diagnostics."}
+
+ # Add final details to diagnostics
+ diagnostics["att_estimation_method"] = method_used_for_att
+ diagnostics["propensity_score_model"] = propensity_model_type
+ diagnostics["bootstrap_iterations_for_se"] = actual_bootstrap_iterations
+ diagnostics["final_caliper_used"] = caliper
+
+ # --- Format and return results ---
+ logger.info(f"Formatting results. ATT Estimate: {att_estimate}, SE: {att_se}, Method: {method_used_for_att}")
+ return format_ps_results(att_estimate, att_se, diagnostics,
+ method_details=f"PSM ({method_used_for_att})",
+ parameters={"caliper": caliper,
+ "n_neighbors": n_neighbors, # n_neighbors used in fallback/bootstrap/diag
+ "propensity_model": propensity_model_type,
+ "n_bootstraps_config": n_bootstraps})
\ No newline at end of file
diff --git a/auto_causal/methods/propensity_score/weighting.py b/auto_causal/methods/propensity_score/weighting.py
new file mode 100644
index 0000000000000000000000000000000000000000..28aba248e27e394756d0c5d32bed3d40373e21d2
--- /dev/null
+++ b/auto_causal/methods/propensity_score/weighting.py
@@ -0,0 +1,124 @@
+# Propensity Score Weighting (IPW) Implementation
+
+import pandas as pd
+import numpy as np
+import statsmodels.api as sm
+from typing import Dict, List, Optional, Any
+
+from .base import estimate_propensity_scores, format_ps_results, select_propensity_model
+from .diagnostics import assess_weight_distribution, plot_overlap, plot_balance # Import diagnostic functions
+from .llm_assist import determine_optimal_weight_type, determine_optimal_trim_threshold, get_llm_parameters # Import LLM helpers
+
+def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str,
+ covariates: List[str], **kwargs) -> Dict[str, Any]:
+ '''Generic propensity score weighting (IPW) implementation.
+
+ Args:
+ df: Dataset containing causal variables
+ treatment: Name of treatment variable
+ outcome: Name of outcome variable
+ covariates: List of covariate names
+ **kwargs: Method-specific parameters (e.g., weight_type, trim_threshold, query)
+
+ Returns:
+ Dictionary with effect estimate and diagnostics
+ '''
+ query = kwargs.get('query')
+
+ # --- LLM-Assisted Parameter Optimization / Defaults ---
+ llm_params = get_llm_parameters(df, query, "PS.Weighting")
+ llm_suggested_params = llm_params.get("parameters", {})
+
+ # Explicitly check LLM suggestion before falling back to default helper
+ llm_weight_type = llm_suggested_params.get('weight_type')
+ default_weight_type = determine_optimal_weight_type(df, treatment, query) if llm_weight_type is None else llm_weight_type
+ weight_type = kwargs.get('weight_type', default_weight_type)
+
+ # Similar explicit check for trim_threshold
+ llm_trim_thresh = llm_suggested_params.get('trim_threshold')
+ default_trim_thresh = determine_optimal_trim_threshold(df, treatment, query=query) if llm_trim_thresh is None else llm_trim_thresh
+ trim_threshold = kwargs.get('trim_threshold', default_trim_thresh)
+
+ propensity_model_type = kwargs.get('propensity_model_type',
+ llm_suggested_params.get('propensity_model_type',
+ select_propensity_model(df, treatment, covariates, query)))
+ robust_se = kwargs.get('robust_se', True)
+
+ # --- Step 1: Estimate propensity scores ---
+ propensity_scores = estimate_propensity_scores(df, treatment, covariates,
+ model_type=propensity_model_type,
+ **kwargs) # Pass other kwargs like C, penalty etc.
+ df_ps = df.copy()
+ df_ps['propensity_score'] = propensity_scores
+
+ # --- Step 2: Calculate weights ---
+ if weight_type.upper() == 'ATE':
+ weights = np.where(df_ps[treatment] == 1,
+ 1 / df_ps['propensity_score'],
+ 1 / (1 - df_ps['propensity_score']))
+ elif weight_type.upper() == 'ATT':
+ weights = np.where(df_ps[treatment] == 1,
+ 1,
+ df_ps['propensity_score'] / (1 - df_ps['propensity_score']))
+ # TODO: Add other weight types like ATC if needed
+ else:
+ raise ValueError(f"Unsupported weight type: {weight_type}")
+
+ df_ps['ipw'] = weights
+
+ # --- Step 3: Apply trimming if needed ---
+ if trim_threshold is not None and trim_threshold > 0:
+ # Trim based on propensity score percentile
+ min_ps_thresh = np.percentile(propensity_scores, trim_threshold * 100)
+ max_ps_thresh = np.percentile(propensity_scores, (1 - trim_threshold) * 100)
+
+ keep_indices = (df_ps['propensity_score'] >= min_ps_thresh) & (df_ps['propensity_score'] <= max_ps_thresh)
+ df_trimmed = df_ps[keep_indices].copy()
+ print(f"Trimming {len(df_ps) - len(df_trimmed)} units ({trim_threshold*100:.1f}% percentile trim)")
+ if df_trimmed.empty:
+ raise ValueError("All units removed after trimming. Try a smaller trim_threshold.")
+ df_analysis = df_trimmed
+ else:
+ # Trim based on weight percentile (alternative approach)
+ # q_low, q_high = np.percentile(weights, [trim_threshold*100, (1-trim_threshold)*100])
+ # df_ps['ipw'] = np.clip(df_ps['ipw'], q_low, q_high)
+ df_analysis = df_ps.copy()
+ trim_threshold = 0 # Explicitly set for parameters output
+
+ # --- Step 4: Normalize weights (optional but common) ---
+ # Normalize weights to sum to sample size within treated/control groups if ATT
+ if weight_type.upper() == 'ATT':
+ sum_weights_treated = df_analysis.loc[df_analysis[treatment] == 1, 'ipw'].sum()
+ sum_weights_control = df_analysis.loc[df_analysis[treatment] == 0, 'ipw'].sum()
+ n_treated = (df_analysis[treatment] == 1).sum()
+ n_control = (df_analysis[treatment] == 0).sum()
+
+ if sum_weights_treated > 0:
+ df_analysis.loc[df_analysis[treatment] == 1, 'ipw'] *= n_treated / sum_weights_treated
+ if sum_weights_control > 0:
+ df_analysis.loc[df_analysis[treatment] == 0, 'ipw'] *= n_control / sum_weights_control
+ else: # ATE normalization
+ df_analysis['ipw'] *= len(df_analysis) / df_analysis['ipw'].sum()
+
+ # --- Step 5: Estimate weighted treatment effect ---
+ X_treat = sm.add_constant(df_analysis[[treatment]]) # Use only treatment variable for direct effect
+ wls_model = sm.WLS(df_analysis[outcome], X_treat, weights=df_analysis['ipw'])
+ results = wls_model.fit(cov_type='HC1' if robust_se else 'nonrobust')
+
+ effect = results.params[treatment]
+ effect_se = results.bse[treatment]
+
+ # --- Step 6: Validate weight quality / Diagnostics ---
+ diagnostics = assess_weight_distribution(df_analysis['ipw'], df_analysis[treatment])
+ # Could also add balance assessment on the weighted sample
+ # weighted_diagnostics = assess_balance(df, df_analysis, treatment, covariates, method="PSW", weights=df_analysis['ipw'])
+ # diagnostics.update(weighted_diagnostics)
+ diagnostics["propensity_score_model"] = propensity_model_type
+
+ # --- Step 7: Format and return results ---
+ return format_ps_results(effect, effect_se, diagnostics,
+ method_details="PS.Weighting",
+ parameters={"weight_type": weight_type,
+ "trim_threshold": trim_threshold,
+ "propensity_model": propensity_model_type,
+ "robust_se": robust_se})
\ No newline at end of file
diff --git a/auto_causal/methods/regression_discontinuity/__init__.py b/auto_causal/methods/regression_discontinuity/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/auto_causal/methods/regression_discontinuity/diagnostics.py b/auto_causal/methods/regression_discontinuity/diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..294ffc37f4b081172199192e42b62e27d11cedac
--- /dev/null
+++ b/auto_causal/methods/regression_discontinuity/diagnostics.py
@@ -0,0 +1,97 @@
+"""
+Diagnostic checks for Regression Discontinuity Design (RDD).
+"""
+
+from typing import Dict, Any, List, Optional
+import pandas as pd
+import numpy as np
+from scipy import stats
+import logging
+
+logger = logging.getLogger(__name__)
+
+def run_rdd_diagnostics(
+ df: pd.DataFrame,
+ outcome: str,
+ running_variable: str,
+ cutoff: float,
+ covariates: Optional[List[str]] = None,
+ bandwidth: Optional[float] = None
+) -> Dict[str, Any]:
+ """
+ Runs diagnostic checks for RDD analysis.
+
+ Currently includes:
+ - Covariate Balance Check (t-tests)
+ Placeholders for:
+ - Density Test (McCrary)
+ - Placebo Cutoff Tests
+ - Bandwidth Sensitivity
+
+ Args:
+ df: Input DataFrame.
+ outcome: Name of the outcome variable.
+ running_variable: Name of the running variable.
+ cutoff: The threshold value.
+ covariates: Optional list of covariate names to check for balance.
+ bandwidth: Optional bandwidth to restrict the analysis. If None, a default is used.
+
+ Returns:
+ Dictionary containing diagnostic results.
+ """
+ diagnostics = {}
+ details = {}
+
+ if bandwidth is None:
+ # Use the same default as estimator for consistency
+ range_rv = df[running_variable].max() - df[running_variable].min()
+ bandwidth = 0.1 * range_rv
+ logger.warning(f"No bandwidth provided for diagnostics, using basic default: {bandwidth:.3f}")
+
+ # --- Filter data within bandwidth ---
+ df_bw = df[(df[running_variable] >= cutoff - bandwidth) & (df[running_variable] <= cutoff + bandwidth)].copy()
+ if df_bw.empty:
+ logger.warning("No data within bandwidth for diagnostics.")
+ return {"status": "Skipped", "reason": "No data in bandwidth", "details": details}
+
+ df_below = df_bw[df_bw[running_variable] < cutoff]
+ df_above = df_bw[df_bw[running_variable] >= cutoff]
+
+ if df_below.empty or df_above.empty:
+ logger.warning("Insufficient data above or below cutoff within bandwidth for diagnostics.")
+ return {"status": "Skipped", "reason": "Insufficient data near cutoff", "details": details}
+
+ # --- Covariate Balance Check ---
+ if covariates:
+ balance_results = {}
+ details['covariate_balance'] = balance_results
+ for cov in covariates:
+ if cov in df_bw.columns:
+ try:
+ # Perform t-test for difference in means
+ t_stat, p_val = stats.ttest_ind(
+ df_below[cov].dropna(),
+ df_above[cov].dropna(),
+ equal_var=False # Welch's t-test
+ )
+ balance_results[cov] = {
+ 't_statistic': t_stat,
+ 'p_value': p_val,
+ 'balanced': "Yes" if p_val > 0.05 else "No (p <= 0.05)"
+ }
+ except Exception as e:
+ logger.warning(f"Could not perform t-test for covariate '{cov}': {e}")
+ balance_results[cov] = {"status": "Test Failed", "error": str(e)}
+ else:
+ balance_results[cov] = {"status": "Column Not Found"}
+ else:
+ details['covariate_balance'] = "No covariates provided to check."
+
+ # --- Placeholders for other common RDD diagnostics ---
+ details['continuity_density_test'] = "Not Implemented (Requires specialized libraries like rdd)"
+ details['placebo_cutoff_test'] = "Not Implemented (Requires re-running estimation)"
+ details['bandwidth_sensitivity'] = "Not Implemented (Requires re-running estimation)"
+ details['visual_inspection'] = "Recommended (Plot outcome vs running variable with fits)"
+
+ return {"status": "Success (Partial Implementation)", "details": details}
+
diff --git a/auto_causal/methods/regression_discontinuity/estimator.py b/auto_causal/methods/regression_discontinuity/estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..623fcd91095c47e1aee3dd5de4ba24b8edc1f1d1
--- /dev/null
+++ b/auto_causal/methods/regression_discontinuity/estimator.py
@@ -0,0 +1,395 @@
+"""
+Regression Discontinuity Design (RDD) Estimator.
+
+Tries to use DoWhy's RDD implementation first, falling back to a basic
+comparison of linear fits around the cutoff if DoWhy fails.
+"""
+
+import pandas as pd
+import statsmodels.api as sm
+from dowhy import CausalModel
+from typing import Dict, Any, List, Optional
+import logging
+from langchain.chat_models.base import BaseChatModel # For type hinting llm
+
+from .diagnostics import run_rdd_diagnostics
+from .llm_assist import interpret_rdd_results
+
+logger = logging.getLogger(__name__)
+
+# Attempt to import specific functions from the evan-magnusson/rdd package
+_rdd_estimator_func_em = None
+_rdd_optimal_bw_func_em = None
+_rdd_em_import_error_message = ""
+try:
+ import rdd
+ from rdd import rdd
+
+ logger.info("Successfully imported 'rdd' and 'optimal_bandwidth' from evan-magnusson/rdd package.")
+except ImportError as e:
+ _rdd_em_import_error_message = f"ImportError for evan-magnusson/rdd: {e}. This package is needed for 'effect_estimate_rdd'."
+ logger.warning(_rdd_em_import_error_message)
+except Exception as e: # Catch other potential errors during import
+ _rdd_em_import_error_message = f"An unexpected error occurred during import from evan-magnusson/rdd: {e}"
+ logger.warning(_rdd_em_import_error_message)
+
+def estimate_effect_dowhy(df: pd.DataFrame, treatment: str, outcome: str, running_variable: str, cutoff_value: float, covariates: Optional[List[str]], **kwargs) -> Dict[str, Any]:
+ """Estimate RDD effect using DoWhy."""
+ logger.info("Attempting RDD estimation using DoWhy.")
+ if covariates:
+ logger.warning("Covariates provided but may not be used by the DoWhy RDD method_name='rdd'. Support varies.")
+ # For DoWhy RDD, we don't typically specify common causes in the model
+ # constructor in the same way as backdoor. The running variable is handled
+ # via method_params. Covariates might be used by specific underlying estimators
+ # if supported, but the basic RDD identification doesn't use them directly.
+ model = CausalModel(
+ data=df,
+ treatment=treatment,
+ outcome=outcome,
+ # No explicit graph needed for iv.regression_discontinuity method
+ )
+
+ # Identify the effect (DoWhy internally identifies RDD as IV)
+ # Although potentially redundant if method_name implies identification,
+ # the API requires identified_estimand as the first argument.
+ identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
+
+ # Estimate using RDD method
+ # Note: DoWhy's RDD often has limited direct support for covariates.
+ # Bandwidth selection is crucial and often done internally or specified.
+ bandwidth = kwargs.get('bandwidth') # Get user-specified bandwidth if provided
+ if bandwidth is None:
+ # Very basic default bandwidth if none provided - consider better methods
+ range_rv = df[running_variable].max() - df[running_variable].min()
+ bandwidth = 0.1 * range_rv
+ logger.warning(f"No bandwidth specified, using basic default: {bandwidth:.3f}")
+
+ estimate = model.estimate_effect(
+ identified_estimand, # ADD identified_estimand argument
+ method_name="iv.regression_discontinuity",
+ method_params={
+ 'rd_variable_name': running_variable,
+ 'rd_threshold_value': cutoff_value,
+ 'rd_bandwidth': bandwidth,
+ # 'covariates': covariates # Support depends on DoWhy version/estimator
+ },
+ test_significance=True # Ask DoWhy to calculate p-values if possible
+ )
+
+ # Extract results - DoWhy's RDD estimate structure might vary
+ effect = estimate.value
+ # DoWhy's RDD significance testing might be limited/indirect
+ # Try to get p-value if estimate object supports it, else None
+ p_value = getattr(estimate, 'test_significance_pvalue', None)
+ if isinstance(p_value, (list, tuple)):
+ p_value = p_value[0] # Handle cases where it might be wrapped
+
+ # Confidence intervals might not be directly available from this method easily
+ conf_int = getattr(estimate, 'confidence_interval', None)
+ std_err = getattr(estimate, 'standard_error', None)
+
+ return {
+ 'effect_estimate': effect,
+ 'p_value': p_value,
+ 'confidence_interval': conf_int,
+ 'standard_error': std_err,
+ 'method_details': f"DoWhy RDD (Bandwidth: {bandwidth:.3f})",
+ }
+
+def estimate_effect_fallback(df: pd.DataFrame, treatment: str, outcome: str, running_variable: str, cutoff_value: float, covariates: Optional[List[str]], **kwargs) -> Dict[str, Any]:
+ """Estimate RDD effect using simple linear regression comparison fallback."""
+ logger.warning("DoWhy RDD failed or not used. Falling back to simple linear regression comparison.")
+ if covariates:
+ logger.warning("Covariates provided but are ignored in the fallback RDD linear regression estimation.")
+
+ bandwidth = kwargs.get('bandwidth')
+ if bandwidth is None:
+ range_rv = df[running_variable].max() - df[running_variable].min()
+ bandwidth = 0.1 * range_rv
+ logger.warning(f"No bandwidth specified for fallback, using basic default: {bandwidth:.3f}")
+
+ # Filter data within bandwidth
+ df_bw = df[(df[running_variable] >= cutoff_value - bandwidth) & (df[running_variable] <= cutoff_value + bandwidth)].copy()
+ if df_bw.empty:
+ raise ValueError("No data within the specified bandwidth.")
+
+ df_bw['above_cutoff'] = (df_bw[running_variable] >= cutoff_value).astype(int)
+
+ # Define predictors for the regression
+ # Interaction term allows different slopes above and below the cutoff
+ df_bw['running_centered'] = df_bw[running_variable] - cutoff_value
+ df_bw['running_x_above'] = df_bw['running_centered'] * df_bw['above_cutoff']
+ predictors = ['above_cutoff', 'running_centered', 'running_x_above']
+
+ # Covariates are NOT included in this basic RDD model
+ # if covariates:
+ # predictors.extend(covariates) # REMOVED as per user request
+
+ required_cols = [outcome] + predictors
+ missing_cols = [col for col in required_cols if col not in df_bw.columns]
+ if missing_cols:
+ raise ValueError(f"Fallback RDD missing columns: {missing_cols}")
+
+ df_analysis = df_bw[required_cols].dropna()
+ if df_analysis.empty:
+ raise ValueError("No data remaining after dropping NaNs for fallback RDD.")
+
+ X = df_analysis[predictors]
+ X = sm.add_constant(X)
+ y = df_analysis[outcome]
+
+ formula = f"{outcome} ~ {' + '.join(predictors)} + const"
+ logger.info(f"Running fallback RDD regression: {formula}")
+
+ model = sm.OLS(y, X)
+ # Use robust standard errors
+ results = model.fit(cov_type='HC1')
+
+ # The coefficient for 'above_cutoff' represents the jump at the cutoff
+ effect = results.params['above_cutoff']
+ p_value = results.pvalues['above_cutoff']
+ conf_int = results.conf_int().loc['above_cutoff'].tolist()
+ std_err = results.bse['above_cutoff']
+
+ return {
+ 'effect_estimate': effect,
+ 'p_value': p_value,
+ 'confidence_interval': conf_int,
+ 'standard_error': std_err,
+ 'method_details': f"Fallback Linear Interaction (Bandwidth: {bandwidth:.3f})",
+ 'formula': formula,
+ 'model_summary': results.summary()
+ }
+
+def effect_estimate_rdd(
+ df: pd.DataFrame,
+ outcome: str,
+ running_variable: str,
+ cutoff_value: float,
+ treatment: Optional[str] = None, # Kept for API consistency, but unused by evan-magnusson/rdd
+ covariates: Optional[List[str]] = None,
+ bandwidth: Optional[float] = None,
+ **kwargs
+) -> Dict[str, Any]:
+ """
+ Estimates RDD effect using the 'evan-magnusson/rdd' package.
+ Uses IK optimal bandwidth selection from the same package by default.
+ """
+ logger.info(f"Attempting RDD estimation using 'evan-magnusson/rdd' for outcome '{outcome}' and running variable '{running_variable}'.")
+
+
+
+ if treatment:
+ logger.info(f"Treatment variable '{treatment}' provided but is not explicitly used by the evan-magnusson/rdd estimation function.")
+ if covariates:
+ logger.warning("Covariates provided but are ignored by this 'evan-magnusson/rdd' implementation.")
+
+ # --- Bandwidth Selection ---
+ final_bandwidth = None
+ bandwidth_selection_method = "unknown"
+
+ if bandwidth is not None and bandwidth > 0:
+ logger.info(f"Using user-specified bandwidth: {bandwidth:.4f}")
+ final_bandwidth = bandwidth
+ bandwidth_selection_method = "user-specified"
+ else:
+ if bandwidth is not None and bandwidth <= 0:
+ logger.warning(f"User-specified bandwidth {bandwidth} is not positive. Attempting IK optimal bandwidth selection.")
+ try:
+ logger.info(f"Attempting IK optimal bandwidth selection using _rdd_optimal_bw_func_em for {outcome} ~ {running_variable} cut at {cutoff_value}.")
+ optimal_bw_val = rdd.optimal_bandwidth(df[outcome], df[running_variable], cut=cutoff_value)
+ if optimal_bw_val is not None and optimal_bw_val > 0:
+ final_bandwidth = optimal_bw_val
+ bandwidth_selection_method = "ik_optimal (evan-magnusson/rdd)"
+ logger.info(f"IK optimal bandwidth from evan-magnusson/rdd: {final_bandwidth:.4f}")
+ else:
+ logger.warning(f"IK optimal bandwidth from evan-magnusson/rdd was None or non-positive: {optimal_bw_val}. Falling back to default.")
+ except Exception as e:
+ logger.warning(f"IK optimal bandwidth selection from evan-magnusson/rdd failed: {e}. Falling back to default.")
+
+ if final_bandwidth is None: # Fallback if user did not specify and IK failed/invalid
+ logger.info("Falling back to default bandwidth (10% of running variable range).")
+ rv_min = df[running_variable].min()
+ rv_max = df[running_variable].max()
+ rv_range = rv_max - rv_min
+ if rv_range > 0:
+ final_bandwidth = 0.1 * rv_range
+ bandwidth_selection_method = "default_10_percent_range"
+ logger.info(f"Using default 10% range bandwidth: {final_bandwidth:.4f}")
+ else:
+ err_msg = "Running variable range is not positive. Cannot determine a default bandwidth for evan-magnusson/rdd."
+ logger.error(err_msg)
+ raise ValueError(err_msg)
+
+ if final_bandwidth is None or final_bandwidth <= 0:
+ raise ValueError(f"Could not determine a valid positive bandwidth for evan-magnusson/rdd. Last method: {bandwidth_selection_method}")
+
+ # --- RDD Estimation ---
+ try:
+ logger.info(f"Running RDD estimation with evan-magnusson/rdd: y='{outcome}', x='{running_variable}', cut={cutoff_value}, bw={final_bandwidth:.4f}")
+ # The evan-magnusson/rdd package's rdd function typically handles dataframes directly
+ # Ensure correct xname for truncated_data
+ data_rdd = rdd.truncated_data(df, running_variable,final_bandwidth, cut=cutoff_value)
+ model = rdd.rdd(
+ data_rdd,
+ xname=running_variable, # Correct: Name of the running variable column
+ yname=outcome, # Correct: Name of the outcome variable column
+ cut=cutoff_value
+ )
+
+ # Extract results - this package creates a treatment dummy 'TREATED'
+ # The 'model' object has a 'results' attribute which is a statsmodels result instance
+ sm_results = model.fit()
+ print(sm_results.summary())
+
+ # Extract results - using 'TREATED' based on the provided summary output
+ effect = sm_results.params.get('TREATED')
+ std_err = sm_results.bse.get('TREATED')
+ p_value = sm_results.pvalues.get('TREATED')
+
+ conf_int_series = sm_results.conf_int()
+ conf_int = conf_int_series.loc['TREATED'].tolist() if 'TREATED' in conf_int_series.index else [None, None]
+
+ n_obs = model.nobs # or model.n_ if nobs is not available (check package details)
+
+ # The formula is implicit in the local linear regression performed by the package
+ # Update to reflect 'TREATED' as the dummy variable name if consistently used by the package
+ formula_desc = f"Local linear RDD: {outcome} ~ TREATED + {running_variable}_centered + TREATED*{running_variable}_centered (implicit, from evan-magnusson/rdd)"
+
+ return {
+ 'effect_estimate': effect,
+ 'standard_error': std_err,
+ 'p_value': p_value,
+ 'confidence_interval': conf_int,
+ 'method_details': f"RDD (evan-magnusson/rdd package, Bandwidth: {final_bandwidth:.4f})",
+ 'bandwidth_used': final_bandwidth,
+ 'bandwidth_selection_method': bandwidth_selection_method,
+ 'n_obs_in_bandwidth': int(n_obs) if n_obs is not None else None,
+ 'formula': formula_desc,
+ 'model_summary': sm_results.summary().as_text() if sm_results else "Summary not available."
+ }
+
+ except Exception as e:
+ logger.error(f"RDD estimation using 'evan-magnusson/rdd' failed: {e}", exc_info=True)
+ # Consider re-raising or returning a more structured error
+ raise e # Or return a dict like in the import failure case
+
+def estimate_effect(
+ df: pd.DataFrame,
+ treatment: str,
+ outcome: str,
+ running_variable: str,
+ cutoff_value: float,
+ covariates: Optional[List[str]] = None,
+ bandwidth: Optional[float] = None, # Optional bandwidth param
+ query: Optional[str] = None,
+ llm: Optional[BaseChatModel] = None,
+ **kwargs # Capture other args like rd_estimator from DoWhy if needed
+) -> Dict[str, Any]:
+ """
+ Estimates the causal effect using Regression Discontinuity Design.
+
+ Tries DoWhy implementation first if use_dowhy=True, otherwise uses fallback.
+
+ Args:
+ df: Input DataFrame.
+ treatment: Name of the treatment variable (often implicitly defined by cutoff).
+ DoWhy might still need it, fallback doesn't use it directly.
+ outcome: Name of the outcome variable.
+ running_variable: Name of the variable determining treatment assignment.
+ cutoff: The threshold value for the running variable.
+ covariates: Optional list of covariate names (support varies).
+ bandwidth: Optional bandwidth around the cutoff. If None, a default is used.
+ use_dowhy: Whether to attempt using the DoWhy library first.
+ query: Optional user query for context.
+ llm: Optional Language Model instance.
+ **kwargs: Additional keyword arguments for underlying methods.
+
+ Returns:
+ Dictionary containing estimation results.
+ """
+ required_args = {
+ "running_variable": running_variable,
+ "cutoff_value": cutoff_value
+ }
+ if any(val is None for val in required_args.values()):
+ raise ValueError(f"Missing required RDD arguments: running_variable and cutoff must be provided.")
+
+ results = {}
+ rdd_em_estimation_error = None # Error from effect_estimate_rdd (evan-magnusson)
+ fallback_estimation_error = None # Error from estimate_effect_fallback
+
+ # --- Try effect_estimate_rdd (evan-magnusson/rdd) First ---
+ try:
+ logger.info("Attempting RDD estimation using 'effect_estimate_rdd' (evan-magnusson/rdd package).")
+ # Note: treatment is passed but might be unused, covariates are also passed but typically ignored by this specific rdd package
+ results = effect_estimate_rdd(
+ df,
+ outcome,
+ running_variable,
+ cutoff_value,
+ treatment=treatment, # For API consistency, though evan-magnusson/rdd doesn't use it explicitly
+ covariates=covariates,
+ bandwidth=bandwidth,
+ **kwargs
+ )
+ results['method_used'] = 'evan-magnusson/rdd' # Ensure method_used is set
+ logger.info("Successfully estimated effect using 'effect_estimate_rdd'.")
+ except ImportError as ie: # Specifically catch import errors for the rdd package
+ logger.warning(f"'effect_estimate_rdd' could not run due to ImportError (likely evan-magnusson/rdd package not available/functional): {ie}")
+ rdd_em_estimation_error = ie
+ except Exception as e:
+ logger.warning(f"'effect_estimate_rdd' failed during execution: {e}")
+ rdd_em_estimation_error = e
+
+ # --- Fallback to estimate_effect_fallback if effect_estimate_rdd failed ---
+ if not results: # If effect_estimate_rdd wasn't used or failed
+ logger.info("'effect_estimate_rdd' did not produce results. Attempting fallback using 'estimate_effect_fallback'.")
+ try:
+ fallback_results = estimate_effect_fallback(df, treatment, outcome, running_variable, cutoff_value, covariates, bandwidth=bandwidth, **kwargs)
+ results.update(fallback_results)
+ results['method_used'] = 'Fallback RDD (Linear Interaction with Robust Errors)'
+ fallback_estimation_error = None # Clear fallback error if it succeeded
+ logger.info("Successfully estimated effect using 'estimate_effect_fallback'.")
+ except Exception as e:
+ logger.error(f"Fallback RDD estimation ('estimate_effect_fallback') also failed: {e}")
+ fallback_estimation_error = e
+
+ # Determine final error status
+ final_estimation_error = None
+ if not results: # If still no results, determine which error to report
+ if fallback_estimation_error: # Fallback was attempted and failed
+ final_estimation_error = fallback_estimation_error
+ logger.error(f"All RDD estimation attempts failed. Last error (from fallback): {final_estimation_error}")
+ elif rdd_em_estimation_error: # effect_estimate_rdd was attempted and failed, fallback was not (or also failed but error not captured)
+ final_estimation_error = rdd_em_estimation_error
+ logger.error(f"All RDD estimation attempts failed. Last error (from effect_estimate_rdd): {final_estimation_error}")
+ else:
+ logger.error("All RDD estimation attempts failed for an unknown reason.")
+
+ if final_estimation_error:
+ raise ValueError(f"RDD estimation failed. Last error: {final_estimation_error}")
+ else:
+ raise ValueError("RDD estimation failed using all available methods for an unknown reason.")
+
+ # --- Diagnostics ---
+ try:
+ diag_results = run_rdd_diagnostics(df, outcome, running_variable, cutoff_value, covariates, bandwidth)
+ results['diagnostics'] = diag_results
+ except Exception as diag_e:
+ logger.error(f"RDD Diagnostics failed: {diag_e}")
+ results['diagnostics'] = {"status": "Failed", "error": str(diag_e)}
+
+ # --- Interpretation ---
+ try:
+ interpretation = interpret_rdd_results(results, results.get('diagnostics'), llm=llm)
+ results['interpretation'] = interpretation
+ except Exception as interp_e:
+ logger.error(f"RDD Interpretation failed: {interp_e}")
+ results['interpretation'] = "Interpretation failed."
+
+ # Add info about primary attempt if fallback was used
+ if rdd_em_estimation_error and results.get('method_used', '').startswith('Fallback'):
+ results['primary_rdd_em_error_info'] = str(rdd_em_estimation_error)
+
+ return results
diff --git a/auto_causal/methods/regression_discontinuity/llm_assist.py b/auto_causal/methods/regression_discontinuity/llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..77fbe9aab85d938bacd5c4ea5235a11f3db76d2f
--- /dev/null
+++ b/auto_causal/methods/regression_discontinuity/llm_assist.py
@@ -0,0 +1,127 @@
+"""
+LLM assistance functions for Regression Discontinuity Design (RDD).
+"""
+
+from typing import List, Dict, Any, Optional
+import logging
+
+# Imported for type hinting
+from langchain.chat_models.base import BaseChatModel
+
+# Import shared LLM helpers
+from auto_causal.utils.llm_helpers import call_llm_with_json_output
+
+logger = logging.getLogger(__name__)
+
+def suggest_rdd_parameters(
+ df_cols: List[str],
+ query: str,
+ llm: Optional[BaseChatModel] = None
+) -> Dict[str, Any]:
+ """
+ (Placeholder) Use LLM to suggest RDD parameters (running variable, cutoff).
+
+ Args:
+ df_cols: List of available column names.
+ query: User's causal query text.
+ llm: Optional LLM model instance.
+
+ Returns:
+ Dictionary containing suggested 'running_variable' and 'cutoff', or empty.
+ """
+ logger.info("LLM RDD parameter suggestion is not implemented yet.")
+ if llm:
+ # Placeholder: Analyze columns, distributions, query for potential
+ # running variables (e.g., 'score', 'age') and cutoffs (e.g., 50, 65).
+ pass
+ return {}
+
+def interpret_rdd_results(
+ results: Dict[str, Any],
+ diagnostics: Optional[Dict[str, Any]],
+ llm: Optional[BaseChatModel] = None
+) -> str:
+ """
+ Use LLM to interpret Regression Discontinuity Design (RDD) results.
+
+ Args:
+ results: Dictionary of estimation results from the RDD estimator.
+ diagnostics: Dictionary of diagnostic test results.
+ llm: Optional LLM model instance.
+
+ Returns:
+ String containing natural language interpretation.
+ """
+ default_interpretation = "LLM interpretation not available for RDD."
+ if llm is None:
+ logger.info("LLM not provided for RDD interpretation.")
+ return default_interpretation
+
+ try:
+ # --- Prepare summary for LLM ---
+ results_summary = {}
+ effect = results.get('effect_estimate')
+ p_val = results.get('p_value')
+ ci = results.get('confidence_interval')
+
+ results_summary['Method Used'] = results.get('method_used', 'RDD')
+ results_summary['Effect Estimate'] = f"{effect:.3f}" if isinstance(effect, (int, float)) else str(effect)
+ results_summary['P-value'] = f"{p_val:.3f}" if isinstance(p_val, (int, float)) else str(p_val)
+ if isinstance(ci, (list, tuple)) and len(ci) == 2:
+ results_summary['Confidence Interval'] = f"[{ci[0]:.3f}, {ci[1]:.3f}]"
+ else:
+ results_summary['Confidence Interval'] = str(ci) if ci is not None else "N/A"
+
+ diag_summary = {}
+ if diagnostics and diagnostics.get("status", "").startswith("Success"):
+ diag_details = diagnostics.get("details", {})
+ diag_summary['Covariate Balance Status'] = "Checked" if 'covariate_balance' in diag_details else "Not Checked"
+ if isinstance(diag_details.get('covariate_balance'), dict):
+ num_unbalanced = sum(1 for cov, res in diag_details['covariate_balance'].items() if isinstance(res, dict) and res.get('balanced', '').startswith("No"))
+ diag_summary['Number of Unbalanced Covariates (p<=0.05)'] = num_unbalanced
+
+ diag_summary['Density Continuity Test'] = diag_details.get('continuity_density_test', 'N/A')
+ diag_summary['Visual Inspection Recommended'] = "Yes" if 'visual_inspection' in diag_details else "No"
+ elif diagnostics:
+ diag_summary['Status'] = diagnostics.get("status", "Unknown")
+ if "error" in diagnostics:
+ diag_summary['Error'] = diagnostics["error"]
+ else:
+ diag_summary['Status'] = "Diagnostics not available or failed."
+
+ # --- Construct Prompt ---
+ prompt = f"""
+ You are assisting with interpreting Regression Discontinuity Design (RDD) results.
+
+ Estimation Results Summary:
+ {results_summary}
+
+ Diagnostics Summary:
+ {diag_summary}
+
+ Explain these RDD results in 2-4 concise sentences. Focus on:
+ 1. The estimated causal effect at the cutoff (magnitude, direction, statistical significance based on p-value < 0.05, if available).
+ 2. Key diagnostic findings (specifically mention covariate balance issues if present, and note that other checks like density continuity were not performed).
+ 3. Mention that visual inspection of the running variable vs outcome is recommended.
+
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
+ {{
+ "interpretation": ""
+ }}
+ """
+
+ # --- Call LLM ---
+ response = call_llm_with_json_output(llm, prompt)
+
+ # --- Process Response ---
+ if response and isinstance(response, dict) and \
+ "interpretation" in response and isinstance(response["interpretation"], str):
+ return response["interpretation"]
+ else:
+ logger.warning(f"Failed to get valid interpretation from LLM for RDD. Response: {response}")
+ return default_interpretation
+
+ except Exception as e:
+ logger.error(f"Error during LLM interpretation for RDD: {e}")
+ return f"Error generating interpretation: {e}"
+
diff --git a/auto_causal/methods/utils.py b/auto_causal/methods/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a24c28eda7fe04dcf582a576316cabfb76cc77d2
--- /dev/null
+++ b/auto_causal/methods/utils.py
@@ -0,0 +1,717 @@
+"""
+Utility functions for causal inference methods.
+
+This module provides common utility functions used across
+different causal inference methods.
+"""
+
+from typing import Dict, List, Set, Optional, Union, Any, Tuple
+import numpy as np
+import pandas as pd
+import scipy.stats as stats
+from sklearn.preprocessing import StandardScaler
+import matplotlib.pyplot as plt
+import seaborn as sns
+from statsmodels.stats.outliers_influence import variance_inflation_factor
+from sklearn.linear_model import LogisticRegression
+import logging
+
+# Configure logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def check_binary_treatment(treatment_series: pd.Series) -> bool:
+ """
+ Check if treatment variable is binary.
+
+ Args:
+ treatment_series: Series containing treatment variable
+
+ Returns:
+ Boolean indicating if treatment is binary
+ """
+ unique_values = set(treatment_series.unique())
+ # Remove NaN values if present
+ unique_values = {x for x in unique_values if pd.notna(x)}
+
+ # Check if there are exactly 2 unique values
+ if len(unique_values) != 2:
+ return False
+
+ # Check if values are 0/1 or similar binary encoding
+ sorted_vals = sorted(unique_values)
+
+ # Check common binary encodings: 0/1, False/True, etc.
+ binary_pairs = [
+ (0, 1),
+ (False, True),
+ ("0", "1"),
+ ("no", "yes"),
+ ("false", "true")
+ ]
+
+ # Convert to strings for comparison if needed
+ if not all(isinstance(v, (int, float, bool)) for v in sorted_vals):
+ # Convert to lowercase strings for comparison
+ str_vals = [str(v).lower() for v in sorted_vals]
+ for pair in binary_pairs:
+ str_pair = [str(v).lower() for v in pair]
+ if str_vals == str_pair:
+ return True
+ return False
+
+ # For numeric values, check if they're 0/1 or can be easily mapped to 0/1
+ if sorted_vals == [0, 1]:
+ return True
+
+ # Check if there are only two values that could be easily mapped
+ return len(unique_values) == 2
+
+
+def calculate_standardized_differences(df: pd.DataFrame, treatment: str, covariates: List[str]) -> Dict[str, float]:
+ """
+ Calculate standardized differences between treated and control groups.
+
+ Args:
+ df: DataFrame containing the data
+ treatment: Name of treatment variable
+ covariates: List of covariate variable names
+
+ Returns:
+ Dictionary with standardized differences for each covariate
+ """
+ treated = df[df[treatment] == 1]
+ control = df[df[treatment] == 0]
+
+ std_diffs = {}
+
+ for cov in covariates:
+ # Skip if covariate has missing values
+ if df[cov].isna().any():
+ std_diffs[cov] = np.nan
+ continue
+
+ t_mean = treated[cov].mean()
+ c_mean = control[cov].mean()
+
+ t_var = treated[cov].var()
+ c_var = control[cov].var()
+
+ # Pooled standard deviation
+ pooled_std = np.sqrt((t_var + c_var) / 2)
+
+ # Avoid division by zero
+ if pooled_std == 0:
+ std_diffs[cov] = 0
+ else:
+ std_diffs[cov] = (t_mean - c_mean) / pooled_std
+
+ return std_diffs
+
+
+def check_overlap(df: pd.DataFrame, treatment: str, propensity_scores: np.ndarray,
+ threshold: float = 0.5) -> Dict[str, Any]:
+ """
+ Check overlap in propensity scores between treated and control groups.
+
+ Args:
+ df: DataFrame containing the data
+ treatment: Name of treatment variable
+ propensity_scores: Array of propensity scores
+ threshold: Threshold for sufficient overlap (proportion of range)
+
+ Returns:
+ Dictionary with overlap statistics
+ """
+ df_copy = df.copy()
+ df_copy['propensity_score'] = propensity_scores
+
+ treated = df_copy[df_copy[treatment] == 1]['propensity_score']
+ control = df_copy[df_copy[treatment] == 0]['propensity_score']
+
+ min_treated = treated.min()
+ max_treated = treated.max()
+ min_control = control.min()
+ max_control = control.max()
+
+ overall_min = min(min_treated, min_control)
+ overall_max = max(max_treated, max_control)
+
+ # Range of overlap
+ overlap_min = max(min_treated, min_control)
+ overlap_max = min(max_treated, max_control)
+
+ # Check if there is any overlap
+ if overlap_max < overlap_min:
+ overlap_proportion = 0
+ sufficient_overlap = False
+ else:
+ # Calculate proportion of overall range that has overlap
+ overall_range = overall_max - overall_min
+ if overall_range == 0:
+ # All values are the same
+ overlap_proportion = 1.0
+ sufficient_overlap = True
+ else:
+ overlap_proportion = (overlap_max - overlap_min) / overall_range
+ sufficient_overlap = overlap_proportion >= threshold
+
+ return {
+ "treated_range": (float(min_treated), float(max_treated)),
+ "control_range": (float(min_control), float(max_control)),
+ "overlap_range": (float(overlap_min), float(overlap_max)),
+ "overlap_proportion": float(overlap_proportion),
+ "sufficient_overlap": sufficient_overlap
+ }
+
+
+def plot_propensity_overlap(df: pd.DataFrame, treatment: str, propensity_scores: np.ndarray,
+ save_path: Optional[str] = None) -> None:
+ """
+ Plot overlap in propensity scores.
+
+ Args:
+ df: DataFrame containing the data
+ treatment: Name of treatment variable
+ propensity_scores: Array of propensity scores
+ save_path: Optional path to save the plot
+ """
+ df_copy = df.copy()
+ df_copy['propensity_score'] = propensity_scores
+
+ plt.figure(figsize=(10, 6))
+
+ # Plot histograms
+ sns.histplot(df_copy.loc[df_copy[treatment] == 1, 'propensity_score'],
+ bins=20, alpha=0.5, label='Treated', color='blue', kde=True)
+ sns.histplot(df_copy.loc[df_copy[treatment] == 0, 'propensity_score'],
+ bins=20, alpha=0.5, label='Control', color='red', kde=True)
+
+ plt.title('Propensity Score Distributions')
+ plt.xlabel('Propensity Score')
+ plt.ylabel('Count')
+ plt.legend()
+
+ if save_path:
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
+
+ plt.show()
+
+
+def plot_covariate_balance(standardized_diffs: Dict[str, float], threshold: float = 0.1,
+ save_path: Optional[str] = None) -> None:
+ """
+ Plot standardized differences for covariates before and after matching.
+
+ Args:
+ standardized_diffs: Dictionary with standardized differences
+ threshold: Threshold for acceptable balance
+ save_path: Optional path to save the plot
+ """
+ # Convert to DataFrame for plotting
+ df = pd.DataFrame({
+ 'Covariate': list(standardized_diffs.keys()),
+ 'Standardized Difference': list(standardized_diffs.values())
+ })
+
+ # Sort by absolute standardized difference
+ df['Absolute Difference'] = np.abs(df['Standardized Difference'])
+ df = df.sort_values('Absolute Difference', ascending=False)
+
+ plt.figure(figsize=(12, len(standardized_diffs) * 0.4 + 2))
+
+ # Plot horizontal bars
+ ax = sns.barplot(x='Standardized Difference', y='Covariate', data=df,
+ palette=['red' if abs(x) > threshold else 'green' for x in df['Standardized Difference']])
+
+ # Add vertical lines for thresholds
+ plt.axvline(x=threshold, color='red', linestyle='--', alpha=0.7)
+ plt.axvline(x=-threshold, color='red', linestyle='--', alpha=0.7)
+ plt.axvline(x=0, color='black', linestyle='-', alpha=0.7)
+
+ plt.title('Covariate Balance: Standardized Differences')
+ plt.xlabel('Standardized Difference')
+ plt.tight_layout()
+
+ if save_path:
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
+
+ plt.show()
+
+
+def check_temporal_structure(df: pd.DataFrame) -> Dict[str, Any]:
+ """
+ Check if dataset has temporal structure.
+
+ Args:
+ df: DataFrame to check
+
+ Returns:
+ Dictionary with temporal structure information
+ """
+ # Check for date/time columns
+ date_cols = []
+
+ for col in df.columns:
+ # Check if column has date in name
+ if any(date_term in col.lower() for date_term in ['date', 'time', 'year', 'month', 'day', 'period']):
+ date_cols.append(col)
+
+ # Check if column can be converted to datetime
+ if df[col].dtype == 'object':
+ try:
+ pd.to_datetime(df[col], errors='raise')
+ date_cols.append(col)
+ except:
+ pass
+
+ # Check for panel structure - look for ID columns
+ id_cols = []
+
+ for col in df.columns:
+ # Check if column has ID in name
+ if any(id_term in col.lower() for id_term in ['id', 'identifier', 'key', 'code']):
+ unique_count = df[col].nunique()
+ # If column has multiple values but fewer than 10% of rows, likely an ID
+ if 1 < unique_count < len(df) * 0.1:
+ id_cols.append(col)
+
+ # Check if there are multiple observations per unit
+ is_panel = False
+ panel_units = None
+
+ if id_cols and date_cols:
+ # For each ID column, check if there are multiple time periods
+ for id_col in id_cols:
+ obs_per_id = df.groupby(id_col).size()
+ if (obs_per_id > 1).any():
+ is_panel = True
+ panel_units = id_col
+ break
+
+ return {
+ "has_temporal_structure": len(date_cols) > 0,
+ "temporal_columns": date_cols,
+ "potential_id_columns": id_cols,
+ "is_panel_data": is_panel,
+ "panel_units": panel_units
+ }
+
+
+def check_for_discontinuities(df: pd.DataFrame, outcome: str,
+ threshold_zscore: float = 3.0) -> Dict[str, Any]:
+ """
+ Check for potential discontinuities in continuous variables.
+
+ Args:
+ df: DataFrame to check
+ outcome: Name of outcome variable
+ threshold_zscore: Z-score threshold for detecting discontinuities
+
+ Returns:
+ Dictionary with discontinuity information
+ """
+ potential_running_vars = []
+
+ # Check only numeric columns that aren't the outcome
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
+ numeric_cols = [col for col in numeric_cols if col != outcome]
+
+ for col in numeric_cols:
+ # Skip if too many unique values (unlikely to be a running variable)
+ if df[col].nunique() > 100:
+ continue
+
+ # Sort values and calculate differences
+ sorted_vals = np.sort(df[col].unique())
+ if len(sorted_vals) <= 1:
+ continue
+
+ diffs = np.diff(sorted_vals)
+ mean_diff = np.mean(diffs)
+ std_diff = np.std(diffs)
+
+ # Skip if all differences are the same
+ if std_diff == 0:
+ continue
+
+ # Calculate z-scores of differences
+ zscores = (diffs - mean_diff) / std_diff
+
+ # Check if any z-score exceeds threshold
+ if np.any(np.abs(zscores) > threshold_zscore):
+ # Potential discontinuity found
+ max_idx = np.argmax(np.abs(zscores))
+ threshold = (sorted_vals[max_idx] + sorted_vals[max_idx + 1]) / 2
+
+ # Check if outcome means differ across threshold
+ below_mean = df[df[col] < threshold][outcome].mean()
+ above_mean = df[df[col] >= threshold][outcome].mean()
+
+ # Only include if outcome means differ substantially
+ if abs(above_mean - below_mean) > 0.1 * df[outcome].std():
+ potential_running_vars.append({
+ "variable": col,
+ "threshold": float(threshold),
+ "z_score": float(zscores[max_idx]),
+ "outcome_diff": float(above_mean - below_mean)
+ })
+
+ return {
+ "has_discontinuities": len(potential_running_vars) > 0,
+ "potential_running_variables": potential_running_vars
+ }
+
+
+def find_potential_instruments(df: pd.DataFrame, treatment: str, outcome: str,
+ correlation_threshold: float = 0.3) -> Dict[str, Any]:
+ """
+ Find potential instrumental variables.
+
+ Args:
+ df: DataFrame to check
+ treatment: Name of treatment variable
+ outcome: Name of outcome variable
+ correlation_threshold: Threshold for correlation with treatment
+
+ Returns:
+ Dictionary with potential instruments information
+ """
+ # Get numeric columns that aren't treatment or outcome
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
+ potential_ivs = [col for col in numeric_cols if col != treatment and col != outcome]
+
+ iv_results = []
+
+ for col in potential_ivs:
+ # Skip if column has too many missing values
+ if df[col].isna().mean() > 0.1:
+ continue
+
+ # Check correlation with treatment
+ corr_treatment = df[[col, treatment]].corr().iloc[0, 1]
+
+ # Check correlation with outcome
+ corr_outcome = df[[col, outcome]].corr().iloc[0, 1]
+
+ # Potential IV should be correlated with treatment but not directly with outcome
+ if abs(corr_treatment) > correlation_threshold and abs(corr_outcome) < correlation_threshold/2:
+ iv_results.append({
+ "variable": col,
+ "correlation_with_treatment": float(corr_treatment),
+ "correlation_with_outcome": float(corr_outcome),
+ "strength": "Strong" if abs(corr_treatment) > 0.5 else "Moderate"
+ })
+
+ return {
+ "has_potential_instruments": len(iv_results) > 0,
+ "potential_instruments": iv_results
+ }
+
+
+def test_parallel_trends(df: pd.DataFrame, treatment: str, outcome: str,
+ time_var: str, unit_var: str) -> Dict[str, Any]:
+ """
+ Test for parallel trends assumption in difference-in-differences.
+
+ Args:
+ df: DataFrame to check
+ treatment: Name of treatment variable
+ outcome: Name of outcome variable
+ time_var: Name of time variable
+ unit_var: Name of unit variable
+
+ Returns:
+ Dictionary with parallel trends test results
+ """
+ # Ensure time_var is properly formatted
+ df = df.copy()
+
+ if df[time_var].dtype != 'int64':
+ # Try to convert to datetime and then to period
+ try:
+ df[time_var] = pd.to_datetime(df[time_var])
+ # Get unique periods and map to integers
+ periods = df[time_var].dt.to_period('M').unique()
+ period_dict = {p: i for i, p in enumerate(sorted(periods))}
+ df['time_period'] = df[time_var].dt.to_period('M').map(period_dict)
+ time_var = 'time_period'
+ except:
+ # If conversion fails, try to map unique values to integers
+ unique_times = df[time_var].unique()
+ time_dict = {t: i for i, t in enumerate(sorted(unique_times))}
+ df['time_period'] = df[time_var].map(time_dict)
+ time_var = 'time_period'
+
+ # Identify treatment and control groups
+ # Treatment indicator should be 0 or 1 for each unit (not time-varying)
+ unit_treatment = df.groupby(unit_var)[treatment].max()
+ treatment_units = unit_treatment[unit_treatment == 1].index
+ control_units = unit_treatment[unit_treatment == 0].index
+
+ # Find time of treatment implementation
+ if len(treatment_units) > 0:
+ treatment_time = df[df[unit_var].isin(treatment_units) & (df[treatment] == 1)][time_var].min()
+ else:
+ # No treated units found
+ return {
+ "parallel_trends": False,
+ "reason": "No treated units found",
+ "pre_trend_correlation": None,
+ "pre_trend_p_value": None
+ }
+
+ # Select pre-treatment periods
+ pre_treatment = df[df[time_var] < treatment_time]
+
+ # Calculate average outcome by time and group
+ treated_means = pre_treatment[pre_treatment[unit_var].isin(treatment_units)].groupby(time_var)[outcome].mean()
+ control_means = pre_treatment[pre_treatment[unit_var].isin(control_units)].groupby(time_var)[outcome].mean()
+
+ # Need enough pre-treatment periods to test
+ if len(treated_means) < 3:
+ return {
+ "parallel_trends": None,
+ "reason": "Insufficient pre-treatment periods",
+ "pre_trend_correlation": None,
+ "pre_trend_p_value": None
+ }
+
+ # Align indices and calculate trends
+ common_periods = sorted(set(treated_means.index).intersection(set(control_means.index)))
+
+ if len(common_periods) < 3:
+ return {
+ "parallel_trends": None,
+ "reason": "Insufficient common pre-treatment periods",
+ "pre_trend_correlation": None,
+ "pre_trend_p_value": None
+ }
+
+ treated_trends = np.diff(treated_means[common_periods])
+ control_trends = np.diff(control_means[common_periods])
+
+ # Calculate correlation between trends
+ correlation, p_value = stats.pearsonr(treated_trends, control_trends)
+
+ # Test if trends are parallel (high correlation, not significantly different)
+ parallel_trends = correlation > 0.7 and p_value < 0.05
+
+ return {
+ "parallel_trends": parallel_trends,
+ "reason": "Trends are parallel" if parallel_trends else "Trends are not parallel",
+ "pre_trend_correlation": float(correlation),
+ "pre_trend_p_value": float(p_value)
+ }
+
+
+def preprocess_data(df: pd.DataFrame, treatment_var: str, outcome_var: str,
+ covariates: List[str], verbose: bool = True) -> pd.DataFrame:
+ """
+ Preprocess the dataset to handle missing values and encode categorical variables.
+
+ Args:
+ df (pd.DataFrame): The dataset
+ treatment_var (str): The treatment variable name
+ outcome_var (str): The outcome variable name
+ covariates (list): List of covariate variable names
+ verbose (bool): Whether to print verbose output
+
+ Returns:
+ Tuple[pd.DataFrame, str, str, List[str], Dict[str, Any]]:
+ Preprocessed dataset, updated treatment var name,
+ updated outcome var name, updated covariates list,
+ and column mappings.
+ """
+ df_processed = df.copy()
+ column_mappings: Dict[str, Any] = {}
+
+ # Store original dtypes for mapping
+ original_dtypes = {col: str(df_processed[col].dtype) for col in df_processed.columns}
+
+ # Report missing values
+ all_vars = [treatment_var, outcome_var] + covariates
+ missing_data = df_processed[all_vars].isnull().sum()
+ total_missing = missing_data.sum()
+
+ if total_missing > 0:
+ if verbose:
+ logger.info(f"Dataset contains {total_missing} missing values:")
+ for col in missing_data[missing_data > 0].index:
+ percent = (missing_data[col] / len(df_processed)) * 100
+ if verbose:
+ logger.info(f" - {col}: {missing_data[col]} missing values ({percent:.2f}%)")
+ else:
+ if verbose:
+ logger.info("No missing values found in relevant columns.")
+ # return df_processed # No preprocessing needed if no missing values
+
+ # Handle missing values in treatment variable
+ if df_processed[treatment_var].isnull().sum() > 0:
+ if verbose:
+ logger.info(f"Filling missing values in treatment variable '{treatment_var}' with mode")
+ # For treatment, use mode (most common value)
+ mode_val = df_processed[treatment_var].mode()[0] if not df_processed[treatment_var].mode().empty else 0
+ df_processed[treatment_var] = df_processed[treatment_var].fillna(mode_val)
+
+ # Handle missing values in outcome variable
+ if df_processed[outcome_var].isnull().sum() > 0:
+ if verbose:
+ logger.info(f"Filling missing values in outcome variable '{outcome_var}' with mean")
+ # For outcome, use mean
+ mean_val = df_processed[outcome_var].mean()
+ df_processed[outcome_var] = df_processed[outcome_var].fillna(mean_val)
+
+ # Handle missing values in covariates
+ for col in covariates:
+ if df_processed[col].isnull().sum() > 0:
+ if pd.api.types.is_numeric_dtype(df_processed[col]):
+ # For numeric covariates, use mean
+ if verbose:
+ logger.info(f"Filling missing values in numeric covariate '{col}' with mean")
+ mean_val = df_processed[col].mean()
+ df_processed[col] = df_processed[col].fillna(mean_val)
+ elif pd.api.types.is_categorical_dtype(df_processed[col]) or df_processed[col].dtype == 'object':
+ # For categorical covariates, use mode
+ mode_val = df_processed[col].mode()[0] if not df_processed[col].mode().empty else "Missing"
+ if verbose:
+ logger.info(f"Filling missing values in categorical covariate '{col}' with mode ('{mode_val}')")
+ df_processed[col] = df_processed[col].fillna(mode_val)
+ else:
+ # For other types, create a "Missing" category
+ if verbose:
+ logger.info(f"Filling missing values in covariate '{col}' of type {df_processed[col].dtype} with 'Missing' category")
+ # Ensure the column is of object type before filling with string
+ if df_processed[col].dtype != 'object':
+ try:
+ df_processed[col] = df_processed[col].astype(object)
+ except Exception as e:
+ logger.warning(f"Could not convert column {col} to object type to fill NAs: {e}. Skipping fill.")
+ continue
+ df_processed[col] = df_processed[col].fillna("Missing")
+
+ # --- Categorical Encoding ---
+ updated_treatment_var = treatment_var
+ updated_outcome_var = outcome_var
+
+ # Helper function for label encoding binary categoricals
+ def label_encode_binary(series: pd.Series, var_name: str) -> Tuple[pd.Series, Dict[int, Any]]:
+ uniques = series.dropna().unique()
+ mapping = {}
+ if len(uniques) == 2:
+ # Try to map to 0 and 1 consistently, e.g., sort and assign
+ # Or if boolean, map True to 1, False to 0
+ if series.dtype == 'bool':
+ mapping = {0: False, 1: True}
+ return series.astype(int), mapping
+
+ # For non-boolean, sort to ensure consistent mapping
+ # However, direct replacement is safer to control which becomes 0 and 1
+ # For simplicity here, we'll make a simple map.
+ # A more robust approach might involve explicit mapping rules or user input.
+ sorted_uniques = sorted(uniques, key=lambda x: str(x)) # sort to make it deterministic
+ map_dict = {sorted_uniques[0]: 0, sorted_uniques[1]: 1}
+ mapping = {v: k for k, v in map_dict.items()} # Inverse map for column_mappings
+ if verbose:
+ logger.info(f"Label encoding binary variable '{var_name}': {map_dict}")
+ return series.map(map_dict), mapping
+ elif len(uniques) == 1: # Single unique value, treat as constant (encode as 0)
+ if verbose:
+ logger.info(f"Binary variable '{var_name}' has only one unique value '{uniques[0]}'. Encoding as 0.")
+ map_dict = {uniques[0]:0}
+ mapping = {0: uniques[0]}
+ return series.map(map_dict), mapping
+ return series, mapping # No change if not binary
+
+ # Encode Treatment Variable
+ if df_processed[treatment_var].dtype == 'object' or df_processed[treatment_var].dtype == 'category' or df_processed[treatment_var].dtype == 'bool':
+ original_series = df_processed[treatment_var].copy()
+ df_processed[treatment_var], value_map = label_encode_binary(df_processed[treatment_var], treatment_var)
+ if value_map: # If encoding happened
+ column_mappings[treatment_var] = {
+ 'original_dtype': original_dtypes[treatment_var],
+ 'transformed_as': 'label_encoded_binary',
+ 'new_column_name': treatment_var, # Name doesn't change
+ 'value_map': value_map
+ }
+ if verbose:
+ logger.info(f"Encoded treatment variable '{treatment_var}' to numeric.")
+
+ # Encode Outcome Variable
+ if df_processed[outcome_var].dtype == 'object' or df_processed[outcome_var].dtype == 'category' or df_processed[outcome_var].dtype == 'bool':
+ original_series = df_processed[outcome_var].copy()
+ df_processed[outcome_var], value_map = label_encode_binary(df_processed[outcome_var], outcome_var)
+ if value_map: # If encoding happened
+ column_mappings[outcome_var] = {
+ 'original_dtype': original_dtypes[outcome_var],
+ 'transformed_as': 'label_encoded_binary',
+ 'new_column_name': outcome_var, # Name doesn't change
+ 'value_map': value_map
+ }
+ if verbose:
+ logger.info(f"Encoded outcome variable '{outcome_var}' to numeric.")
+
+ # Encode Covariates (One-Hot Encoding for non-numeric)
+ updated_covariates = []
+ categorical_covariates_to_encode = []
+ for cov in covariates:
+ if cov not in df_processed.columns: # If a covariate was dropped or is an instrument etc.
+ if verbose:
+ logger.warning(f"Covariate '{cov}' not found in DataFrame columns after initial processing. Skipping encoding for it.")
+ continue
+
+ if df_processed[cov].dtype == 'object' or df_processed[cov].dtype == 'category' or pd.api.types.is_bool_dtype(df_processed[cov]):
+ # Check if it's binary - if so, can also label encode
+ # However, for consistency with get_dummies and to handle multi-category,
+ # we'll let get_dummies handle it, or apply label encoding for binary covariates too.
+ # For simplicity, let's stick to one-hot for all categorical covariates.
+ if len(df_processed[cov].dropna().unique()) > 1 : # Only encode if more than 1 unique value
+ categorical_covariates_to_encode.append(cov)
+ else: # If only one unique value or all NaNs (already handled), it's constant-like
+ if verbose:
+ logger.info(f"Categorical covariate '{cov}' has <= 1 unique value after NA handling. Treating as constant-like, not one-hot encoding.")
+ updated_covariates.append(cov) # Keep as is, will likely be numeric 0 or some constant
+ else: # Already numeric
+ updated_covariates.append(cov)
+
+ if categorical_covariates_to_encode:
+ if verbose:
+ logger.info(f"One-hot encoding categorical covariates: {categorical_covariates_to_encode} using pd.get_dummies (drop_first=True)")
+
+ # Store original columns before get_dummies to identify new ones
+ original_df_columns = set(df_processed.columns)
+
+ df_processed = pd.get_dummies(df_processed, columns=categorical_covariates_to_encode,
+ prefix_sep='_', drop_first=True, dummy_na=False) # dummy_na=False since we handled NAs
+
+ # Identify new columns created by get_dummies
+ new_dummy_columns = list(set(df_processed.columns) - original_df_columns)
+ updated_covariates.extend(new_dummy_columns)
+
+ for original_cov_name in categorical_covariates_to_encode:
+ # Find which dummy columns correspond to this original covariate
+ related_dummies = [col for col in new_dummy_columns if col.startswith(original_cov_name + '_')]
+ column_mappings[original_cov_name] = {
+ 'original_dtype': original_dtypes[original_cov_name],
+ 'transformed_as': 'one_hot_encoded',
+ 'encoded_columns': related_dummies,
+ # 'dropped_category': can be inferred if needed, but not explicitly stored for simplicity here
+ }
+ if verbose:
+ logger.info(f" Original covariate '{original_cov_name}' resulted in dummy variables: {related_dummies}")
+
+ if verbose:
+ logger.info("Preprocessing complete.")
+ if column_mappings:
+ logger.info(f"Column mappings generated: {column_mappings}")
+ else:
+ logger.info("No column encodings were applied.")
+
+ return df_processed, updated_treatment_var, updated_outcome_var, list(dict.fromkeys(updated_covariates)), column_mappings
+
+
+def check_collinearity(df: pd.DataFrame, covariates: List[str]) -> Optional[List[str]]:
+ # Implementation of check_collinearity function
+ # This function should return a list of collinear variables or None
+ pass
\ No newline at end of file
diff --git a/auto_causal/models.py b/auto_causal/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5f5cba2114c5143e100639b950bc63665303231
--- /dev/null
+++ b/auto_causal/models.py
@@ -0,0 +1,235 @@
+from typing import List, Optional, Union, Dict, Any, Tuple
+from pydantic import BaseModel, Field, validator
+import json
+
+# --- Pydantic models for LLM structured output ---
+# These models are used by query_interpreter and potentially other components
+# to structure the output received from Language Models.
+
+class LLMSelectedVariable(BaseModel):
+ """Pydantic model for selecting a single variable."""
+ variable_name: Optional[str] = Field(None, description="The single best column name selected.")
+
+class LLMSelectedCovariates(BaseModel):
+ """Pydantic model for selecting a list of covariates."""
+ covariates: List[str] = Field(default_factory=list, description="The list of selected covariate column names.")
+
+class LLMIVars(BaseModel):
+ """Pydantic model for identifying IVs."""
+ instrument_variable: Optional[str] = Field(None, description="The identified instrumental variable column name.")
+
+class LLMEstimand(BaseModel):
+ """Pydantic model for identifying estimand"""
+ estimand: Optional[str] = Field(None, description="The identified estimand")
+
+class LLMRDDVars(BaseModel):
+ """Pydantic model for identifying RDD variables."""
+ running_variable: Optional[str] = Field(None, description="The identified running variable column name.")
+ cutoff_value: Optional[Union[float, int]] = Field(None, description="The identified cutoff value.")
+
+class LLMRCTCheck(BaseModel):
+ """Pydantic model for checking if data is RCT."""
+ is_rct: Optional[bool] = Field(None, description="True if the data is from a randomized controlled trial, False otherwise, None if unsure.")
+ reasoning: Optional[str] = Field(None, description="Brief reasoning for the RCT conclusion.")
+
+class LLMTreatmentReferenceLevel(BaseModel):
+ reference_level: Optional[str] = Field(None, description="The identified reference/control level for the treatment variable, if specified in the query. Should be one of the actual values in the treatment column.")
+ reasoning: Optional[str] = Field(None, description="Brief reasoning for identifying this reference level.")
+
+
+class LLMInteractionSuggestion(BaseModel):
+ """Pydantic model for LLM suggestion on interaction terms."""
+ interaction_needed: Optional[bool] = Field(None, description="True if an interaction term is strongly suggested by the query or context. LLM should provide true, false, or omit for None.")
+ interaction_variable: Optional[str] = Field(None, description="The name of the covariate that should interact with the treatment. Null if not applicable or if the interaction is complex/multiple.")
+ reasoning: Optional[str] = Field(None, description="Brief reasoning for the suggestion for or against an interaction term.")
+
+# --- Pydantic models for Tool Inputs/Outputs and Data Structures ---
+
+class TemporalStructure(BaseModel):
+ """Represents detected temporal structure in the data."""
+ has_temporal_structure: bool
+ temporal_columns: List[str]
+ is_panel_data: bool
+ id_column: Optional[str] = None
+ time_column: Optional[str] = None
+ time_periods: Optional[int] = None
+ units: Optional[int] = None
+
+class DatasetInfo(BaseModel):
+ """Basic information about the dataset file."""
+ num_rows: int
+ num_columns: int
+ file_path: str
+ file_name: str
+
+class DatasetAnalysis(BaseModel):
+ """Results from the dataset analysis component."""
+ dataset_info: DatasetInfo
+ columns: List[str]
+ potential_treatments: List[str]
+ potential_outcomes: List[str]
+ temporal_structure_detected: bool
+ panel_data_detected: bool
+ potential_instruments_detected: bool
+ discontinuities_detected: bool
+ temporal_structure: TemporalStructure
+ column_categories: Optional[Dict[str, str]] = None
+ column_nunique_counts: Optional[Dict[str, int]] = None
+ sample_size: int
+ num_covariates_estimate: int
+ per_group_summary_stats: Optional[Dict[str, Dict[str, Any]]] = None
+ potential_instruments: Optional[List[str]] = None
+ overlap_assessment: Optional[Dict[str, Any]] = None
+
+# --- Model for Dataset Analyzer Tool Output ---
+
+class DatasetAnalyzerOutput(BaseModel):
+ """Structured output for the dataset analyzer tool."""
+ analysis_results: DatasetAnalysis
+ dataset_description: Optional[str] = None
+ workflow_state: Dict[str, Any]
+
+#TODO make query info consistent with the Data analysis out put
+class QueryInfo(BaseModel):
+ """Information extracted from the user's initial query."""
+ query_text: str
+ potential_treatments: Optional[List[str]] = None
+ potential_outcomes: Optional[List[str]] = None
+ covariates_hints: Optional[List[str]] = None
+ instrument_hints: Optional[List[str]] = None
+ running_variable_hints: Optional[List[str]] = None
+ cutoff_value_hint: Optional[Union[float, int]] = None
+
+class QueryInterpreterInput(BaseModel):
+ """Input structure for the query interpreter tool."""
+ query_info: QueryInfo
+ dataset_analysis: DatasetAnalysis
+ dataset_description: str
+ # Add original_query if it should be part of the standard input
+ original_query: Optional[str] = None
+
+class Variables(BaseModel):
+ """Structured variables identified by the query interpreter component."""
+ treatment_variable: Optional[str] = None
+ treatment_variable_type: Optional[str] = Field(None, description="Type of the treatment variable (e.g., 'binary', 'continuous', 'categorical_multi_value')")
+ outcome_variable: Optional[str] = None
+ instrument_variable: Optional[str] = None
+ covariates: Optional[List[str]] = Field(default_factory=list)
+ time_variable: Optional[str] = None
+ group_variable: Optional[str] = None # Often the unit ID
+ running_variable: Optional[str] = None
+ cutoff_value: Optional[Union[float, int]] = None
+ is_rct: Optional[bool] = Field(False, description="Flag indicating if the dataset is from an RCT.")
+ treatment_reference_level: Optional[str] = Field(None, description="The specified reference/control level for a multi-valued treatment variable.")
+ interaction_term_suggested: Optional[bool] = Field(False, description="Whether the query or context suggests an interaction term with the treatment might be relevant.")
+ interaction_variable_candidate: Optional[str] = Field(None, description="The covariate identified as a candidate for interaction with the treatment.")
+
+class QueryInterpreterOutput(BaseModel):
+ """Structured output for the query interpreter tool."""
+ variables: Variables
+ dataset_analysis: DatasetAnalysis
+ dataset_description: Optional[str]
+ workflow_state: Dict[str, Any]
+ original_query: Optional[str] = None
+
+# Input model for Method Selector Tool
+class MethodSelectorInput(BaseModel):
+ """Input structure for the method selector tool."""
+ variables: Variables# Uses the Variables model identified by QueryInterpreter
+ dataset_analysis: DatasetAnalysis # Uses the DatasetAnalysis model
+ dataset_description: Optional[str] = None
+ original_query: Optional[str] = None
+ # Note: is_rct is expected inside inputs.variables
+
+# --- Models for Method Validator Tool ---
+
+class MethodInfo(BaseModel):
+ """Information about the selected causal inference method."""
+ selected_method: Optional[str] = None
+ method_name: Optional[str] = None # Often a title-cased version for display
+ method_justification: Optional[str] = None
+ method_assumptions: Optional[List[str]] = Field(default_factory=list)
+ # Add alternative methods if it should be part of the standard info passed around
+ alternative_methods: Optional[List[str]] = Field(default_factory=list)
+
+class MethodValidatorInput(BaseModel):
+ """Input structure for the method validator tool."""
+ method_info: MethodInfo
+ variables: Variables
+ dataset_analysis: DatasetAnalysis
+ dataset_description: Optional[str] = None
+ original_query: Optional[str] = None
+
+# --- Model for Method Executor Tool ---
+
+class MethodExecutorInput(BaseModel):
+ """Input structure for the method executor tool."""
+ method: str = Field(..., description="The causal method name (use recommended method if validation failed).")
+ variables: Variables # Contains T, O, C, etc.
+ dataset_path: str
+ dataset_analysis: DatasetAnalysis
+ dataset_description: Optional[str] = None
+ # Include validation_info from validator output if needed by estimator or LLM assist later?
+ validation_info: Optional[Any] = None
+ original_query: Optional[str] = None
+# --- Model for Explanation Generator Tool ---
+
+class ExplainerInput(BaseModel):
+ """Input structure for the explanation generator tool."""
+ # Based on expected output from method_executor_tool and validator
+ method_info: MethodInfo
+ validation_info: Optional[Dict[str, Any]] = None # From validator tool
+ variables: Variables
+ results: Dict[str, Any] # Numerical results from executor
+ dataset_analysis: DatasetAnalysis
+ dataset_description: Optional[str] = None
+ # Add original query if needed for explanation context
+ original_query: Optional[str] = None
+
+# Add other shared models/schemas below as needed.
+
+class FormattedOutput(BaseModel):
+ """
+ Structured output containing the final formatted results and explanations
+ from a causal analysis run.
+ """
+ query: str = Field(description="The original user query.")
+ method_used: str = Field(description="The user-friendly name of the causal inference method used.")
+ causal_effect: Optional[float] = Field(None, description="The point estimate of the causal effect.")
+ standard_error: Optional[float] = Field(None, description="The standard error of the causal effect estimate.")
+ confidence_interval: Optional[Tuple[Optional[float], Optional[float]]] = Field(None, description="The confidence interval for the causal effect (e.g., 95% CI).")
+ p_value: Optional[float] = Field(None, description="The p-value associated with the causal effect estimate.")
+ summary: str = Field(description="A concise summary paragraph interpreting the main findings.")
+ method_explanation: Optional[str] = Field("", description="Explanation of the causal inference method used.")
+ interpretation_guide: Optional[str] = Field("", description="Guidance on how to interpret the results.")
+ limitations: Optional[List[str]] = Field(default_factory=list, description="List of limitations or potential issues with the analysis.")
+ assumptions: Optional[str] = Field("", description="Discussion of the key assumptions underlying the method and their validity.")
+ practical_implications: Optional[str] = Field("", description="Discussion of the practical implications or significance of the findings.")
+ # Optionally add dataset_analysis and dataset_description if they should be part of the final structure
+ # dataset_analysis: Optional[DatasetAnalysis] = None # Example if using DatasetAnalysis model
+ # dataset_description: Optional[str] = None
+
+ # This model itself doesn't include workflow_state, as it represents the *content*
+ # The tool using this component will add the workflow_state separately.
+
+class LLMParameterDetails(BaseModel):
+ parameter_name: str = Field(description="The full parameter name as found in the model results.")
+ estimate: float
+ p_value: float
+ conf_int_low: float
+ conf_int_high: float
+ std_err: float
+ reasoning: Optional[str] = Field(None, description="Brief reasoning for selecting this parameter and its values.")
+
+class LLMTreatmentEffectResults(BaseModel):
+ effects: Optional[Dict[str, LLMParameterDetails]] = Field(description="Dictionary where keys are treatment level names (e.g., 'LevelA', 'LevelB' if multi-level) or a generic key like 'treatment_effect' for binary/continuous treatments. Values are the statistical details for that effect.")
+ all_parameters_successfully_identified: Optional[bool] = Field(description="True if all expected treatment effect parameters were identified and their values extracted, False otherwise.")
+ overall_reasoning: Optional[str] = Field(None, description="Overall reasoning for the extraction process or if issues were encountered.")
+
+class RelevantParamInfo(BaseModel):
+ param_name: str = Field(description="The exact parameter name as it appears in the statsmodels results.")
+ param_index: int = Field(description="The index of this parameter in the original list of parameter names.")
+
+class LLMIdentifiedRelevantParams(BaseModel):
+ identified_params: List[RelevantParamInfo] = Field(description="A list of parameters identified as relevant to the query or representing all treatment effects for a general query.")
+ all_parameters_successfully_identified: bool = Field(description="True if LLM is confident it identified all necessary params based on query type (e.g., all levels for a general query).")
diff --git a/auto_causal/preprocess/json.py b/auto_causal/preprocess/json.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6b233d4847c591796000d4b1b3490fd7401a976
--- /dev/null
+++ b/auto_causal/preprocess/json.py
@@ -0,0 +1,58 @@
+import pandas as pd
+import json
+import argparse
+from pathlib import Path
+from typing import List
+
+##TODO: later add logs
+
+
+def create_json(csv_file_loc:str, output_folder:str, output_file_name:str,
+ base_data_folder:str, data_attribute:str="data_files") -> List:
+ """
+ Creates a json file containing the causal query and its associated metadata from
+ the csv file
+
+ Args:
+ csv_file_loc: path to the csv file
+ output_folder: path to the folder where the json file is saved
+ output_file_name: name of the output json file
+ base_data_folder: path to the folder where the data is saved
+ data_attribute: name of the column in the csv file containing the data file name
+ """
+
+ try:
+ df = pd.read_csv(csv_file_loc)
+ except FileNotFoundError:
+ print(f"File not found:{csv_file_loc}. Make sure the file path is correct.")
+
+ json_df = df.to_dict(orient="records")
+
+ print("Checking if referenced csv files are available")
+ all_exists = True
+ for data in json_df:
+ #print(base_data_folder, data[data_attribute])
+ full_path = Path(base_data_folder) / data[data_attribute]
+ if not full_path.exists():
+ print(f"File not found: {full_path}. Re-check the name of the data file.")
+ all_exists = False & all_exists
+ else:
+ data[data_attribute] = str(full_path)
+
+ if not all_exists:
+ print("Some data files are missing or incorrectly name")
+ else:
+ print("All data files are available. Good to go.")
+
+ if ".json" not in output_file_name:
+ output_file_name = output_file_name + ".json"
+
+ output_path = Path(output_folder)
+ output_path.mkdir(parents=True, exist_ok=True)
+ output_file_path = output_path / output_file_name
+ with open(output_file_path, "w") as f:
+ json.dump(json_df, f, indent=4)
+ print(f"Json file created at {output_file_path}")
+ f.close()
+
+ return json_df
\ No newline at end of file
diff --git a/auto_causal/prompts/__init__.py b/auto_causal/prompts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..45312243c887874e6f2e936ffe3bbe8965ab927a
--- /dev/null
+++ b/auto_causal/prompts/__init__.py
@@ -0,0 +1,21 @@
+from .method_identification_prompts import (
+ IV_IDENTIFICATION_PROMPT_TEMPLATE,
+ RDD_IDENTIFICATION_PROMPT_TEMPLATE,
+ RCT_IDENTIFICATION_PROMPT_TEMPLATE
+)
+
+from .dataset_analysis_prompts import (
+ INSTRUMENT_IDENTIFICATION_PROMPT,
+ OVERLAP_ASSESSMENT_PROMPT
+)
+
+from .result_interpretation_prompts import QUERY_SPECIFIC_INTERPRETATION_PROMPT_TEMPLATE
+
+__all__ = [
+ 'IV_IDENTIFICATION_PROMPT_TEMPLATE',
+ 'RDD_IDENTIFICATION_PROMPT_TEMPLATE',
+ 'RCT_IDENTIFICATION_PROMPT_TEMPLATE',
+ 'INSTRUMENT_IDENTIFICATION_PROMPT',
+ 'OVERLAP_ASSESSMENT_PROMPT',
+ 'QUERY_SPECIFIC_INTERPRETATION_PROMPT_TEMPLATE'
+]
\ No newline at end of file
diff --git a/auto_causal/prompts/dataset_analysis_prompts.py b/auto_causal/prompts/dataset_analysis_prompts.py
new file mode 100644
index 0000000000000000000000000000000000000000..87d2cdaa5ac0c27402655ea49b5c40211aa7d332
--- /dev/null
+++ b/auto_causal/prompts/dataset_analysis_prompts.py
@@ -0,0 +1,69 @@
+"""
+Prompt templates for dataset analysis functions including identifying
+instrumental variables, assessing variable relationships, and overlap assessment.
+"""
+
+# Note: These templates use f-string formatting with dataset-specific variables
+
+INSTRUMENT_IDENTIFICATION_PROMPT = """
+You are an expert causal inference assistant helping to identify potential Instrumental Variables (IVs).
+
+I have a dataset with the following characteristics:
+- Treatment variable(s): {potential_treatments}
+- Outcome variable(s): {potential_outcomes}
+- All columns: {all_columns}
+- Column types: {column_types}
+- Variable relationships: {relationships_info}
+
+An instrumental variable must satisfy three conditions:
+1. Relevance: It must be correlated with the treatment variable
+2. Exclusion restriction: It must affect the outcome ONLY through the treatment variable (no direct effect)
+3. Independence: It must be independent of unobserved confounders affecting the outcome
+
+Based on the column names, types, and relationships, identify potential instrumental variables.
+For each potential IV, explain why it might satisfy these conditions.
+
+Return your answer as a list of dictionaries with the following structure:
+[
+ {{
+ "variable": "column_name",
+ "reason": "Brief explanation of why this could be an instrumental variable",
+ "data_type": "column data type",
+ "confidence": "high/medium/low",
+ "relevance_assessment": "Brief assessment of condition 1",
+ "exclusion_assessment": "Brief assessment of condition 2",
+ "independence_assessment": "Brief assessment of condition 3"
+ }}
+]
+
+If you cannot identify any potential IVs, return an empty list.
+"""
+
+OVERLAP_ASSESSMENT_PROMPT = """
+You are an expert causal inference assistant helping to assess covariate balance and overlap between treatment and control groups.
+
+Treatment variable: {treatment}
+Group sizes:
+- Treatment group: {treated_count} observations
+- Control group: {control_count} observations
+
+Covariate statistics:
+{covariate_stats}
+
+Based on this information, assess:
+1. Balance: Are there significant differences in covariates between treatment and control groups?
+2. Overlap: Is there sufficient overlap in covariate distributions to make causal comparisons?
+3. Sample size: Is the sample size adequate for the analysis?
+
+Your assessment should indicate whether methods like propensity score matching or weighting might be necessary.
+
+Return your assessment as a dictionary with the following structure:
+{{
+ "balance_assessment": "Good/Moderate/Poor",
+ "overlap_assessment": "Good/Moderate/Poor",
+ "sample_size_assessment": "Adequate/Limited",
+ "problematic_covariates": ["list", "of", "unbalanced", "covariates"],
+ "recommendation": "Brief recommendation for addressing any issues",
+ "reasoning": "Brief explanation of your assessment"
+}}
+"""
\ No newline at end of file
diff --git a/auto_causal/prompts/method_identification_prompts.py b/auto_causal/prompts/method_identification_prompts.py
new file mode 100644
index 0000000000000000000000000000000000000000..419e06816bae7048070c715edd5a1cfacfd1c6c4
--- /dev/null
+++ b/auto_causal/prompts/method_identification_prompts.py
@@ -0,0 +1,345 @@
+"""
+Prompt templates for identifying specific causal structures (IV, RDD, RCT)
+within the query_interpreter component.
+"""
+
+# Note: These templates expect f-string formatting with variables like:
+# query, description, column_info, treatment, outcome
+
+## TODO: Test is do we need to provide all this information to the LLM or we simply ask find the instrument?
+IV_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are a causal inference assistant tasked with assessing whether a valid Instrumental Variable (IV) exists in the dataset. A valid IV must satisfy **all** of the following conditions:
+
+1. **Relevance**: It must causally influence the Treatment.
+2. **Exclusion Restriction**: It must affect the Outcome only through the Treatment — not directly or indirectly via other paths.
+3. **Independence**: It must be as good as randomly assigned with respect to any unobserved confounders affecting the Outcome.
+4. **Compliance (for RCTs)**: If the dataset comes from a randomized controlled trial or experiment, IVs are only valid if compliance data is available — i.e., if some units did not follow their assigned treatment. In this case, the random assignment may be a valid IV, and compliance is the actual treatment variable. If compliance related variable is not available, do not select IV.
+5. The instrument must be one of the listed dataset columns (not the treatment itself), and must not be assumed or invented.
+
+You should **only suggest an IV if you are confident that all the conditions are satisfied**. Otherwise, return "NULL".
+
+Here is the information about the user query and the dataset:
+
+User Query: "{query}"
+Dataset Description: {description}
+Treatment: {treatment}
+Outcome: {outcome}
+
+Available Columns:
+{column_info}
+
+Return a JSON object with the structure:
+{{ "instrument_variable": "COLUMN_NAME_OR_NULL" }}
+"""
+
+DID_TERM_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are a causal inference assistant tasked with determining whether a valid Difference-in-Differences (DiD) **interaction term** already exists in the dataset.
+
+This DiD term should be a **binary variable** indicating whether a unit belongs to the **treatment group after treatment was applied**.
+
+For example, if a policy was enacted in 2020 for a particular state, then the DiD term would equal 1 for units from that state in years after 2020, and 0 otherwise.
+
+Here is the information:
+
+User Query: "{query}"
+
+Time variable: {time_variable}
+Group variable: {group_variable}
+
+Dataset Description:
+{description}
+
+Available Columns:
+{column_info}
+
+Column Types:
+{column_types}
+
+Return your answer as a valid JSON object with the following format:
+{{ "did_term": "COLUMN_NAME_OR_NULL" }}
+"""
+
+
+
+
+RDD_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are an expert causal inference assistant helping to determine if Regression Discontinuity Design (RDD) is applicable for quasi-experimental analysis.
+Here is the information about the user query and the dataset:
+
+User Query: "{query}"
+Dataset Description: {description}
+Identified Treatment (tentative): {treatment}
+Identified Outcome (tentative): {outcome}
+
+Available Columns:
+{column_info}
+
+Your goal is to check if there is 'Running Variable' i.e. a variable that determines treatment/treatment control. If the variable is above a certain cutoff, the unit is categorized as treat; if below, it is control.
+The running variable must be numeric and continuous. Do not use categorical or low-cardinality variables. Additionally, the treatment variable must be binary in this case. If not, RDD is not valid.
+
+Respond ONLY with a valid JSON object matching the required schema. If RDD is not suggested by the context, return null for both fields.
+Schema: {{ "running_variable": "COLUMN_NAME_OR_NULL", "cutoff_value": NUMERIC_VALUE_OR_NULL }}
+Example: {{ "running_variable": "test_score", "cutoff_value": 70 }} or {{ "running_variable": null, "cutoff_value": null }}
+"""
+
+RCT_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are an expert causal inference assistant helping to determine if the data comes from a Randomized Controlled Trial (RCT).
+Your goal is to assess if the treatment assignment mechanism described or implied was random.
+
+Here is the information about the user query and the dataset:
+
+User Query: "{query}"
+Dataset Description: {description}
+Identified Treatment (tentative): {treatment}
+Identified Outcome (tentative): {outcome}
+
+Available Columns:
+{column_info}
+
+Based on the above information, determine if the data comes a randmomized experiment / radomized controlled trial.
+
+Respond ONLY with a valid JSON object matching the required schema. Respond with true if RCT is likely, false if observational is likely, and null if unsure.
+Schema: {{ "is_rct": BOOLEAN_OR_NULL }}
+Example (RCT likely): {{ "is_rct": true }}
+Example (Observational likely): {{ "is_rct": false }}
+Example (Unsure): {{ "is_rct": null }}
+"""
+
+TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are a causal inference assistant.
+"
+Dataset Description: {description}
+Identified Treatment Variable: "{treatment_variable}"
+Unique Values in Treatment Variable (sample): {treatment_variable_values}
+
+User Query: "{query}
+
+
+Based on the user query, does it specify a particular category of the treatment variable '{treatment_variable}' that should be considered the control, baseline, or reference group for comparison?
+
+Examples:
+- Query: "Effect of DrugA vs Placebo" -> Reference for treatment "Drug" might be "Placebo"
+- Query: "Compare ActiveLearning and StandardMethod against NoIntervention" -> Reference for treatment "TeachingMethod" might be "NoIntervention"
+
+If a reference level is clearly specified or strongly implied AND it is one of the unique values provided for the treatment variable, identify it. Otherwise, state null.
+If multiple values seem like controls (e.g. "compare A and B vs C and D"), return null for now, as this requires more complex handling.
+
+Respond ONLY with a JSON object adhering to this Pydantic model:
+{{
+ "reference_level": "string_representing_the_level_or_null",
+ "reasoning": "string_or_null_brief_explanation"
+}}
+"""
+
+INTERACTION_TERM_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are a causal inference assistant.
+
+Your task is to determine whether the user query suggests the inclusion of an interaction term between the treatment and one covariate, specifically to assess heterogeneous treatment effects (HTE).
+
+User Query:
+"{query}"
+
+Dataset Description:
+"{description}"
+
+Identified Treatment Variable:
+"{treatment_variable}"
+
+Available Covariates (name: type):
+{covariates_list_with_types}
+
+Instructions:
+- ONLY suggest an interaction if the query explicitly mentions treatment across a subgroup.
+- DO NOT suggest an interaction if the query asks for an overall average effect or does not mention subgroup analysis.
+- If you're unsure, default to no interaction.
+
+Respond ONLY with a JSON object that follows this schema:
+
+{{
+ "interaction_needed": boolean, // True if subgroup comparison is clearly mentioned
+ "interaction_variable": string_or_null, // Name of covariate to interact with treatment, or null
+ "reasoning": string // Short explanation
+}}
+Example (interaction suggested):
+{{
+ "interaction_needed": true,
+ "interaction_variable": "gender",
+ "reasoning": "Query asks if the treatment effect if for men."
+}}
+
+Example (no interaction suggested):
+{{
+ "interaction_needed": false,
+ "interaction_variable": null,
+ "reasoning": "Query asks for the overall average treatment effect, no specific subgroups mentioned for effect heterogeneity."
+}}
+"""
+
+
+## This prompt is used to identify the treatment variable.
+TREATMENT_VAR_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are an expert in causal inference. Your task is to identify the **treatment variable** in a dataset in order to perform a causal analysis that answers the user's query.
+
+User Query:
+{query}
+
+Dataset Description:
+{description}
+
+List of Available Variables:
+{column_info}
+
+Based on the query, dataset description, and available variables, determine which variable is most likely to serve as the treatment variable.
+
+If a clear treatment variable cannot be determined from the provided information, return null.
+
+Return your response as a valid JSON object in the following format:
+{{ "treatment": "COLUMN_NAME_OR_NULL" }}
+"""
+
+## This prompt is used to identify the outcome variable.
+OUTCOME_VAR_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are an expert in causal inference. Your task is to identify the **outcome variable** in a dataset in order to perform a causal analysis that answers the user's query.
+
+User Query:
+{query}
+
+Dataset Description:
+{description}
+
+Available Variables:
+{column_info}
+
+Based on the query, dataset description, and available variables, determine which variable is most likely to serve as the outcome variable in the causal analysis.
+
+Do not speculate. If a clear outcome variable cannot be identified from the provided information, return null.
+
+Return your response as a valid JSON object in the following format:
+{{ "outcome": "COLUMN_NAME_OR_NULL" }}
+"""
+
+COVARIATES_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are an expert in causal inference. Your task is to identify the **pre-treatment variables** in a dataset that can be used as controls in a causal estimation model to answer the user's query.
+
+User Query:
+{query}
+
+Dataset Description:
+{description}
+
+Available Variables:
+{column_info}
+
+The treatment variable is: {treatment}
+The outcome variable is: {outcome}
+
+Pre-treatment variables are those that are measured **before** the treatment is applied and are **not affected** by the treatment. These variables can be used as controls in the causal model.
+For example, say we have an RCT with outcome Y, treatment T, and pre-treatment variables X1, X2, and X3. We can perform a regression of the form: Y ~ T + X1 + X2 + X3.
+
+Based on the information above, return a list of variables that qualify as pre-treatment variables from the available columns.
+If no suitable pre-treatment variables can be identified, return an empty list.
+
+Return your response as a valid JSON object in the following format:
+{{ "covariates": ["LIST_OF_COLUMN_NAMES_OR_EMPTY_LIST"] }}
+"""
+
+
+CAUSAL_GRAPH_PROMPT_TEMPLATE = """
+You are an expert in causal inference. Your task is to construct a causal graph to help answer a user query.
+
+Here is the user query:
+{query}
+
+Dataset Description:
+{description}
+
+Here are the treatment and outcome variables:
+Treatment: {treatment}
+Outcome: {outcome}
+
+Here are the available variables in the dataset:
+{column_names}
+
+Based on the query, dataset description, and available variables, list the most relevant direct causal relationships in the dataset.
+Return them as a list in the format "A -> B", where A is the cause and B is the effect. Use only variables present in the dataset. Do not invent or assume any variables.
+Return the result as a Python list of strings, like:
+["A -> B", "B -> C", "A -> C"]
+"""
+
+
+ESTIMAND_PROMPT_TEMPLATE = """
+You are an expert in causal inference. Your task is to determine the appropriate estimand to answer a given query.
+
+Here is the user query:
+{query}
+
+Additionally, the dataset has the following description:
+{dataset_description}
+
+Here are the variables in the dataset:
+{dataset_columns}
+
+Likewise, the treatment variable is: {treatment}, and the outcome variable is: {outcome}.
+
+Given this information, decide whether the Average Treatment Effect (ATE) or the Average Treatment Effect on the Treated (ATT) is more appropriate for answering the query.
+Only return the estimand name: "att" or "ate"
+"""
+
+
+MEDIATOR_PROMPT_TEMPLATE = """
+You are an expert in causal inference. The user is interested in estimating the effect of {treatment} on {outcome}.
+
+Here is the dataset description:
+{description}
+
+Taking into account the treatment, outcome, and the description, from the following variables, is there a valid mediator (i.e., affected by {treatment} and affecting {outcome})?
+{column_names}
+
+*** This should be a valid mediator. If there is no valid mediator, return "None." ***
+Return a single variable name if applicable. If none, return "None."
+"""
+
+CONFOUNDER_PROMPT_TEMPLATE = """
+You are an expert in causal inference.
+
+The user is interested in estimating the effect of {treatment} on {outcome}.
+Here is the dataset description:
+{description}
+
+List 3 to 5 variables from the following that are likely confounders (i.e., affect both {treatment} and {outcome}):
+{column_names}
+
+Return only a comma-separated list of variable names.
+"""
+
+
+CONFOUNDER_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are an expert in causal inference. Your task is to identify potential **confounders** in a dataset that should be adjusted for when estimating the causal effect described in the user query.
+
+User Query:
+{query}
+
+Dataset Description:
+{description}
+
+Available Variables:
+{column_info}
+
+The treatment variable is: {treatment}
+The outcome variable is: {outcome}
+
+A **confounder** is a variable that:
+1. **Affects the treatment** (i.e., influences who receives the treatment), and
+2. **Affects the outcome**, and
+3. **Is not caused by the treatment** (i.e., it must be a pre-treatment variable),
+4. Is **not a mediator** between treatment and outcome.
+
+These variables can create spurious associations between treatment and outcome if not adjusted for.
+
+Based on the user query and the dataset description, identify which variables are likely to be confounders. Only include variables that you believe causally affect both treatment and outcome. If you're uncertain, only include variables where the justification is clear from the query or description.
+
+Return your response as a valid JSON object in the following format:
+{{ "confounders": ["LIST_OF_COLUMN_NAMES_OR_EMPTY_LIST"] }}
+"""
+
+
diff --git a/auto_causal/prompts/prompts.py b/auto_causal/prompts/prompts.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd04dfd23362223d2a69aab97fc23efbcdbe7dc1
--- /dev/null
+++ b/auto_causal/prompts/prompts.py
@@ -0,0 +1,409 @@
+## This prompt is used to identify whether a dataset comes from a randomized trial or not.
+RCT_IDENTIFICATION_PROMPT = """
+You are an expert in causal inference. Your task is to determine whether an input data comes from a **randomized controlled trial (RCT) / randomized experiment **.
+
+Here is the dataset description:
+{description}
+
+Here are the variables included in the dataset:
+{column_info}
+
+Based solely on the dataset description, assess whether the treatment was randomly assigned (i.e., whether the data comes from an RCT).
+RCTs are characterized by random assignment of treatment across the participating units.
+
+Do not speculate. If the description does not provide enough information to decide, return null.
+
+Return your response as a valid JSON object in the following format:
+{{ "is_rct": BOOLEAN_OR_NULL }}
+
+Examples:
+- RCT likely → {{ "is_rct": true }}
+- Observational likely → {{ "is_rct": false }}
+- Unclear or not enough information → {{ "is_rct": null }}
+"""
+
+## This prompt is used to identify the outcome variable.
+OUTCOME_VAR_IDENTIFICATION_PROMPT = """
+You are an expert in causal inference. Your task is to identify the **outcome variable** in a dataset in order to perform a causal analysis that answers the user's query.
+
+User Query:
+{query}
+
+Dataset Description:
+{description}
+
+Available Variables:
+{column_info}
+
+Based on the query, dataset description, and available variables, determine which variable is most likely to serve as the outcome variable in the causal analysis.
+
+Do not speculate. If a clear outcome variable cannot be identified from the provided information, return null.
+
+Return your response as a valid JSON object in the following format:
+{{ "outcome_variable": "COLUMN_NAME_OR_NULL" }}
+"""
+
+## This prompt is used to identify the treatment variable.
+TREATMENT_VAR_IDENTIFICATION_PROMPT = """
+You are an expert in causal inference. Your task is to identify the **treatment variable** in a dataset in order to perform a causal analysis that answers the user's query.
+
+User Query:
+{query}
+
+Dataset Description:
+{description}
+
+List of Available Variables:
+{column_info}
+
+Based on the query, dataset description, and available variables, determine which variable is most likely to serve as the treatment in the causal analysis.
+
+If a clear treatment variable cannot be determined from the provided information, return null.
+
+Return your response as a valid JSON object in the following format:
+{{ "treatment_variable": "COLUMN_NAME_OR_NULL" }}
+"""
+
+
+## This prompt is used to identify whether the dataset comes from an encouragement design, which is a type of randomized experiment where individuals are
+## randomly encouraged to take a treatment, but not all who are encouraged actually comply. For instance, we could randomly selected inviduals and encourage them to
+## take a vaccine. However, we cannot guarantee that all individuals who were encouraged actually took the vaccine. In such case, the mechanism is describe as
+## Z -> T -> Y, where Z is the encouragement variable, T is the treatment variable, and Y is the outcome variable.
+ENCOURAGEMENT_IDENTIFICATION_PROMPT = """
+You are an expert in causal inference. Your task is to determine whether the dataset follows an encouragement design.
+
+Here is the dataset description:
+{description}
+
+Here are the variables included in the dataset:
+{column_info}
+Recall that an encouragement design is a type of experiment where individuals are randomly encouraged to take a treatment, but not all who are encouraged actually comply.
+
+To identify such a design, the dataset must include both:
+1. A variable indicating whether a unit was encouraged (randomized assignment), and
+2. A variable indicating whether the unit actually received the treatment.
+
+If either of these variables is missing, or the description is insufficient, you should return null.
+
+Do not speculate. Base your decision strictly on the provided information.
+
+Return your response as a valid JSON object in the following format:
+{{ "is_encouragement": BOOLEAN_OR_NULL }}
+
+Examples:
+- Encouragement design likely → {{ "is_encouragement": true }}
+- Not an encouragement design → {{ "is_encouragement": false }}
+- Unclear or insufficient information → {{ "is_encouragement": null }}
+"""
+
+## This prompt is used to identify the encouragement variable and the treatment variable in an encouragement design.
+ENCOURAGEMENT_VAR_IDENTIFICATION_PROMPT = """
+You are an expert in causal inference. Your task is to identify variables for performing an encouragement design (Instrumental Variable) analysis to answer the user’s query.
+
+User Query:
+{query}
+
+Dataset Description:
+{description}
+
+Available Variables:
+{column_info}
+
+Based on the query, dataset description, and listed variables, identify:
+
+1. The **encouragement variable** — a randomized variable indicating whether a unit was encouraged to take the treatment.
+2. The **treatment variable** — indicating whether the unit actually took the the treatment.
+
+Do not speculate. If either the encouragement or the treatment variable cannot be clearly identified from the information provided, return null for the respective field.
+
+Return your response as a valid JSON object in the following format:
+{{ "encouragement_variable": "COLUMN_NAME_OR_NULL", "treatment_variable": "COLUMN_NAME_OR_NULL" }}
+"""
+
+## This prompt is used to identify pre-treatment variables that can be used as control in a causal estimation model.
+PRE_TREAT_VAR_IDENTIFICATION_PROMPT = """
+You are an expert in causal inference. Your task is to identify the **pre-treatment variables** in a dataset that can be used as controls in a causal estimation model to answer the user's query.
+
+User Query:
+{query}
+
+Dataset Description:
+{description}
+
+Available Variables:
+{column_info}
+
+The treatment variable is: {treatment}
+The outcome variable is: {outcome}
+
+Pre-treatment variables are those that are measured **before** the treatment is applied and are **not affected** by the treatment. These variables can be used as controls in the causal model.
+For example, say we have an RCT with outcome Y, treatment T, and pre-treatment variables X1, X2, and X3. We can perform a regression of the form: Y ~ T + X1 + X2 + X3.
+
+Based on the information above, return a list of variables that qualify as pre-treatment variables from the available columns.
+If no suitable pre-treatment variables can be identified, return an empty list.
+
+Return your response as a valid JSON object in the following format:
+{{ "pre_treat_variables": ["LIST_OF_COLUMN_NAMES_OR_EMPTY_LIST"] }}
+"""
+
+## This prompt is used to identify whether a Difference-in-Differences (DiD) analysis is appropriate for the dataset in relation to the user query.
+DiD_IDENTIFICATION_PROMPT = """
+You are an expert in causal inference. Your task is to determine whether a difference-in-differences (DiD) analysis is appropriate for analyzing the given dataset to answer a user query.
+
+Here is the user query:
+{query}
+
+Dataset Description:
+{description}
+
+Available Variables:
+{column_info}
+
+Recall that, DiD is used to estimate causal effects by comparing outcome changes over time between treated and control groups.This requires information on outcomes both before and after treatment.
+There are two common types of DiD designs:
+1. **Canonical DiD**: Two groups (treated and control) and two time periods (pre-treatment and post-treatment).
+2. **Staggered DiD**: Multiple groups and multiple time periods, with treatment staggered across groups over time.
+
+Based on the provided information, first determine whether DiD is applicable. If it is, indicate whether the design is canonical or staggered.
+
+Do not speculate. If the information is insufficient to make a determination, return null values.
+
+Return your response as a valid JSON object in the following format:
+{{
+ "is_did_applicable": BOOLEAN_OR_NULL,
+ "is_canonical_did": BOOLEAN_OR_NULL
+}}
+"""
+
+## This prompt is used to identify the temporal variable necessary for performing a Difference-in-Differences (DiD) analysis.
+## The temporal variable must indicate when the observation was recorded or it could be used to construct a post-treatment indicator.
+
+TEMPORAL_VAR_IDENTIFICATION_PROMPT = """
+You are an expert in causal inference. Your task is to determine whether the dataset contains **temporal information relevant to treatment timing**,
+which is necessary to perform a Difference-in-Differences (DiD) analysis to answer a user's query.
+
+User Query:
+{query}
+
+Dataset Description:
+{description}
+
+Available Variables:
+{column_info}
+
+To apply a DiD analysis, the dataset must allow comparison of outcomes **before and after treatment**.
+This requires temporal variables that help determine **when the treatment occurred** and **when each observation was recorded**. There are three possible cases:
+
+1. **Post-treatment indicator is directly available**:
+ - A binary variable (e.g., `post_treatment`) indicates whether the observation occurred after the treatment.
+
+2. **Post-treatment indicator can be constructed using a reference value, which represents the time when the treatment was applied**:
+ - The dataset contains a **time variable** (e.g., `year`, `date`) indicating when each observation occurred.
+ - A single **treatment time** (not a column) can be inferred from the description or query.
+ - From this, a post-treatment indicator can be constructed: `post = 1{year ≥ treatment_time}`.
+
+3. **Treatment is staggered across units i.e. this is a two-way fixed effects model**:
+ - The dataset includes a **time variable** indicating when each observation occurred.
+ - From these, we can construct post indicators or event time variables, such
+
+Only identify these variables if they are relevant for conducting a DiD-style analysis. Do **not** select time-related variables that are unrelated to treatment timing or if the query does not support a before-after causal comparison.
+
+Based on the query, dataset description, and available variables, return:
+
+- "post_treatment_variable": Name of the binary post-treatment indicator, if it exists.
+- "time_variable": Name of the variable indicating when the observation occurred (e.g., `year`, `date`).
+- "treatment_reference_time": A single reference period (not a column) when treatment was introduced, if inferable. Note that this is useful for canonical DiD analysis with two groups and two period: pre-and post-treatment. For stagged DiD, this is not needed. Return NULL.
+
+If any of these cannot be identified, return `null` for that field.
+
+Return your response as a valid JSON object in the following format:
+{{
+ "post_treatment_variable": "COLUMN_NAME_OR_NULL",
+ "time_variable": "COLUMN_NAME_OR_NULL",
+ "treatment_reference_time": YEAR_OR_NULL
+}}
+"""
+
+## This prompt is used to identify the group variable necessary for performing a Difference-in-Differences (DiD) analysis.
+## The group variable must indicate the treatment and control groups, or it could be a categorical variable in case of staggered DiD.
+STATE_VAR_IDENTIFICATION_PROMPT = """
+You are an expert in causal inference. Your task is to determine whether the dataset contains a ** group variable*** necessary to perform a Difference-in-Differences (DiD) analysis to answer a user's query.
+
+Here is the user query:
+{query}
+
+Dataset Description:
+{description}
+
+Available Variables:
+{column_info}
+
+To apply a DiD analysis, the dataset must allow comparison of outcomes between different groups (treatment vs control) over time. This requires a group variable that represents entites that are either treated or not treated.
+There are two possible cases:
+
+1. **Group variable is a binary indicator**:
+ - The data contains a binary variable indicating whether the observation belongs to the treatment or control group.
+2. **Group variable is categorical**:
+ - The data contains a categorical variable indicating different groups.
+
+Based on the query, dataset description, and available variables, return:
+- "group_variable": Name of the group variable that indicates the treatment and control groups or a categorical variable representing different groups in case of staggered DiD.
+- "group_reference": A single reference group (not a column) that corresponds to the treatment group. This is used in canonical DiD only. For example, say a policy was enacted in New Jersey, but not in other states.
+The reference group would be "New Jersey" or "NJ", since the policy was enacted there. We can contruct, TREAT variable as `TREAT = 1 if state == "New Jersey" else 0`.
+
+If a suitable group variable cannot be identified, return null for the "group_variable" field.
+Return your response as a valid JSON object in the following format:
+{{
+ "group_variable": "COLUMN_NAME_OR_NULL",
+ "group_reference": "GROUP_NAME_OR_NULL"
+}}
+"""
+
+## This prompt is used to identify whether a Regression Discontinuity Design (RDD) is appropriate for the dataset in relation to the user query.
+## The goal is to identify the running variable, cutoff value, and treatment variable for RDD analysis.
+
+RDD_IDENTIFICATION_PROMPT = """
+You are an expert in causal inference. Your task is to determine whether a Regression Discontinuity Design (RDD) is appropriate for analyzing the dataset in relation to the user query.
+
+Here is the user query:
+{query}
+
+Dataset Description:
+{description}
+
+Available Variables:
+{column_info}
+
+Recall that RDD requires a numeric, non-binary running variable that determines treatment assignment based on a cutoff—i.e., treatment = 1 if value > cutoff.
+In many cases, the treatment variable may already be included in the dataset and does not need to be computed. In such cases, return the name of the treatment variable as well.
+
+Based on the query, dataset description, and available variables, return:
+- "running_variable": Name of the running variable that determines treatment assignment based on a cutoff.
+- "cutoff_value": The numeric cutoff value that determines treatment assignment.
+- "treatment_variable": Name of the treatment variable indicating whether the unit received treatment (1) or control (0).
+
+Do not speculate. If the information is insufficient, return null for the relevant fields.
+
+Return your response as a valid JSON object in the following format:
+{{
+ "running_variable": "COLUMN_NAME_OR_NULL",
+ "cutoff_value": NUMERIC_VALUE_OR_NULL,
+ "treatment_variable": "COLUMN_NAME_OR_NULL"
+}}
+"""
+
+## This prompt is used to identify whether an Instrumental Variable (IV) is appropriate for the dataset in relation to the user query.
+## The output includes the instrument variable and the treatment variable that the instrument influences. We can check whether the treatment variable selected here is
+## the same as the treatment variable selected by the TREATMENT_VAR_IDENTIFICATION_PROMPT.
+
+INSTRUMENT_VAR_IDENTIFICATION_PROMPT = """
+You are an expert in causal inference. Your task is to determine whether an Instrumental Variable (IV) is appropriate for analyzing the dataset in relation to the user query.
+
+Here is the user query:
+{query}
+
+Dataset Description:
+{description}
+
+Available Variables:
+{column_info}
+
+Recall that a valid instrument must satisfy all of the following conditions:
+1. **Relevance**: It must causally influence the treatment variable.
+2. **Exclusion Restriction**: It must affect the outcome only through the treatment—not directly or through other pathways.
+3. **Independence**: It must be as good as randomly assigned, independent of unobserved confounders affecting the outcome.
+
+The instrument must be one of the variables listed in the dataset (not the treatment itself), and must not be assumed or invented.
+
+Based on the query, dataset description, and available variables, return:
+- "instrument_variable": Name of the variable that can serve as a valid instrument.
+- "treatment_variable": Name of the treatment variable that the instrument influences.
+
+Do not speculate. Only suggest an IV if you are confident that all conditions are satisfied. Otherwise, return null.
+
+Return your response as a valid JSON object in the following format:
+{{
+ "instrument_variable": "COLUMN_NAME_OR_NULL",
+ "treatment_variable": "COLUMN_NAME_OR_NULL"
+}}
+"""
+
+## This prompt is used to construct a causal graph based on the user query, dataset description, treatment, outcome, and available variables.
+## Once the graph is constructed, we can use dowhy to identify frontdoor, backdoor adjustment sets.
+
+CAUSAL_GRAPH_PROMPT = """
+You are an expert in causal inference. Your task is to construct a causal graph to help answer a user query.
+
+Here is the user query:
+{query}
+
+Dataset Description:
+{description}
+
+Here are the treatment and outcome variables:
+Treatment: {treatment}
+Outcome: {outcome}
+
+Here are the available variables in the dataset:
+{column_info}
+
+Based on the query, dataset description, and available variables, construct a causal graph that captures the relationships between the treatment, outcome, and other relevant variables.
+
+Use only variables present in the dataset. Do not invent or assume any variables. However, not all variables need to be included—only those that are relevant to the causal relationships should appear in the graph.
+
+Return the causal graph in DOT format. The DOT format should include:
+- Nodes for each included variable.
+- Directed edges representing causal relationships among variables.
+
+Also return the list of edges in the format "A -> B", where A and B are variable names.
+
+Here is an example of the DOT format:
+digraph G {
+ A -> B;
+ B -> C;
+ A -> C;
+}
+
+And the corresponding list of edges:
+["A -> B", "B -> C", "A -> C"]
+
+Return your response as a valid JSON object in the following format:
+{{
+ "causal_graph": "DOT_FORMAT_STRING",
+ "edges": ["EDGE_1", "EDGE_2", ...]
+}}
+"""
+
+## This prompt is used to determine whether an interaction term between the treatment variable and any covariate is needed in the causal model
+## In case, we need to look at the coefficient of the interaction term to answer the user query, set the boolean "interaction_term_query" to true.
+## If the interaction term is not needed, but its inclusion may still be statistically or substantively justified, set "interaction_term_query" to true. In this case,
+## we include the itneraction term in the mode, but we do not need to look at the coefficient of the interaction term. We use the coefficient of the treatment variable
+
+INTERACTION_TERM_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are an expert in causal inference. Your task is to determine whether an interaction term between the treatment variable and any covariate is needed in the causal model to answer the user query.
+
+Here is the user query:
+{query}
+
+Dataset Description:
+{description}
+
+Identified Treatment Variable: "{treatment_variable}"
+Available Covariates (with types): "{covariates_list_with_types}"
+Outcome Variable: "{outcome_variable}"
+
+Recall that an interaction term is needed when:
+- To answer the query, we need to examine the **coefficient of the interaction term**, in which case the interaction is strictly necessary.
+- The goal is to estimate **heterogeneous treatment effects** — i.e., whether the effect of the treatment on the outcome differs based on the level or value of a covariate.
+- In other cases, the interaction may not be required to answer the query, but its inclusion may still be statistically or substantively justified.
+
+Based on the information above, determine whether an interaction term between the treatment and any covariate is needed to answer the user query.
+
+Return your response as a valid JSON object in the following format:
+{{
+ "interaction_needed": BOOLEAN, // True if an interaction term is needed or beneficial
+ "interaction_variable": "COLUMN_NAME_OR_NULL", // Covariate to interact with treatment, or null if not needed
+ "reasoning": "REASONING_STRING", // Brief explanation of your decision
+ "interaction_term_query": BOOLEAN // True if the interaction is essential to answering the query
+}}
+"""
diff --git a/auto_causal/prompts/regression_prompts.py b/auto_causal/prompts/regression_prompts.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5abcbbeef858ac5f3ee325c362e47e9bf7dbf25
--- /dev/null
+++ b/auto_causal/prompts/regression_prompts.py
@@ -0,0 +1,21 @@
+STATSMODELS_PARAMS_IDENTIFICATION_PROMPT_TEMPLATE = """
+You are a statistical assistant. Given a list of parameter names from a regression model,
+the user's query, and context about the treatment variable, identify the parameter names and their original indices
+that are relevant for answering the query or for providing a general overview of the treatment effect.
+
+User Query: "{user_query}"
+Treatment variable in formula (Patsy term): "{treatment_patsy_term}"
+Original treatment column name: "{treatment_col_name}"
+Is treatment multi-level categorical with a reference: {is_multilevel_case}
+Reference level (if multi-level): "{reference_level_for_prompt}"
+
+Available Parameter Names (with their original 0-based index):
+{indexed_param_names_str}
+
+Instructions:
+-Respond with best matching param or params in case multiple matches with their index/s
+-Exclude interaction terms (those containing ':') unless the query *specifically* asks for an interaction effect. This task is focused on main treatment effects.
+
+Respond ONLY with a valid JSON object matching this Pydantic model schema:
+{llm_response_schema_json}
+"""
\ No newline at end of file
diff --git a/auto_causal/prompts/result_interpretation_prompts.py b/auto_causal/prompts/result_interpretation_prompts.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7298f75ddad4e03e1a7dd7e4f12b0681815f9ab
--- /dev/null
+++ b/auto_causal/prompts/result_interpretation_prompts.py
@@ -0,0 +1,35 @@
+"""
+Prompts for interpreting statistical model results in the context of a specific user query.
+"""
+
+QUERY_SPECIFIC_INTERPRETATION_PROMPT_TEMPLATE = """
+You are an AI assistant. Your task is to analyze the results of a statistical model and extract the specific information that answers the user's query.
+
+User Query: "{user_query}"
+
+Context from Model Execution:
+- Treatment Variable: "{treatment_variable}"
+- Reference Level for Treatment (if any): "{reference_level}"
+- Model Formula: "{formula}"
+- Estimated Effects by Treatment Level (compared to reference, if applicable):
+{effects_by_level_str}
+- Information on Interaction Term (if any):
+{interaction_info_str}
+
+Full Model Summary (for additional context if needed, prefer structured 'Estimated Effects' above):
+---
+{model_summary_text}
+---
+
+Instructions:
+1. Carefully read the User Query to understand what specific treatment effect or comparison they are interested in.
+2. Examine the 'Estimated Effects by Treatment Level' to find the statistics (estimate, p-value, confidence interval, std_err) for the treatment level or comparison most relevant to the query.
+3. If the query refers to a specific treatment level (e.g., "Civic Duty" when treatment variable is "treatment" with levels "Control", "Civic Duty", etc.), focus on that level's comparison to the reference.
+4. Determine if the identified effect is statistically significant (p-value < 0.05).
+5. If a significant interaction is noted in 'Information on Interaction Term' and it involves the identified treatment level, briefly state how it modifies the main effect in your interpretation. Do not perform complex calculations; just state the presence and direction if clear.
+6. Construct a concise 'interpretation_summary' that directly answers the User Query using the extracted statistics.
+7. If the query cannot be directly answered (e.g., the specific level isn't in the results, or the query is too abstract for the given data), explain this in 'unanswered_query_reason'.
+
+Respond ONLY with a valid JSON object matching this Pydantic model schema:
+{llm_response_schema_json}
+"""
\ No newline at end of file
diff --git a/auto_causal/synthetic/__init__.py b/auto_causal/synthetic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d478c6ff525c81eae2655f82402c2e5b7a77427
--- /dev/null
+++ b/auto_causal/synthetic/__init__.py
@@ -0,0 +1,2 @@
+
+from auto_causal.synthetic.generator import PSMGenerator, PSWGenerator, IVGenerator, RDDGenerator, RCTGenerator, DiDGenerator, MultiTreatRCTGenerator
diff --git a/auto_causal/synthetic/generator.py b/auto_causal/synthetic/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7ebc31cf1aef914d6abd74d7dec187ef82ba442
--- /dev/null
+++ b/auto_causal/synthetic/generator.py
@@ -0,0 +1,663 @@
+## This code contains the base classess used in generating synthetic data
+
+from linearmodels.iv import IV2SLS
+from dowhy import CausalModel
+from dowhy import datasets as dset
+from sklearn.linear_model import LogisticRegression
+import statsmodels.api as sm
+import statsmodels.formula.api as smf
+import numpy as np
+import pandas as pd
+from pathlib import Path
+import matplotlib.pyplot as plt
+
+class DataGenerator:
+ """
+ Base class for generating synthetic data
+
+ Attributes:
+ n_observations (int): Number of observations
+ n_continuous_covars (int): Number of covariates
+ n_covars (int): total number of covariates (continuous + binary)
+ n_treatments (int): Number of treatments
+ true_effect (float): True effect size
+ seed (int): Random seed for reproducibility
+ data (pd.DataFrame): Generated data
+ info (dict): Dictionary to store additional information about the data
+ method (str): the causal inference method assocated with the synthetic
+ mean (np.ndarray): mean of the covariates
+ covar (np.ndarray): covariance matrix for the covariates
+ heterogeneity (bool): whether or not the treatment effects are heterogeneous
+ """
+
+ def __init__(self, n_observations, n_continuous_covars, n_binary_covars=2, mean=None,
+ covar = None, n_treatments=1, true_effect=0 ,seed=111, heterogeneity=0):
+
+ np.random.seed(seed)
+ self.n_observations = n_observations
+ self.n_continuous_covars = n_continuous_covars
+ self.n_covars = n_continuous_covars + n_binary_covars
+ self.n_treatments = n_treatments
+ self.n_binary_covars = n_binary_covars
+ self.data = None
+ self.seed = seed
+ self.true_effect = true_effect
+ self.method = None
+ self.mean = mean
+ self.covar = covar
+ if mean is None:
+ self.mean = np.random.randint(3, 20, size=self.n_continuous_covars)
+ if self.covar is None:
+ self.covar = np.identity(self.n_continuous_covars)
+ self.heterogeneity = heterogeneity
+
+ def generate_data(self):
+ """
+ Generates the synthetic data
+
+ Returns:
+ pd.DataFrame: The generated data
+ """
+
+ raise NotImplementedError("Invoke the method in the subclass")
+
+ def save_data(self, folder, filename):
+ """
+ Saves the generated data as a CSV file
+
+ Args:
+ folder (str): path to the folder where the data is saved
+ filename (str): name of the file
+ """
+
+ if self.data is None:
+ raise ValueError("Data not generated yet. Please generate data first.")
+ path = Path(folder)
+ path.mkdir(parents=True, exist_ok=True)
+ if not filename.endswith('.csv'):
+ filename += '.csv'
+ self.data.to_csv(path / filename, index=False)
+
+ def test_data(self, print_=False):
+ """
+ Test the generated data, using the appropriate method.
+ """
+
+ raise NotImplementedError("This method should be overridden by subclasses")
+
+ def generate_covariates(self):
+ """
+ Generate covariates. For continuous covariates, we use multivariate normal distribution, and for
+ binary covars, we use binomial distribution. The non-binary covariates are discretized to their floor
+ integer.
+ """
+
+ X_c = np.random.multivariate_normal(mean=self.mean, cov=self.covar,
+ size=self.n_observations)
+ p = np.random.uniform(0.3, 0.7)
+ X_b = np.random.binomial(1, p, size=(self.n_observations, self.n_binary_covars)).astype(int)
+ covariates = np.hstack((X_c, X_b))
+ covariates = covariates.astype(int)
+
+ return covariates
+
+class MultiTreatRCTGenerator(DataGenerator):
+ """
+ Base class for generating synthetic data for multi-treatment RCTs
+
+ Additional Attributes:
+ true_effect_vec (np.ndarray): the treatment effect for different treatments.
+ """
+ def __init__(self, n_observations, n_continuous_covars, n_treatments, n_binary_covars=2,
+ mean=None, covar=None, true_effect=1.0, true_effect_vec = None,
+ seed=111, heterogeneity=0):
+
+ super().__init__(n_observations, n_continuous_covars, n_binary_covars=n_binary_covars,
+ mean=mean, covar=covar, true_effect=true_effect, seed=seed,
+ heterogeneity=heterogeneity, n_treatments=n_treatments)
+
+ self.method = "MultiTreatRCT"
+ self.true_effect_vec = true_effect_vec
+
+ ## if true effect vec is None, we set the treatment effects to be the same for all treatments
+ if true_effect_vec is None:
+ self.true_effect_vec = np.zeros(n_treatments)
+ for i in range(1, n_treatments):
+ self.true_effect_vec[i] = self.true_effect
+
+ def generate_data(self):
+
+ X = self.generate_covariates()
+ cols = [f"X{i+1}" for i in range(self.n_covars)]
+ df = pd.DataFrame(X, columns=cols)
+
+
+ df['D'] = np.random.randint(0, self.n_treatments+1, size=self.n_observations)
+ vec = np.random.uniform(0, 1, size=self.n_covars)
+ intercept = np.random.normal(50, 3)
+ noise = np.random.normal(0, 1, size=self.n_observations)
+
+ # Apply appropriate treatment effect per treatment arm
+ treatment_effects = np.array(self.true_effect_vec)
+ df['treat_effect'] = treatment_effects[df['D']]
+
+ df['Y'] = intercept + X.dot(vec) + df['treat_effect'] + noise
+
+ df.drop(columns='treat_effect', inplace=True)
+ self.data = df
+
+ return df
+
+
+ def test_data(self, print_=False):
+
+ if self.data is None:
+ raise ValueError("Data not generated yet. Please generate data first.")
+
+ model = smf.ols('Y ~ C(D)', data=self.data).fit()
+
+ result = model.summary()
+ if print_:
+ print(result)
+ return result
+
+
+# Front-Door Criterion Generator
+class FrontDoorGenerator(DataGenerator):
+ """
+ Generates synthetic data satisfying the front-door criterion.
+ D → M → Y, D ← U → Y
+ """
+ def __init__(self, n_observations, n_continuous_covars=2, n_binary_covars=2,
+ mean=None, covar=None, seed=111, true_effect=2.0, heterogeneity=0):
+ super().__init__(n_observations, n_continuous_covars, n_binary_covars=n_binary_covars,
+ mean=mean, covar=covar, seed=seed, true_effect=true_effect,
+ n_treatments=1, heterogeneity=heterogeneity)
+ self.method = "FrontDoor"
+
+ def generate_data(self):
+ X = self.generate_covariates()
+ cols = [f"X{i+1}" for i in range(self.n_covars)]
+ df = pd.DataFrame(X, columns=cols)
+
+ # Latent confounder
+ U = np.random.normal(0, 1, self.n_observations)
+
+ # Treatment depends on U and X
+ vec_d = np.random.uniform(0.5, 1.5, size=self.n_covars)
+ df['D'] = (X @ vec_d + 0.8 * U + np.random.normal(0, 1, self.n_observations)) > 0
+ df['D'] = df['D'].astype(int)
+
+ # Mediator depends on D and X
+ vec_m = np.random.uniform(0.5, 1.5, size=self.n_covars)
+ df['M'] = X @ vec_m + df['D'] * 1.5 + np.random.normal(0, 1, self.n_observations)
+
+ # Outcome depends on M, U and X
+ vec_y = np.random.uniform(0.5, 1.5, size=self.n_covars)
+ df['Y'] = 50 + 2.0 * df['M'] + 1.0 * U + X @ vec_y + np.random.normal(0, 1, self.n_observations)
+
+ self.data = df
+ return df
+
+ def test_data(self, print_=False):
+ if self.data is None:
+ raise ValueError("Data not generated yet. Please generate data first.")
+
+ model_m = smf.ols("M ~ D", data=self.data).fit()
+ model_y = smf.ols("Y ~ M + D", data=self.data).fit()
+
+ if print_:
+ print("Regression: M ~ D")
+ print(model_m.summary())
+ print("\nRegression: Y ~ M + D")
+ print(model_y.summary())
+
+ return {"M~D": model_m.summary(), "Y~M+D": model_y.summary()}
+
+class ObservationalDataGenerator(DataGenerator):
+ """
+ Generate synthetic data for observational studies.
+
+ Additional Attributes:
+ self.weights (np.ndarray): the propoensity score weights for each observation
+ """
+
+ def __init__(self, n_observations, n_continuous_covars, n_binary_covars=2, mean=None, covar=None,
+ true_effect=1.0, seed=111, heterogeneity=0):
+
+ super().__init__(n_observations, n_continuous_covars, n_binary_covars=n_binary_covars, mean=mean, covar=covar,
+ true_effect=true_effect, seed=seed, heterogeneity=heterogeneity)
+
+ def generate_data(self):
+
+ X = self.generate_covariates()
+
+ cols = [f"X{i+1}" for i in range(self.n_covars)]
+ df = pd.DataFrame(X, columns=cols)
+ X_norm = (X - X.mean(axis=0)) / X.std(axis=0)
+
+ vec1 = np.random.normal(0, 0.5, size=self.n_covars)
+ lin = X_norm @ vec1 + np.random.normal(0, 1, self.n_observations)
+ ## the propensity score
+ ps = 1 / (1 + np.exp(-lin))
+ ## we do this for stability reasons
+ ps = np.clip(ps, 1e-3, 1 -1e-3)
+ df['D'] = np.random.binomial(1, ps).astype(int)
+ vec2 = np.random.normal(0, 0.5, size=self.n_covars)
+ intercept = np.random.normal(50, 3)
+ noise = np.random.normal(0, 1, size=self.n_observations)
+ df['Y'] = intercept + X @ vec2 + self.true_effect * df['D'] + noise
+
+ self.propensity = ps
+ self.weights = np.where(df['D'] == 1, 1 / ps, 1 / (1 - ps))
+ self.data = df
+
+ return self.data
+
+class PSMGenerator(ObservationalDataGenerator):
+ """
+ Generate synthetic data for Propensity Score Matching (PSM)
+ """
+
+ def __init__(self, n_observations, n_continuous_covars, n_binary_covars=2, mean=None, covar=None,
+ true_effect=1.0, seed=111, heterogeneity=0):
+ super().__init__(n_observations, n_continuous_covars, n_binary_covars=n_binary_covars, mean=mean, covar=covar,
+ true_effect=true_effect, seed=seed, heterogeneity=heterogeneity)
+ self.method = "PSM"
+
+ def test_data(self, print_=False):
+ """
+ Test the generated data
+ """
+ if self.data is None:
+ raise ValueError("Data not generated yet. Please generate data first.")
+
+ lr = LogisticRegression(solver='lbfgs')
+ X = self.data[[f"X{i+1}" for i in range(self.n_covars)]]
+ lr.fit(X, self.data['D'])
+ ps_hat = lr.predict_proba(X)[:, 1]
+ treated = self.data[self.data['D'] == 1]
+ control = self.data[self.data['D'] == 0]
+
+ ## perform matching using the propensity scores
+ match_idxs = [np.abs(ps_hat[control.index] - ps_hat[i]).argmin() for i in treated.index]
+ matches = control.iloc[match_idxs]
+ att = treated['Y'].mean() - matches['Y'].mean()
+
+ result = f"Estimated ATT (matching): {att:.3f} | True: {self.true_effect}"
+ if print_:
+ print(result)
+ return result
+
+class PSWGenerator(ObservationalDataGenerator):
+ """
+ Generate synthetic data for Propensity Score Weighting (PSW)
+ """
+
+ def __init__(self, n_observations, n_continuous_covars, n_binary_covars=2, mean=None, covar=None,
+ true_effect=1.0, seed=111, heterogeneity=0):
+ super().__init__(n_observations, n_continuous_covars, n_binary_covars=n_binary_covars, mean=mean, covar=covar,
+ true_effect=true_effect, seed=seed, heterogeneity=heterogeneity)
+ self.method = "PSW"
+
+ def test_data(self, print_=False):
+ """
+ Test the generated data
+ """
+ if self.data is None:
+ raise ValueError("Data not generated yet. Please generate data first.")
+
+ df = self.data.copy()
+ D = df['D']
+ Y = df['Y']
+
+ treated = D == 1
+ control = D == 0
+
+ w = np.zeros(self.n_observations)
+ w[control] = self.propensity[control] / (1 - self.propensity[control])
+ w[treated] = 1
+
+ Y1 = Y[treated].mean()
+ Y0_weighted = np.average(Y[control], weights=w[control])
+
+ att = Y1 - Y0_weighted
+ ate = np.average(Y * D / self.propensity - (1 - D) * Y / (1 - self.propensity))
+ result = f"Estimated ATT (IPW): {att:.3f} | True: {self.true_effect}\nEstimated ATE: {ate:.3f} | True:{self.true_effect}"
+ if print_:
+ print(result)
+
+ return result
+
+
+class RCTGenerator(DataGenerator):
+ """
+ Generate synthetic data for Randomized Controlled Trials (RCT)
+ """
+
+ def __init__(self, n_observations, n_continuous_covars, n_binary_covars=2, mean=None,
+ covar=None, true_effect=1.0, seed=111, heterogeneity=0):
+ super().__init__(n_observations, n_continuous_covars, n_binary_covars=n_binary_covars,
+ mean=mean, covar=covar, true_effect=true_effect, seed=seed,
+ heterogeneity=heterogeneity)
+ self.method = "RCT"
+
+ def generate_data(self):
+
+ X = self.generate_covariates()
+ cols = [f"X{i+1}" for i in range(self.n_covars)]
+ df = pd.DataFrame(X, columns=cols)
+ df['D'] = np.random.binomial(1, 0.5, size=self.n_observations)
+ vec = np.random.uniform(0, 1, size=self.n_covars)
+ intercept = np.random.normal(50, 3)
+ noise = np.random.normal(0, 1, size=self.n_observations)
+ df['Y'] = (intercept + X.dot(vec) + self.true_effect * df['D'] + noise)
+ self.data = df
+
+ def test_data(self, print=False):
+ if self.data is None:
+ raise ValueError("Data not generated yet. Please generate data first.")
+ model = smf.ols('Y ~ D', data=self.data).fit()
+ result = model.summary()
+ if print:
+ print(result)
+ est = model.params['D']
+ conf_int = model.conf_int().loc['D']
+ result = f"TRUE ATE: {self.true_effect:.3f}, ESTIMATED ATE: {est:.3f}, \
+ 95% CI: [{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
+
+ return result
+
+class IVGenerator(DataGenerator):
+ """
+ Generate synthetic data for Instrumental Variables (IV) analysis. We assume two forms:
+ 1. Encouragement Design:
+ Z -> D -> Y
+ In this setting, encouragements (Z) is randomized. For instance, consider the administering of vaccines.
+ We cannot force people to take vaccines, however we can encourage them to take the vaccine. We could run
+ a vaccine awareness campaign, where we randomly pick participants, and inform them about the benefits of
+ vaccine. The user can either comply (take the vaccine) or not comply (not take the vaccine). Likewise, in the control
+ group, the user can comply (not take the vaccine) or defy (take the vaccine)
+ 2.
+ U
+ / \
+ Z -> D -> Y
+ This is the classical setting where we have an unobserved confounder affecting both treatment (D) and outcome (Y).
+
+
+ Additional Attributes:
+ alpha (float): the effect of the instrument on the treatment (Z on D)
+ encouragement (bool): whether or not this is an encouragement design
+ beta_d (float): effect of the unobserved confounder (U) on treatment (D)
+ beta_y (float): effect of the unobserved confounders (U) on outcome (Y)
+
+ """
+
+ def __init__(self, n_observations, n_continuous_covars, n_binary_covars=2, mean=None, beta_d = 1.0,
+ beta_y = 1.5, covar=None, true_effect=1.0, seed=111, heterogeneity=0, alpha=0.5,
+ encouragement=False):
+ super().__init__(n_observations, n_continuous_covars, n_binary_covars=n_binary_covars, mean=mean,
+ covar=covar, true_effect=true_effect, seed=seed, heterogeneity=heterogeneity)
+ self.method = "IV"
+ self.alpha = alpha
+ self.encouragement = encouragement
+ self.beta_d = beta_d
+ self.beta_y = beta_y
+
+ def generate_data(self):
+ X = self.generate_covariates()
+
+ mean = np.random.randint(8, 13)
+ Z = np.random.normal(mean, 2, size=self.n_observations).astype(int)
+ U = np.random.normal(0, 1, size=self.n_observations)
+ vec1 = np.random.normal(0, 0.5, size=self.n_covars)
+ intercept1 = np.random.normal(30, 2)
+ D = self.alpha * Z + X @ vec1 + np.random.normal(size=self.n_observations) + intercept1
+ if self.encouragement:
+ D = (D > np.mean(D)).astype(int)
+ else:
+ D = D + self.beta_d * U
+ D = D.astype(int)
+
+ intercept2 = np.random.normal(50, 3)
+ vec2 = np.random.normal(0, 0.5, size=self.n_covars)
+ Y = self.true_effect * D + X @ vec2 + np.random.normal(size=self.n_observations) + intercept2
+ if not self.encouragement:
+ Y = Y + self.beta_y * U
+ df = pd.DataFrame(X, columns=[f"X{i+1}" for i in range(self.n_covars)])
+ df['Z'] = Z
+ df['D'] = D
+ df['Y'] = Y
+ self.data = df
+
+ return self.data
+
+
+ def test_data(self, print_=False):
+
+ if self.data is None:
+ raise ValueError("Data not generated yet.")
+ model = IV2SLS.from_formula('Y ~ 1 + [D ~ Z]', data=self.data).fit()
+ est = model.params['D']
+ conf_int = model.conf_int().loc['D']
+ result = f"TRUE LATE: {self.true_effect:.3f}, ESTIMATED LATE: {est:.3f}, \
+ 95% CI: [{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
+ if print_:
+ print(result)
+ return result
+
+
+class RDDGenerator(DataGenerator):
+ """
+ Generate synthetic data for (sharp) Regression Discontinuity Design (RDD).
+
+ Additional Attributes:
+ cutoff (float): the cutoff for treatment assignment
+ bandwidth (float): the bandwidth for the running variable we consider when estimating the treatment effects
+ plot (bool): whether we plot the data or not
+ """
+
+ def __init__(self, n_observations, n_continuous_covars, n_binary_covars=2, mean=None, plot=False,
+ covar=None, true_effect=1.0, seed=111, heterogeneity=0, cutoff=10, bandwidth=0.1):
+ super().__init__(n_observations, n_continuous_covars, n_binary_covars=n_binary_covars,
+ mean=mean, covar=covar, true_effect=true_effect, seed=seed,
+ heterogeneity=heterogeneity)
+ self.cutoff = cutoff
+ self.bandwidth = bandwidth
+ self.method = "RDD"
+ self.plot=plot
+
+ print("self.plot", self.plot)
+
+ def generate_data(self):
+
+ X = self.generate_covariates()
+ cols = [f"X{i+1}" for i in range(self.n_covars)]
+ df = pd.DataFrame(X, columns=cols)
+
+ df['running_X'] = np.random.normal(0, 2, size=self.n_observations) + self.cutoff
+
+
+ df['D'] = (df['running_X'] >= self.cutoff).astype(int)
+
+ intercept = 10
+ coeffs = np.random.normal(0, 0.1, size=self.n_covars)
+
+ ## slope of the line below the threshold
+ m_below = 1.5
+ ## slope of the line above the threshold
+ m_above = 0.8
+
+ df['running_centered'] = df['running_X'] - self.cutoff
+ # Use centered version for slope
+ df["Y"] = (intercept + self.true_effect * df['D'] + m_below * df['running_centered'] * (1 - df['D']) + \
+ m_above * df['running_centered'] * df['D'] + X @ coeffs + np.random.normal(0, 0.5, size=self.n_observations))
+
+ if self.plot:
+ plt.figure(figsize=(10, 6))
+ plt.scatter(df[df['D']==0]['running_X'], df[df['D']==0]['Y'],
+ alpha=0.5, label='Control', color='blue')
+ plt.scatter(df[df['D']==1]['running_X'], df[df['D']==1]['Y'],
+ alpha=0.5, label='Treatment', color='red')
+ plt.axvline(self.cutoff, color='black', linestyle='--', label='Cutoff')
+ plt.show()
+
+ self.data = df[[cols for cols in df.columns if cols != 'running_centered']]
+
+ return self.data
+
+
+ def test_data(self, print_=False):
+
+ if self.data is None:
+ raise ValueError("Data not generated yet.")
+ df = self.data.copy()
+ df['running_adj'] = df['running_X'].astype(float) - self.cutoff
+ df = df[np.abs(df['running_adj']) <= self.bandwidth].copy()
+ model = smf.ols('Y ~ D + running_adj + D:running_adj', data=df).fit()
+ est = model.params['D']
+ conf_int = model.conf_int().loc['D']
+
+ result = f"TRUE LATE: {self.true_effect:.3f}, ESTIMATED LATE: {est:.3f}, \
+ 95% CI: [{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
+ if print_:
+ print(result)
+ return result
+
+
+class DiDGenerator(DataGenerator):
+ """
+ Generate synthetic data for Difference-in-Differences (DiD) analysis
+
+ Additional Attributes:
+ 1. n_periods (int): number of time-periods
+ """
+
+ def __init__(self, n_observations, n_continuous_covars, n_binary_covars=2, n_periods=2,
+ mean=None, covar=None, true_effect=1.0, seed=111, heterogeneity=0):
+ super().__init__(n_observations, n_continuous_covars, n_binary_covars=n_binary_covars,
+ mean=mean, covar=covar, true_effect=true_effect,
+ seed=seed, heterogeneity=heterogeneity)
+
+ self.method = "DiD"
+ self.n_periods = n_periods
+
+ def canonical_did_model(self):
+ """
+ This is the classical DiD setting with two periods (pre and post treatment) and two groups (treatment and control)
+ """
+
+ ## fraction of observations that receives the treatment
+ frac_treated = np.random.uniform(0.35, 0.65)
+ n_treated = int(frac_treated * self.n_observations)
+ unit_ids = np.arange(self.n_observations)
+ treatment_status = np.zeros(self.n_observations, dtype=int)
+ treatment_status[:n_treated] = 1
+ np.random.shuffle(treatment_status)
+
+ X = self.generate_covariates()
+ cols = [f"X{i+1}" for i in range(self.n_covars)]
+ covar_df = pd.DataFrame(X, columns=cols)
+
+
+ vec = np.random.normal(0, 0.1, size=self.n_covars)
+
+ intercept = np.random.normal(50, 3)
+ treat_effect = np.random.normal(0, 1)
+ time_effect = np.random.normal(0, 1)
+
+ covar_term = X @ vec
+ pre_noise = np.random.normal(0, 1, self.n_observations)
+ pre_outcome = intercept + covar_term + pre_noise + treat_effect * treatment_status
+ pre_data = pd.DataFrame({'unit_id': unit_ids, 'post': 0, 'D': treatment_status,
+ 'Y': pre_outcome})
+ post_noise = np.random.normal(0, 1, self.n_observations)
+ post_outcome = (intercept + time_effect + covar_term + self.true_effect * treatment_status
+ + treat_effect * treatment_status + post_noise)
+
+ post_data = pd.DataFrame({'unit_id': unit_ids, 'post': 1, 'D': treatment_status,
+ 'Y': post_outcome})
+
+ df = pd.concat([pre_data, post_data], ignore_index=True)
+
+ df = df.merge(covar_df, left_on="unit_id", right_index=True)
+
+ return df[['unit_id', 'post', 'D', 'Y'] + cols]
+
+ def twfe_model(self):
+ """
+ Generate panel data for Two-Way Fixed Effects DiD model. This is a generalization of 2-period DiD for multi-year treatments
+ """
+
+ ## fraction of observations that receives the treatment
+ frac_treated = np.random.uniform(0.35, 0.65)
+ unit_ids = np.arange(1, self.n_observations + 1)
+ time_periods = np.arange(0, self.n_periods)
+
+ df = pd.DataFrame([(i, t) for i in unit_ids for t in time_periods],
+ columns=["unit", "time"])
+
+ X = self.generate_covariates()
+ for j in range(self.n_covars):
+ df[f"X{j+1}"] = np.repeat(X[:, j], self.n_periods)
+
+ ## Assign treatment timing
+ n_treated = int(frac_treated * self.n_observations)
+ treated_units = np.random.choice(unit_ids, size=n_treated, replace=False)
+ treatment_start = {unit: np.random.randint(1, self.n_periods) for unit in treated_units}
+
+ df["treat_post"] = df.apply(lambda row: int(row["unit"] in treatment_start and
+ row["time"] >= treatment_start[row["unit"]]),axis=1)
+
+ ## State fixed effects
+ unit_effects = dict(zip(unit_ids, np.random.normal(0, 1.0, self.n_observations)))
+ ## Time fixed effects
+ time_effects = dict(zip(time_periods, np.random.normal(0, 1, len(time_periods))))
+ df["unit_fe"] = df["unit"].map(unit_effects)
+ df["time_fe"] = df["time"].map(time_effects)
+
+ covar_effects = np.random.normal(0, 0.1, self.n_covars)
+ X_matrix = df[[f"X{j+1}" for j in range(self.n_covars)]].values
+ covar_term = X_matrix @ covar_effects
+ intercept = np.random.normal(50, 3)
+ noise = np.random.normal(0, 1, len(df))
+
+ df["Y"] = intercept + covar_term + df["unit_fe"] + df["time_fe"] + self.true_effect * df["treat_post"] + noise
+
+ final_df = df[["unit", "time", "treat_post", "Y"] + [f"X{j+1}" for j in range(self.n_covars)]]
+ final_df = final_df.rename(columns={"time": "year", "treat_post": "D"})
+
+ return final_df
+
+
+ def generate_data(self):
+
+ if self.n_periods == 2:
+ self.data = self.canonical_did_model()
+ else:
+ self.data = self.twfe_model()
+
+ return self.data
+
+
+ def test_data(self, print_=False):
+
+ estimated_att = None
+ if self.data is None:
+ raise ValueError("Data not generated yet.")
+ if self.n_periods == 2:
+ print("Testing canonical DiD model")
+ model = smf.ols('Y ~ D * post', data=self.data).fit()
+ estimated_att = model.params['D:post']
+ conf_int = model.conf_int().loc['D:post']
+ else:
+ print("Testing TWFE model")
+ model = smf.ols('Y ~ D + C(unit) + C(year)', data=self.data).fit()
+ estimated_att = model.params['D']
+ conf_int = model.conf_int().loc['D']
+
+ result = "TRUE ATT: {:.3f}, EMPRICAL ATT:{:.3f}\nCONFIDENCE INTERVAL:{}".format(
+ self.true_effect, estimated_att, conf_int)
+ if print_:
+ print(result)
+ return result
diff --git a/auto_causal/synthetic/io.py b/auto_causal/synthetic/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..184ee34aa8f05b712230ba8a40dee0991395de15
--- /dev/null
+++ b/auto_causal/synthetic/io.py
@@ -0,0 +1,508 @@
+## This file contains the functions that uses the classes in generator.py to generate the synthetic data
+
+from .generator import PSMGenerator, PSWGenerator, IVGenerator, RDDGenerator, RCTGenerator, DiDGenerator, MultiTreatRCTGenerator, FrontDoorGenerator
+import pandas as pd
+import numpy as np
+from pathlib import Path
+import logging
+import logging.config
+import json
+
+from .util import export_info
+
+Path("reproduce_results/logs").mkdir(parents=True, exist_ok=True)
+logging.config.fileConfig('reproduce_results/log_config.ini')
+
+def config_hyperparameters(base_seed, base_mean, base_cov_diag, max_cont, max_bin, n_obs,
+ max_obs, min_obs, max_treat=2, max_periods=5, cutoff_max=25):
+ """
+ configure the hyperparameters for the data generation process.
+
+ Args:
+ base_seed (int): Base seed for random number generation
+ base_mean (np.ndarray): Base mean vector for the covariates
+ base_cov_diag (np.ndarray): Base (diagonal) covariance matrix for the covariates
+ max_cont (int): Maximum number of continuous covariates
+ max_bin (int): Maximum number of binary covariates
+ n_obs (int): Number of observations to generate
+ max_obs (int): Maximum number of observations to generate
+ min_obs (int): Minimum number of observations to generate
+ max_treat (int): Maximum number of treatment groups (default is 2)
+ max_periods (int): Maximum number of periods for DiD data (default is 5)
+ cutoff_max (int): Maximum value for the cutoff in RDD data (default is 25)
+
+ Returns:
+ dict: A dictionary containing the hyperparameters for data generation.
+ (str) attribute -> (int) value
+
+
+ """
+
+ base_cov_mat = np.diag(base_cov_diag)
+ np.random.seed(base_seed)
+ n_treat = np.random.randint(2, max_treat + 1)
+ true_effect = np.random.uniform(1, 10)
+ true_effect_vec = np.array([0] + [np.random.uniform(1, 10) for i in range(n_treat)])
+ n_continuous = np.random.randint(2, max_cont + 1)
+ n_binary = np.random.randint(2, max_bin)
+ n_observations = np.random.randint(min_obs, max_obs + 1)
+ if n_obs is not None:
+ n_observations = n_obs
+ n_periods = np.random.randint(3, max_periods + 1)
+ cutoff = np.random.randint(2, cutoff_max + 1)
+ mean_vec = base_mean[0:n_continuous]
+ cov_mat = base_cov_mat[0:n_continuous, 0:n_continuous]
+
+
+ param_dict = {'tau': true_effect, 'continuous': n_continuous, 'binary': n_binary,
+ 'obs': n_observations, 'mean': mean_vec, 'covar': cov_mat,
+ 'tau_vec':true_effect_vec, "treat":n_treat, "periods": n_periods,
+ 'cutoff':cutoff}
+
+ return param_dict
+
+
+def generate_observational_data(base_mean, base_cov, dset_size, max_cont, max_bin, min_obs,
+ max_obs, data_save_loc, metadata_save_loc, n_obs=None):
+ """
+ Generate observational data using the PSMGenerator class.
+
+ Args:
+ base_mean (np.ndarray): Base mean vector for the covariates
+ base_cov (np.ndarray): Base covariance matrix for the covariates
+ dset_size (int): Number of datasets to generate
+ max_cont (int): Maximum number of continuous covariates
+ max_bin (int): Maximum number of binary covariates
+ min_obs (int): Minimum number of observations to generate
+ max_obs (int): Maximum number of observations to generate
+ data_save_loc (str): Directory to save the generated data files
+ metadata_save_loc (str): Directory to save the metadata information
+ n_obs (int, None): number of observations. If None, it will be randomly
+ generated within the range of min_obs and max_obs.
+ """
+
+ logger = logging.getLogger("observational_data_logger")
+ logger.info("Generating observational data")
+ metadata_dict = {}
+ base_seed = 31
+ for i in range(dset_size):
+ logger.info("Iteration: {}".format(i))
+ seed = (i + 1) * base_seed
+ params = config_hyperparameters(seed, base_mean, base_cov, max_cont, max_bin,
+ n_obs, max_obs, min_obs)
+ logger.info("n_observations:{}, n_continuous: {}, n_binary: {}".format(
+ params['obs'], params['continuous'], params['binary']))
+ logger.info("true_effect: {}".format(params['tau']))
+ mean_vec = params['mean']
+ cov_mat = params['covar']
+ gen = PSMGenerator(params['obs'], params['continuous'], n_binary_covars=params['binary'],
+ mean=mean_vec, covar=cov_mat, true_effect=params['tau'], seed=seed*2)
+ data = gen.generate_data()
+ name = "observational_data_{}.csv".format(i)
+ data_dict = {"true_effect": params['tau'], "observation": params['obs'], "continuous": params['continuous'],
+ "binary": params['binary'], "type": "observational"}
+ test_result = gen.test_data()
+ logger.info("Test result: {}\n".format(test_result))
+ metadata_dict[name] = data_dict
+ gen.save_data(data_save_loc, name)
+ export_info(metadata_dict, metadata_save_loc, "observational")
+
+
+def generate_rct_data(base_mean, base_cov, dset_size, max_cont, max_bin, min_obs, max_obs,
+ data_save_loc, metadata_save_loc, n_obs=None):
+ """
+ Generates RCT data
+
+ Args:
+ base_mean (np.ndarray): Base mean vector for the covariates
+ base_cov (np.ndarray): Base covariance matrix for the covariates
+ dset_size (int): Number of datasets to generate
+ max_cont (int): Maximum number of continuous covariates
+ max_bin (int): Maximum number of binary covariates
+ min_obs (int): Minimum number of observations to generate
+ max_obs (int): Maximum number of observations to generate
+ data_save_loc (str): Directory to save the generated data files
+ metadata_save_loc (str): Directory to save the metadata information
+ n_obs (int, None): number of observations. If None, it will be randomly
+ generated within the range of min_obs and max_obs.
+ """
+
+ logger = logging.getLogger("rct_data_logger")
+ logger.info("Generating RCT data")
+ metadata_dict = {}
+ base_seed = 197
+ for i in range(dset_size):
+ logger.info("Iteration: {}".format(i))
+ seed = (i + 1) * base_seed
+ params = config_hyperparameters(seed, base_mean, base_cov, max_cont, max_bin, n_obs,
+ max_obs, min_obs)
+ logger.info("n_observations:{}, n_continuous: {}, n_binary: {}".format(
+ params['obs'], params['continuous'], params['binary']))
+ logger.info("true_effect: {}".format(params['tau']))
+ mean_vec = params['mean']
+ cov_mat = params['covar']
+ gen = RCTGenerator(params['obs'], params['continuous'], n_binary_covars=params['binary'],
+ mean=mean_vec, covar=cov_mat, true_effect=params['tau'], seed=seed)
+ data = gen.generate_data()
+ test_result = gen.test_data()
+ data_dict = {"true_effect": params['tau'], "observation": params['obs'], "continuous": params['continuous'],
+ "binary": params['binary'], "type": "rct"}
+ name = "rct_data_{}.csv".format(i)
+ logger.info("Test result: {}\n".format(test_result))
+ metadata_dict[name] = data_dict
+ gen.save_data(data_save_loc, name)
+ export_info(metadata_dict, metadata_save_loc, "rct")
+
+
+def generate_multi_rct_data(base_mean, base_cov, dset_size, max_n_treat, max_cont, max_bin, min_obs, max_obs,
+ data_save_loc, metadata_save_loc, n_obs=None):
+ """
+ Generate multi-treatment RCT data
+ Args:
+ base_mean (np.ndarray): Base mean vector for the covariates
+ base_cov (np.ndarray): Base covariance matrix for the covariates
+ dset_size (int): Number of datasets to generate
+ max_n_treat (int): Maximum number of treatment groups
+ max_cont (int): Maximum number of continuous covariates
+ max_bin (int): Maximum number of binary covariates
+ min_obs (int): Minimum number of observations to generate
+ max_obs (int): Maximum number of observations to generate
+ data_save_loc (str): Directory to save the generated data files
+ metadata_save_loc (str): Directory to save the metadata information
+ n_obs (int, None): number of observations. If None, it will be randomly
+ generated within the range of min_obs and max_obs.
+ """
+ logger = logging.getLogger("multi_rct_data_logger")
+ logger.info("Generating multi-treatment RCT data")
+ metadata_dict = {}
+ base_seed = 173
+ for i in range(dset_size):
+ logger.info("Iteration: {}".format(i))
+ seed = (i+1) * base_seed
+ params = config_hyperparameters(seed, base_mean, base_cov, max_cont, max_bin, n_obs,
+ max_obs, min_obs, max_treat=max_n_treat)
+ n_treat = params['treat']
+ logger.info("n_observations:{}, n_continuous: {}, n_binary: {}, n_treat: {}".format(
+ params['obs'], params['continuous'], params['binary'], n_treat))
+ logger.info("true_effect: {}".format(params['tau_vec']))
+ mean_vec = params['mean']
+ cov_mat = params['covar']
+ gen = MultiTreatRCTGenerator(params['obs'], params['continuous'], params['treat'], n_binary_covars=params['binary'],
+ mean=mean_vec, covar=cov_mat, true_effect_vec=params['tau_vec'], seed=seed,
+ true_effect=params['tau'])
+ data = gen.generate_data()
+ test_result = gen.test_data()
+ data_dict = {"true_effect": list(params['tau_vec']), "observation": params['obs'], "continuous": params['continuous'],
+ "binary": params['binary'], "type": "multi_rct"}
+ name = "multi_rct_data_{}.csv".format(i)
+ logger.info("Test result: {}\n".format(test_result))
+ metadata_dict[name] = data_dict
+ gen.save_data(data_save_loc, name)
+ export_info(metadata_dict, metadata_save_loc, "multi_rct")
+
+
+def generate_frontdoor_data(base_mean, base_cov, dset_size, max_cont, max_bin, min_obs, max_obs,
+ data_save_loc, metadata_save_loc, n_obs=None):
+ """
+ Generates front-door data
+
+ Args:
+ base_mean (np.ndarray): Base mean vector for the covariates
+ base_cov (np.ndarray): Base covariance matrix for the covariates
+ dset_size (int): Number of datasets to generate
+ max_cont (int): Max number of continuous covariates
+ max_bin (int): Max number of binary covariates
+ min_obs (int): Minimum number of observations
+ max_obs (int): Maximum number of observations
+ data_save_loc (str): Folder to save generated CSV files
+ metadata_save_loc (str): Folder to save metadata JSON
+ n_obs (int or None): Fixed number of observations (if provided)
+ """
+
+ logger = logging.getLogger("frontdoor_data_logger")
+ logger.info("Generating Front-Door synthetic data")
+ metadata_dict = {}
+ base_seed = 311
+
+ for i in range(dset_size):
+ logger.info(f"Iteration: {i}")
+ seed = (i + 1) * base_seed
+
+ params = config_hyperparameters(seed, base_mean, base_cov, max_cont, max_bin, n_obs,
+ max_obs, min_obs)
+
+ logger.info("n_observations: {}, n_continuous: {}, n_binary: {}".format(
+ params['obs'], params['continuous'], params['binary']))
+ logger.info("true_effect: {}".format(params['tau']))
+
+ mean_vec = params['mean']
+ cov_mat = params['covar']
+
+ gen = FrontDoorGenerator(
+ n_observations=params['obs'],
+ n_continuous_covars=params['continuous'],
+ n_binary_covars=params['binary'],
+ mean=mean_vec,
+ covar=cov_mat,
+ true_effect=params['tau'],
+ seed=seed
+ )
+
+ data = gen.generate_data()
+ test_result = gen.test_data()
+ logger.info("Test result: {}\n".format(test_result))
+
+ # Save CSV
+ filename = f"frontdoor_data_{i}.csv"
+ gen.save_data(data_save_loc, filename)
+
+ # Metadata
+ data_dict = {
+ "true_effect": params['tau'],
+ "observation": params['obs'],
+ "continuous": params['continuous'],
+ "binary": params['binary'],
+ "type": "frontdoor"
+ }
+ metadata_dict[filename] = data_dict
+
+ # Save metadata JSON
+ export_info(metadata_dict, metadata_save_loc, "frontdoor")
+
+
+
+def generate_canonical_did_data(base_mean, base_cov, dset_size, max_cont, max_bin, min_obs, max_obs,
+ data_save_loc, metadata_save_loc, n_obs=None):
+ """
+ Generate canonical DiD data
+ Args:
+ base_mean (np.ndarray): Base mean vector for the covariates
+ base_cov (np.ndarray): Base covariance matrix for the covariates
+ dset_size (int): Number of datasets to generate
+ max_cont (int): Maximum number of continuous covariates
+ max_bin (int): Maximum number of binary covariates
+ min_obs (int): Minimum number of observations to generate
+ max_obs (int): Maximum number of observations to generate
+ data_save_loc (str): Directory to save the generated data files
+ metadata_save_loc (str): Directory to save the metadata information
+ n_obs (int, None): number of observations. If None, it will be randomly
+ generated within the range of min_obs and max_obs.
+ """
+ logger = logging.getLogger("did_data_logger")
+ logger.info("Generating canonical DiD data")
+ metadata_dict = {}
+ base_seed = 281
+ for i in range(dset_size):
+ logger.info("Iteration: {}".format(i))
+ seed = (i + 1) * base_seed
+ params = config_hyperparameters(seed, base_mean, base_cov, max_cont, max_bin, n_obs,
+ max_obs, min_obs)
+ logger.info("n_observations:{}, n_continuous: {}, n_binary: {}".format(
+ params['obs'], params['continuous'], params['binary']))
+ logger.info("true_effect: {}".format(params['tau']))
+ mean_vec = params['mean']
+ cov_mat = params['covar']
+ gen = DiDGenerator(params['obs'], params['continuous'], n_binary_covars=params['binary'],
+ mean=mean_vec, covar=cov_mat, true_effect=params['tau'], seed=seed)
+ data = gen.generate_data()
+ test_result = gen.test_data()
+ data_dict = {"true_effect": params['tau'], "observation": params['obs'], "continuous": params['continuous'],
+ "binary": params['binary'], "type": "did_canonical"}
+ name = "did_canonical_data_{}.csv".format(i)
+ logger.info("Test result: {}\n".format(test_result))
+ metadata_dict[name] = data_dict
+ gen.save_data(data_save_loc, name)
+
+ export_info(metadata_dict, metadata_save_loc, "did_canonical")
+
+def generate_data_iv(base_mean, base_cov, dset_size, max_cont, max_bin, min_obs, max_obs,
+ data_save_loc, metadata_save_loc, n_obs=None):
+ """
+ Generate IV data
+ Args:
+ base_mean (np.ndarray): Base mean vector for the covariates
+ base_cov (np.ndarray): Base covariance matrix for the covariates
+ dset_size (int): Number of datasets to generate
+ max_cont (int): Maximum number of continuous covariates
+ max_bin (int): Maximum number of binary covariates
+ min_obs (int): Minimum number of observations to generate
+ max_obs (int): Maximum number of observations to generate
+ data_save_loc (str): Directory to save the generated data files
+ metadata_save_loc (str): Directory to save the metadata information
+ n_obs (int, None): number of observations. If None, it will be randomly
+ generated within the range of min_obs and max_obs.
+ """
+
+ logger = logging.getLogger("iv_data_logger")
+ logger.info("Generating IV data")
+ metadata_dict = {}
+ base_seed = 343
+ for i in range(dset_size):
+ logger.info("Iteration: {}".format(i))
+ seed = (i + 1) * base_seed
+ params = config_hyperparameters(seed, base_mean, base_cov, max_cont, max_bin, n_obs,
+ max_obs, min_obs)
+ logger.info("n_observations:{}, n_continuous: {}, n_binary: {}".format(
+ params['obs'], params['continuous'], params['binary']))
+ logger.info("true_effect: {}".format(params['tau']))
+ mean_vec = params['mean']
+ cov_mat = params['covar']
+ gen = IVGenerator(params['obs'], params['continuous'], n_binary_covars=params['binary'],
+ mean=mean_vec, covar=cov_mat, true_effect=params['tau'], seed=seed)
+ data = gen.generate_data()
+ test_result = gen.test_data()
+ data_dict = {"true_effect": params['tau'], "observation": params['obs'], "continuous": params['continuous'],
+ "binary": params['binary'], "type": "IV"}
+ name = "iv_data_{}.csv".format(i)
+ logger.info("Test result: {}\n".format(test_result))
+ metadata_dict[name] = data_dict
+ gen.save_data(data_save_loc, name)
+
+ export_info(metadata_dict, metadata_save_loc, "iv")
+
+def generate_twfe_did_data(base_mean, base_cov, dset_size, max_cont, max_bin, n_periods,
+ min_obs, max_obs, data_save_loc, metadata_save_loc, n_obs=None):
+ """
+ Generate TWFE DiD data
+
+ Args:
+ base_mean (np.ndarray): Base mean vector for the covariates
+ base_cov (np.ndarray): Base covariance matrix for the covariates
+ dset_size (int): Number of datasets to generate
+ max_cont (int): Maximum number of continuous covariates
+ max_bin (int): Maximum number of binary covariates
+ n_periods (int): Number of periods for the DiD data
+ min_obs (int): Minimum number of observations to generate
+ max_obs (int): Maximum number of observations to generate
+ data_save_loc (str): Directory to save the generated data files
+ metadata_save_loc (str): Directory to save the metadata information
+ n_obs (int, None): number of observations. If None, it will be randomly
+ generated within the range of min_obs and max_obs.
+ """
+
+ logger = logging.getLogger("did_data_logger")
+ logger.info("Generating TWFE DiD data")
+ metadata_dict = {}
+ base_seed = 447
+ print("preiods: ", n_periods)
+ for i in range(dset_size):
+ logger.info("Iteration: {}".format(i))
+ seed = (i + 1) * base_seed
+ params = config_hyperparameters(seed, base_mean, base_cov, max_cont, max_bin, n_obs,
+ max_obs, min_obs, max_periods=n_periods)
+ logger.info("n_observations:{}, n_continuous: {}, n_binary: {}, n_periods:{}".format(
+ params['obs'], params['continuous'], params['binary'], params['periods']))
+ logger.info("true_effect: {}".format(params['tau']))
+ mean_vec = params['mean']
+ cov_mat = params['covar']
+ gen = DiDGenerator(params['obs'], params['continuous'], n_binary_covars=params['binary'],
+ mean=mean_vec, covar=cov_mat, true_effect=params['tau'], seed=seed,
+ n_periods=n_periods)
+ data = gen.generate_data()
+ test_result = gen.test_data()
+ data_dict = {"true_effect": params['tau'], "observation": params['obs'], "continuous": params['continuous'],
+ "binary": params['binary'], "type": "did_twfe", "periods": params['periods']}
+ name = "did_twfe_data_{}.csv".format(i)
+ logger.info("Test result: {}\n".format(test_result))
+ metadata_dict[name] = data_dict
+ gen.save_data(data_save_loc, name)
+
+ export_info(metadata_dict, metadata_save_loc, "did_twfe")
+
+def generate_encouragement_data(base_mean, base_cov, dset_size, max_cont, max_bin, min_obs, max_obs,
+ data_save_loc, metadata_save_loc, n_obs=None):
+ """
+ Generate encouragement design data
+
+ Args:
+ base_mean (np.ndarray): Base mean vector for the covariates
+ base_cov (np.ndarray): Base covariance matrix for the covariates
+ dset_size (int): Number of datasets to generate
+ max_cont (int): Maximum number of continuous covariates
+ max_bin (int): Maximum number of binary covariates
+ min_obs (int): Minimum number of observations to generate
+ max_obs (int): Maximum number of observations to generate
+ data_save_loc (str): Directory to save the generated data files
+ metadata_save_loc (str): Directory to save the metadata information
+ n_obs (int, None): number of observations. If None, it will be randomly
+ generated within the range of min_obs and max_obs.
+ """
+
+ logger = logging.getLogger("iv_data_logger")
+ logger.info("Generating encouragement design data")
+ metadata_dict = {}
+ base_seed = 571
+ for i in range(dset_size):
+ logger.info("Iteration: {}".format(i))
+ seed = (i + 1) * base_seed
+ params = config_hyperparameters(seed, base_mean, base_cov, max_cont, max_bin, n_obs,
+ max_obs, min_obs)
+ logger.info("n_observations:{}, n_continuous: {}, n_binary: {}".format(
+ params['obs'], params['continuous'], params['binary']))
+ logger.info("true_effect: {}".format(params['tau']))
+ mean_vec = params['mean']
+ cov_mat = params['covar']
+ gen = IVGenerator(params['obs'], params['continuous'], n_binary_covars=params['binary'],
+ mean=mean_vec, covar=cov_mat, true_effect=params['tau'], seed=seed,
+ encouragement=True)
+ data = gen.generate_data()
+ test_result = gen.test_data()
+ data_dict = {"true_effect": params['tau'], "observation": params['obs'], "continuous": params['continuous'],
+ "binary": params['binary'], "type": "encouragement"}
+ name = "iv_encouragement_data_{}.csv".format(i)
+ logger.info("Test result: {}\n".format(test_result))
+ metadata_dict[name] = data_dict
+ gen.save_data(data_save_loc, name)
+
+ export_info(metadata_dict, metadata_save_loc, "iv_encouragement")
+
+
+def generate_rdd_data(base_mean, base_cov, dset_size, max_cont, max_bin, max_cutoff,
+ min_obs, max_obs, data_save_loc, metadata_save_loc, n_obs=None):
+
+ """
+ Generates (sharp) RDD data
+
+ Args:
+ base_mean (np.ndarray): Base mean vector for the covariates
+ base_cov (np.ndarray): Base covariance matrix for the covariates
+ dset_size (int): Number of datasets to generate
+ max_cont (int): Maximum number of continuous covariates
+ max_bin (int): Maximum number of binary covariates
+ max_cutoff (int): Maximum value for the cutoff in RDD data
+ min_obs (int): Minimum number of observations to generate
+ max_obs (int): Maximum number of observations to generate
+ data_save_loc (str): Directory to save the generated data files
+ metadata_save_loc (str): Directory to save the metadata information
+ n_obs (int, None): number of observations. If None, it will be randomly
+ generated within the range of min_obs and max_obs.
+ """
+
+ logger = logging.getLogger("rdd_data_logger")
+ logger.info("Generating RDD data")
+ metadata_dict = {}
+ base_seed = 683
+ for i in range(dset_size):
+ logger.info("Iteration:{}".format(i))
+ seed = (i + 1) * base_seed
+ params = config_hyperparameters(seed, base_mean, base_cov, max_cont, max_bin, n_obs,
+ max_obs, min_obs, cutoff_max=max_cutoff)
+ logger.info("n_observations:{}, n_continuous: {}, n_binary: {}, cutoff:{}".format(
+ params['obs'], params['continuous'], params['binary'], params['cutoff']))
+ logger.info("true_effect: {}".format(params['tau']))
+ mean_vec = params['mean']
+ cov_mat = params['covar']
+ gen = RDDGenerator(params['obs'], params['continuous'], n_binary_covars=params['binary'],
+ mean=mean_vec, covar=cov_mat, true_effect=params['tau'], seed=seed,
+ cutoff=params['cutoff'], plot=True)
+
+ data = gen.generate_data()
+ test_result = gen.test_data()
+ data_dict = {"true_effect": params['tau'], "observation": params['obs'], "continuous": params['continuous'],
+ "binary": params['binary'], "type": "rdd", 'cutoff': params['cutoff']}
+ name = "rdd_data_{}.csv".format(i)
+ logger.info("Test result: {}\n".format(test_result))
+ metadata_dict[name] = data_dict
+ gen.save_data(data_save_loc, name)
+
+ export_info(metadata_dict, metadata_save_loc, "rdd")
diff --git a/auto_causal/synthetic/prompts.py b/auto_causal/synthetic/prompts.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4d5bc4e7c19aefc7bf1729122fbaceb0df0951f
--- /dev/null
+++ b/auto_causal/synthetic/prompts.py
@@ -0,0 +1,144 @@
+## This file contains the functions that can be used to create prompts for generating synthetic data contexts.
+
+def generate_data_summary(df, n_cont_vars, n_bin_vars, method, cutoff=None) -> str:
+ """
+ Generate a summary of the input dataset. The summary includes information about column headings
+ for continuuous, binary, treatment, and outcome variables. Additionally, it also includes information on the method
+ used to generate the dataset and the basic statistical summary.
+
+ Args:
+ df (pd.DataFrame): The input dataset.
+ n_cont_vars (int): Number of continuous variables in the dataset
+ n_bin_vars (int): Number of binary variables in the dataset
+ method (str): The method used to generate the dataset
+ cutff (float, None): The cutoff value for RDD data
+
+ Returns:
+ str: Summary of the (raw) dataset.
+
+ """
+
+ continuous_vars = [f"X{i}" for i in range(1, n_cont_vars + 1)]
+ binary_vars = [f"X{i}" for i in range(n_cont_vars + 1, n_cont_vars + n_bin_vars + 1)]
+
+ information = "The dataset contains the following **continuous covariates**: " + ", ".join(continuous_vars) + ".\n"
+ information += "The dataset contains the following **binary covariates**: " + ", ".join(binary_vars) + ".\n"
+ information += "The **outcome variable** is Y.\n"
+ information += "The **treatment variable** is D.\n"
+
+ if method == "encouragement":
+ information += "This is an encouragement design where Z is the instrument, i.e., the \
+ , the initial treatment assignment \ n"
+ elif method == "IV":
+ information += "This is an IV design where Z is the instrument \n"
+ elif method == "rdd":
+ information += "The running variable is running_X, and the cutoff is {}\n".format(cutoff)
+ elif method == "did_twfe":
+ information += "This is a staggered Difference in Difference where D indicates whether or not the unit is treated \
+ at time t. Similarly, year denotes the time at which the data was measured.\n"
+ elif method == "did_canonical":
+ information += "This is a canonical Difference in Difference where D indicates whether or not the unit is treated \
+ at time t. Similarly, post is a binary variable indicating post / pre-intervention time points \
+ , post = 1 indicates post-intervention time points.\n"
+
+ information += "Here is the statistical summary of the variables: \n " + str(df.describe(include='all')) + "\n"
+
+ return information
+
+
+def create_prompt(summary, method, domain, history):
+
+ """
+ Creates a prompt for the OpenAI API to generate a context for the given dataset
+
+ Args:
+ summary (str): Summary of the dataset
+ method (str): The method used to generate the dataset
+ domain (str): The domain of the dataset
+ history (str): Previous contexts that have been used. We use this to avoid overlap in contexts
+
+ """
+
+ method_names = {"encouragement": "Encouragement Design", "did_twfe": "Difference in Differences with Two-Way Fixed Effects",
+ "did_canonical": "Canonical Difference in Differences", "IV": "Instrumental Variable",
+ "multi_rct": "Multi-Treatment Randomized Control Trial", "rdd": "Regression Discontinuity Design",
+ "observational": "Observational", "rct": "Randomized Control Trial", "frontdoor": "Front-Door Causal Inference"}
+
+ domain_guides = {
+ "education": "Education data often includes student performance, school-level features, socioeconomic background, and intervention types like tutoring or online classes.",
+ "healthcare": "Healthcare data may include treatments, diagnoses, hospital visits, recovery outcomes, or demographic details.",
+ "labor": "Labor datasets typically include income, education, job type, employment history, and training programs.",
+ "policy": "Policy evaluation data may track program participation, regional differences, economic impact, and public outcomes like housing, safety, or benefits."
+ }
+
+ prompt = f"""
+You are a helpful assistant generating realistic, domain-specific contexts for synthetic datasets.
+
+The current dataset is designed for **{method_names[method]}** studies in the **{domain}** domain.
+
+### Dataset Summary
+{summary}
+
+### Previously Used Contexts (avoid duplication)
+{history}
+
+### Domain-Specific Guidance
+{domain_guides.get(domain, '')}
+
+
+---
+
+### Your Tasks:
+1. Propose a **realistic real-world scenario** that fits a {method_names[method]} study in the {domain} domain. Mention whether the data was collected from a randomized trial, policy rollout, or real-world observation.
+2a. Assign **realistic and concise variable names** in snake_case. Map original variable names like `"X1"` to names like `"education_years"`.
+2b. Provide a **one-line natural-language description for each variable** (e.g., `education_years: total years of formal schooling completed by the individual.`). Use newline-separated key-value format.
+3. Write a **paragraph** describing the dataset's background: who collected it, what was studied, why, and how.
+4. Write a **natural language causal question** the dataset could answer. The question should:
+ - Relate implicitly to the treatment and outcome
+ - Avoid any statistical or causal terminology
+ - Avoid naming variables directly
+ - Feel like it belongs in a news article or report
+5. Write a **1-2 sentence summary** capturing the dataset's overall intent and contents.
+
+---
+
+
+ Return your output as a JSON object with the following keys:
+ - "variable_labels": {{ "X1": "education_years", ... }}
+ - "description": ""
+ - "question": ""
+ - "summary":
+ - "domain": ""
+
+ Return only a valid JSON object. Do not include any markdown, explanations, or extra text.
+ """
+
+
+
+ return prompt
+
+def filter_question(question):
+ """
+ Filter the question to remove explicit mentions of variables.
+
+ Args:
+ question (str): The original causal query
+
+ Returns:
+ str: The filtered causal query
+ """
+
+ prompt = """
+ You are a helpful assistant. Help me filter this causal query.
+
+ The query is: {}
+ The query should not provide information on what variables one needs to consider in course of causal analysis.
+ For example,
+ Bad question: "What is the effect of the training program on job outcomes considering education and experience?"
+ Good question: "What is the effect of the training program on job outcomes?"
+
+ If the question is already filtered, return it as is.
+ Return only the filtered query. Do not say anything else.
+ """
+
+ return prompt.format(question)
diff --git a/auto_causal/synthetic/util.py b/auto_causal/synthetic/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef77bb32d0e7be07339378cb6bfa4ecc25238821
--- /dev/null
+++ b/auto_causal/synthetic/util.py
@@ -0,0 +1,10 @@
+import json
+from pathlib import Path
+
+
+def export_info(info, folder, name):
+ Path(folder).mkdir(parents=True, exist_ok=True)
+ if ".json" not in name:
+ name = name + ".json"
+ with open(f"{folder}/{name}", "w") as f:
+ json.dump(info, f, indent=4)
diff --git a/auto_causal/tools/__init__.py b/auto_causal/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..29b992048ab94ab9cdc8075648ecd6586f382935
--- /dev/null
+++ b/auto_causal/tools/__init__.py
@@ -0,0 +1,32 @@
+"""
+Auto Causal tools package.
+
+This package contains the tool wrappers for the auto_causal LangChain agent,
+each providing an interface to a specific component.
+"""
+
+from auto_causal.tools.input_parser_tool import input_parser_tool
+from auto_causal.tools.dataset_analyzer_tool import dataset_analyzer_tool
+from auto_causal.tools.query_interpreter_tool import query_interpreter_tool
+from auto_causal.tools.method_selector_tool import method_selector_tool
+from auto_causal.tools.method_validator_tool import method_validator_tool
+from auto_causal.tools.method_executor_tool import method_executor_tool
+from auto_causal.tools.explanation_generator_tool import explanation_generator_tool
+from auto_causal.tools.output_formatter_tool import output_formatter_tool
+
+# Removed imports for DataAnalyzer, DecisionTreeEngine, MethodImplementer
+# These are components, not tools, or have been removed.
+# from causalscientist.auto_causal.tools.data_analyzer import DataAnalyzer
+# from causalscientist.auto_causal.tools.decision_tree import DecisionTreeEngine
+# from causalscientist.auto_causal.tools.method_implementer import MethodImplementer
+
+__all__ = [
+ "input_parser_tool",
+ "dataset_analyzer_tool",
+ "query_interpreter_tool",
+ "method_selector_tool",
+ "method_validator_tool",
+ "method_executor_tool",
+ "explanation_generator_tool",
+ "output_formatter_tool"
+]
diff --git a/auto_causal/tools/data_analyzer.py b/auto_causal/tools/data_analyzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec58b33a6ad5ad133821cb90e345d6a41859102c
--- /dev/null
+++ b/auto_causal/tools/data_analyzer.py
@@ -0,0 +1,212 @@
+"""
+Data Analyzer class for causal inference pipelines.
+
+This module provides the DataAnalyzer class for analyzing datasets
+and extracting relevant information for causal inference.
+"""
+
+import pandas as pd
+import numpy as np
+from typing import Dict, List, Any, Optional
+
+
+class DataAnalyzer:
+ """
+ Data analyzer for causal inference datasets.
+
+ This class provides methods for analyzing datasets to extract
+ relevant information for causal inference, such as variables,
+ relationships, and temporal structures.
+ """
+
+ def __init__(self, verbose=False):
+ """
+ Initialize the data analyzer.
+
+ Args:
+ verbose: Whether to print verbose information
+ """
+ self.verbose = verbose
+
+ def analyze_dataset(self, dataset_path: str) -> Dict[str, Any]:
+ """
+ Analyze a dataset and extract relevant information.
+
+ Args:
+ dataset_path: Path to the dataset file
+
+ Returns:
+ Dictionary with dataset analysis results
+ """
+ try:
+ # Load the dataset
+ df = pd.read_csv(dataset_path)
+
+ # Get basic statistics
+ n_rows, n_cols = df.shape
+ columns = list(df.columns)
+
+ # Get column types and categories
+ column_types = {col: str(df[col].dtype) for col in columns}
+ column_categories = self._categorize_columns(df)
+
+ # Check for temporal structure
+ temporal_structure = self._check_temporal_structure(df)
+
+ # Identify potential confounders
+ variable_relationships = self._identify_relationships(df)
+
+ # Look for potential instruments
+ potential_instruments = self._identify_potential_instruments(df)
+
+ # Check for discontinuities
+ discontinuities = self._check_discontinuities(df)
+
+ # Construct the analysis result
+ analysis = {
+ "filepath": dataset_path,
+ "n_rows": n_rows,
+ "n_cols": n_cols,
+ "columns": columns,
+ "column_types": column_types,
+ "column_categories": column_categories,
+ "temporal_structure": temporal_structure,
+ "variable_relationships": variable_relationships,
+ "potential_instruments": potential_instruments,
+ "discontinuities": discontinuities
+ }
+
+ if self.verbose:
+ print(f"Dataset analysis completed: {n_rows} rows, {n_cols} columns")
+
+ return analysis
+
+ except Exception as e:
+ if self.verbose:
+ print(f"Error analyzing dataset: {str(e)}")
+
+ return {
+ "error": str(e),
+ "filepath": dataset_path,
+ "n_rows": 0,
+ "n_cols": 0,
+ "columns": [],
+ "column_types": {},
+ "column_categories": {},
+ "temporal_structure": {"has_temporal_structure": False},
+ "variable_relationships": {"potential_confounders": []},
+ "potential_instruments": [],
+ "discontinuities": {"has_discontinuities": False}
+ }
+
+ def _categorize_columns(self, df: pd.DataFrame) -> Dict[str, str]:
+ """
+ Categorize columns by data type.
+
+ Args:
+ df: Pandas DataFrame
+
+ Returns:
+ Dictionary mapping column names to categories
+ """
+ categories = {}
+ for col in df.columns:
+ if df[col].dtype == 'bool':
+ categories[col] = 'binary'
+ elif pd.api.types.is_numeric_dtype(df[col]):
+ if len(df[col].unique()) <= 2:
+ categories[col] = 'binary'
+ else:
+ categories[col] = 'continuous'
+ else:
+ unique_values = df[col].nunique()
+ if unique_values <= 2:
+ categories[col] = 'binary'
+ elif unique_values <= 10:
+ categories[col] = 'categorical'
+ else:
+ categories[col] = 'high_cardinality'
+
+ return categories
+
+ def _check_temporal_structure(self, df: pd.DataFrame) -> Dict[str, Any]:
+ """
+ Check for temporal structure in the dataset.
+
+ Args:
+ df: Pandas DataFrame
+
+ Returns:
+ Dictionary with temporal structure information
+ """
+ # Look for date/time columns
+ date_cols = [col for col in df.columns if
+ any(keyword in col.lower() for keyword in
+ ['date', 'time', 'year', 'month', 'day', 'period'])]
+
+ # Check for panel data structure
+ id_cols = [col for col in df.columns if
+ any(keyword in col.lower() for keyword in
+ ['id', 'group', 'entity', 'unit'])]
+
+ return {
+ "has_temporal_structure": len(date_cols) > 0,
+ "is_panel_data": len(date_cols) > 0 and len(id_cols) > 0,
+ "time_variables": date_cols,
+ "id_variables": id_cols
+ }
+
+ def _identify_relationships(self, df: pd.DataFrame) -> Dict[str, List[str]]:
+ """
+ Identify potential variable relationships.
+
+ Args:
+ df: Pandas DataFrame
+
+ Returns:
+ Dictionary with relationship information
+ """
+ # This is a simplified implementation
+ # A real implementation would use statistical tests or causal discovery
+
+ return {
+ "potential_confounders": []
+ }
+
+ def _identify_potential_instruments(self, df: pd.DataFrame) -> List[str]:
+ """
+ Identify potential instrumental variables.
+
+ Args:
+ df: Pandas DataFrame
+
+ Returns:
+ List of potential instrumental variables
+ """
+ # This is a simplified implementation
+ # A real implementation would use statistical tests
+
+ # Look for variables that might be instruments based on naming
+ potential_instruments = [col for col in df.columns if
+ any(keyword in col.lower() for keyword in
+ ['instrument', 'random', 'assignment', 'iv'])]
+
+ return potential_instruments
+
+ def _check_discontinuities(self, df: pd.DataFrame) -> Dict[str, Any]:
+ """
+ Check for potential discontinuities for RDD.
+
+ Args:
+ df: Pandas DataFrame
+
+ Returns:
+ Dictionary with discontinuity information
+ """
+ # This is a simplified implementation
+ # A real implementation would use statistical tests
+
+ return {
+ "has_discontinuities": False,
+ "potential_running_variables": []
+ }
\ No newline at end of file
diff --git a/auto_causal/tools/dataset_analyzer_tool.py b/auto_causal/tools/dataset_analyzer_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e282ef4254864859805bc7cfaa05ead6d5b763a
--- /dev/null
+++ b/auto_causal/tools/dataset_analyzer_tool.py
@@ -0,0 +1,112 @@
+"""
+Tool for analyzing datasets for causal inference.
+
+This module provides a LangChain tool for analyzing datasets to detect
+characteristics relevant for causal inference, such as temporal structure,
+potential instrumental variables, and variable relationships.
+"""
+
+from typing import Dict, Any, Optional
+from langchain.tools import tool
+import logging
+
+from auto_causal.components.dataset_analyzer import analyze_dataset
+from auto_causal.components.state_manager import create_workflow_state_update
+from langchain_core.language_models import BaseChatModel
+
+from auto_causal.config import get_llm_client
+
+# Import the required Pydantic models
+from auto_causal.models import DatasetAnalysis, DatasetAnalyzerOutput
+from auto_causal import models
+
+logger = logging.getLogger(__name__)
+
+
+@tool
+def dataset_analyzer_tool(dataset_path: str,
+ dataset_description: Optional[str] = None,
+ original_query: Optional[str] = None) -> DatasetAnalyzerOutput:
+ """
+ Analyze dataset to identify important characteristics for causal inference.
+
+ This tool loads the dataset, calculates summary statistics, checks for temporal
+ structure, identifies potential treatments/outcomes/instruments, and assesses
+ variable relationships relevant for selecting a causal method.
+
+ Args:
+ dataset_path: Path to the dataset file.
+ dataset_description: Optional description string from input.
+ llm: Optional LLM client for enhanced analysis.
+
+ Returns:
+ A Pydantic model containing the structured dataset analysis results and workflow state.
+ """
+ logger.info(f"Running dataset_analyzer_tool on path: {dataset_path}")
+ # Call the component function with the LLM if available
+ llm = get_llm_client()
+
+ try:
+ # Call the component function
+ analysis_dict = analyze_dataset(dataset_path, llm_client=llm, dataset_description=dataset_description, original_query=original_query)
+
+ # Check for errors returned explicitly by the component
+ if isinstance(analysis_dict, dict) and "error" in analysis_dict:
+ logger.error(f"Dataset analysis component failed: {analysis_dict['error']}")
+ raise ValueError(analysis_dict['error'])
+
+ # Validate and structure the analysis using Pydantic
+ # This assumes analyze_dataset returns a dict compatible with DatasetAnalysis
+ # Handle potential missing keys or type mismatches gracefully
+ analysis_results_model = DatasetAnalysis(**analysis_dict)
+
+ except Exception as e:
+ logger.error(f"Error during dataset analysis or Pydantic model creation: {e}", exc_info=True)
+ error_state = create_workflow_state_update(
+ current_step="data_analysis",
+ step_completed_flag=False,
+ next_tool="dataset_analyzer_tool", # Retry or error handler?
+ next_step_reason=f"Dataset analysis failed: {e}"
+ )
+
+ minimal_info = models.DatasetInfo(num_rows=0, num_columns=0, file_path=dataset_path, file_name="unknown")
+ empty_temporal = models.TemporalStructure(has_temporal_structure=False, temporal_columns=[], is_panel_data=False)
+ error_analysis = models.DatasetAnalysis(
+ dataset_info=minimal_info,
+ columns=[],
+ potential_treatments=[],
+ potential_outcomes=[],
+ temporal_structure_detected=False,
+ panel_data_detected=False,
+ potential_instruments_detected=False,
+ discontinuities_detected=False,
+ temporal_structure=empty_temporal,
+ sample_size=0,
+ num_covariates_estimate=0
+ )
+ return DatasetAnalyzerOutput(
+ analysis_results=error_analysis,
+ dataset_description=dataset_description,
+ workflow_state=error_state.get('workflow_state', {})
+ )
+
+ # Create workflow state update for success
+ workflow_update = create_workflow_state_update(
+ current_step="data_analysis",
+ step_completed_flag="dataset_analyzed",
+ next_tool="query_interpreter_tool",
+ next_step_reason="Now we need to map query concepts to actual dataset variables"
+ )
+
+ # Construct the final Pydantic output object
+ output = DatasetAnalyzerOutput(
+ analysis_results=analysis_results_model,
+ dataset_description=dataset_description,
+ dataset_path=dataset_path,
+ workflow_state=workflow_update.get('workflow_state', {})
+ )
+
+ # print(output)
+
+ logger.info("dataset_analyzer_tool finished successfully.")
+ return output
\ No newline at end of file
diff --git a/auto_causal/tools/explanation_generator_tool.py b/auto_causal/tools/explanation_generator_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..b75acb6b4d18f08bfa4a0fa1d0f17856be13ba3e
--- /dev/null
+++ b/auto_causal/tools/explanation_generator_tool.py
@@ -0,0 +1,149 @@
+"""
+Explanation generator tool for causal inference methods.
+
+This tool generates explanations for the selected causal inference method,
+including what the method does, its assumptions, and how it will be applied.
+"""
+
+from typing import Dict, Any, Optional, List, Union
+from langchain.tools import tool
+import logging
+
+from auto_causal.components.explanation_generator import generate_explanation
+from auto_causal.components.state_manager import create_workflow_state_update
+from auto_causal.config import get_llm_client
+
+# Import shared models from central location
+from auto_causal.models import (
+ Variables,
+ TemporalStructure, # Needed indirectly by DatasetAnalysis
+ DatasetInfo, # Needed indirectly by DatasetAnalysis
+ DatasetAnalysis,
+ MethodInfo,
+ ExplainerInput # Keep for type hinting arguments
+)
+
+logger = logging.getLogger(__name__)
+
+# --- Removed local Pydantic definitions ---
+# class Variables(BaseModel): ...
+# class TemporalStructure(BaseModel): ...
+# class DatasetInfo(BaseModel): ...
+# class DatasetAnalysis(BaseModel): ...
+# class MethodInfo(BaseModel): ...
+# class ExplainerInput(BaseModel): ...
+
+# --- Tool Definition ---
+@tool(args_schema=ExplainerInput)
+# Change signature to accept individual arguments
+def explanation_generator_tool(
+ method_info: MethodInfo,
+ variables: Variables,
+ results: Dict[str, Any],
+ dataset_analysis: DatasetAnalysis,
+ validation_info: Optional[Dict[str, Any]] = None,
+ dataset_description: Optional[str] = None,
+ original_query: Optional[str] = None # Get original query if passed
+) -> Dict[str, Any]:
+ """
+ Generate a single comprehensive explanation string using structured Pydantic input.
+
+ Args:
+ method_info: Pydantic model with method details.
+ variables: Pydantic model with identified variables.
+ results: Dictionary containing numerical results from execution.
+ dataset_analysis: Pydantic model with dataset analysis results.
+ validation_info: Optional dictionary with validation results.
+ dataset_description: Optional string description of the dataset.
+ original_query: Optional original user query string.
+
+ Returns:
+ Dictionary with the final explanation text, context, and workflow state.
+ """
+ logger.info("Running explainer_tool with direct arguments...")
+
+ # Use arguments directly, dump models to dicts if needed by component
+ method_info_dict = method_info.model_dump()
+ print('------------------------')
+ print(method_info_dict)
+ print('------------------------')
+ validation_result_dict = validation_info # Already dict or None
+ variables_dict = variables.model_dump()
+ # results is already a dict
+ dataset_analysis_dict = dataset_analysis.model_dump()
+ # dataset_description is already str or None
+
+ # Include original_query in variables_dict if the component expects it there
+ if original_query:
+ variables_dict['original_query'] = original_query
+
+ # Get LLM instance if needed by generate_explanation
+ llm_instance = None
+ try:
+ llm_instance = get_llm_client()
+ except Exception as e:
+ logger.warning(f"Could not get LLM client for explainer: {e}")
+
+ # Call component to generate the single explanation string
+ try:
+ explanation_dict = generate_explanation(
+ method_info=method_info_dict,
+ validation_result=validation_result_dict,
+ variables=variables_dict,
+ results=results, # Pass results dict directly
+ dataset_analysis=dataset_analysis_dict,
+ dataset_description=dataset_description,
+ llm=llm_instance # Pass LLM if component uses it
+ )
+ if not isinstance(explanation_dict, dict):
+ raise TypeError(f"generate_explanation component did not return a dict. Got: {type(explanation_dict)}")
+
+ except Exception as e:
+ logger.error(f"Error during generate_explanation execution: {e}", exc_info=True)
+ # Provide missing args for the error state update
+ workflow_update = create_workflow_state_update(
+ current_step="result_explanation",
+ step_completed_flag=False,
+ error=f"Component failed: {e}",
+ next_tool="explanation_generator_tool", # Indicate failed tool
+ next_step_reason=f"Explanation generation component failed: {e}" # Provide reason
+ )
+ # Return structure consistent with success case, but with error info
+ return {
+ "error": f"Explanation generation component failed: {e}",
+ # Pass necessary context for potential retry or next step
+ "query": original_query or "N/A",
+ "method": method_info_dict.get('selected_method', "N/A"),
+ "results": results, # Include results even if explanation failed
+ "explanation": {"error": str(e)}, # Include error in explanation part
+ "dataset_analysis": dataset_analysis_dict,
+ "dataset_description": dataset_description,
+ **workflow_update.get('workflow_state', {})
+ }
+
+ # Create workflow state update
+ workflow_update = create_workflow_state_update(
+ current_step="result_explanation",
+ step_completed_flag="results_explained",
+ next_tool="output_formatter_tool", # Step 8: Format output
+ next_step_reason="Finally, we need to format the output for presentation"
+ )
+
+ # Prepare result dict for the next tool (formatter)
+ result_for_formatter = {
+ # Pass the necessary pieces for the formatter
+ "query": original_query or "N/A", # Use original_query directly
+ "method": method_info_dict.get('selected_method', 'N/A'),
+ "results": results, # Pass the numerical results directly
+ "explanation": explanation_dict, # Pass the structured explanation
+ # Avoid passing full analysis if not needed by formatter? Check formatter needs.
+ # For now, keep them.
+ "dataset_analysis": dataset_analysis_dict,
+ "dataset_description": dataset_description
+ }
+
+ # Add workflow state to the result
+ result_for_formatter.update(workflow_update)
+
+ logger.info("explanation_generator_tool finished successfully.")
+ return result_for_formatter
\ No newline at end of file
diff --git a/auto_causal/tools/input_parser_tool.py b/auto_causal/tools/input_parser_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..79792fe33e20b1748079772fee82f816fc9f61e6
--- /dev/null
+++ b/auto_causal/tools/input_parser_tool.py
@@ -0,0 +1,103 @@
+"""
+Tool for parsing causal inference queries.
+
+This module provides a LangChain tool for parsing causal inference queries,
+extracting key elements, and guiding the workflow to the next step.
+"""
+
+import logging
+import re
+from typing import Dict, Any, Optional
+from langchain_core.tools import tool
+
+from auto_causal.components.input_parser import parse_input
+from auto_causal.config import get_llm_client
+from auto_causal.components.state_manager import create_workflow_state_update
+import json
+logger = logging.getLogger(__name__)
+
+@tool
+def input_parser_tool(input_text: str) -> Dict[str, Any]:
+ """
+ Parse the user's initial input text to extract query, dataset path, and description.
+
+ This tool uses regex to find structured information within the input text
+ and then leverages an LLM for more complex NLP tasks on the extracted query.
+
+ Args:
+ input_text: The combined initial input string from the user/system.
+
+ Returns:
+ Dict containing parsed query information, path, description, and workflow state.
+ """
+ logger.info(f"Running input_parser_tool on input: '{input_text[:100]}...'")
+
+ # --- Extract structured info using Regex ---
+ query = None
+ dataset_path = None
+ dataset_description = None
+
+ query_match = re.search(r"My question is: (.*?)\n", input_text, re.IGNORECASE)
+ if query_match:
+ query = query_match.group(1).strip()
+
+ path_match = re.search(r"The dataset is located at: (.*?)\n", input_text, re.IGNORECASE)
+ if path_match:
+ dataset_path = path_match.group(1).strip()
+
+ # Use re.search to find the description potentially anywhere after its label
+ desc_match = re.search(r"Dataset Description: (.*)", input_text, re.DOTALL | re.IGNORECASE)
+ if desc_match:
+ # Strip leading/trailing whitespace/newlines from the captured group
+ dataset_description = desc_match.group(1).strip()
+
+ if not query:
+ logger.warning("Could not extract query from input_text using regex. Attempting full text as query.")
+ # Fallback: This is risky if input_text contains boilerplate
+ query = input_text
+
+ logger.info(f"Extracted - Query: '{query[:50]}...', Path: '{dataset_path}', Desc: '{str(dataset_description)[:50]}...'")
+
+ # --- Get LLM and Parse Query ---
+ try:
+ llm_instance = get_llm_client()
+ except Exception as e:
+ logger.error(f"Failed to initialize LLM for input_parser_tool: {e}")
+ return {"error": f"LLM Initialization failed: {e}", "workflow_state": {}}
+
+ # Call the component function to parse the extracted query
+ try:
+ parsed_info = parse_input(
+ query=query,
+ dataset_path_arg=dataset_path, # Use extracted path
+ dataset_info=None, # This arg seems unused by parse_input now
+ llm=llm_instance
+ )
+ except Exception as e:
+ logger.error(f"Error during parse_input execution: {e}", exc_info=True)
+ return {"error": f"Input parsing failed: {e}", "workflow_state": {}}
+
+ # Create workflow state update
+ workflow_update = create_workflow_state_update(
+ current_step="input_processing",
+ step_completed_flag="query_parsed",
+ next_tool="dataset_analyzer_tool",
+ next_step_reason="Now that we understand the query, we need to analyze the dataset structure"
+ )
+
+ # Combine results with workflow state
+ result = {
+ "original_query": parsed_info.get("original_query", query), # Fallback to regex query
+ "dataset_path": parsed_info.get("dataset_path") or dataset_path, # Use extracted if component missed it
+ "query_type": parsed_info.get("query_type"),
+ "extracted_variables": parsed_info.get("extracted_variables", {}),
+ "constraints": parsed_info.get("constraints", []),
+ # Pass dataset_description along
+ "dataset_description": dataset_description
+ }
+ print('before workflow: ', result)
+ # Add workflow state to the result
+ result.update(workflow_update)
+ print('after workflow: ', result)
+ logger.info("input_parser_tool finished successfully.")
+ return result
\ No newline at end of file
diff --git a/auto_causal/tools/method_executor_tool.py b/auto_causal/tools/method_executor_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..92edefc5f9d8192a63820c3117ccfd39ddaaaae7
--- /dev/null
+++ b/auto_causal/tools/method_executor_tool.py
@@ -0,0 +1,216 @@
+"""
+Method Executor Tool for the causal inference agent.
+
+Executes the selected causal inference method using its implementation function.
+"""
+
+import pandas as pd
+from typing import Dict, Any, Optional, List, Union
+from langchain.tools import tool
+import traceback # For error logging
+import logging # Add logging
+
+# Import the mapping and potentially preprocessing utils
+from auto_causal.methods import METHOD_MAPPING
+from auto_causal.methods.utils import preprocess_data # Assuming preprocess exists
+from auto_causal.components.state_manager import create_workflow_state_update
+from auto_causal.config import get_llm_client # IMPORT LLM Client Factory
+
+# Import shared models from central location
+from auto_causal.models import (
+ Variables,
+ TemporalStructure, # Needed indirectly by DatasetAnalysis
+ DatasetInfo, # Needed indirectly by DatasetAnalysis
+ DatasetAnalysis,
+ MethodExecutorInput
+)
+
+# Add this module-level variable, typically near imports or at the top
+CURRENT_OUTPUT_LOG_FILE = None
+
+logger = logging.getLogger(__name__)
+
+@tool
+def method_executor_tool(inputs: MethodExecutorInput, original_query: Optional[str] = None) -> Dict[str, Any]: # Use Pydantic Input
+ '''Execute the selected causal inference method function using structured input.
+
+ Args:
+ inputs: Pydantic model containing method, variables, dataset_path,
+ dataset_analysis, and dataset_description.
+
+ Returns:
+ Dict with numerical results, context for next step, and workflow state.
+ '''
+ # Access data from input model
+ method = inputs.method
+ variables_dict = inputs.variables.model_dump()
+ dataset_path = inputs.dataset_path
+ dataset_analysis_dict = inputs.dataset_analysis.model_dump()
+ dataset_description_str = inputs.dataset_description
+ validation_info = inputs.validation_info # Can be passed if needed
+
+ logger.info(f"Executing method: {method}")
+
+ try:
+ # --- Get LLM Instance ---
+ llm_instance = None
+ try:
+ llm_instance = get_llm_client()
+ except Exception as llm_e:
+ logger.warning(f"Could not get LLM client in method_executor_tool: {llm_e}. LLM-dependent features in method will be disabled.")
+
+ # 1. Load Data
+ if not dataset_path:
+ raise ValueError("Dataset path is missing.")
+ df = pd.read_csv(dataset_path)
+
+ # 2. Extract Key Variables needed by estimate_func signature
+ treatment = variables_dict.get("treatment_variable")
+ outcome = variables_dict.get("outcome_variable")
+ covariates = variables_dict.get("covariates", [])
+ query_str = original_query if original_query is not None else inputs.original_query
+
+ if not all([treatment, outcome]):
+ raise ValueError("Treatment or Outcome variable not found in 'variables' dict.")
+
+ # 3. Preprocess Data
+ required_cols_for_method = [treatment, outcome] + covariates
+ # Add method-specific required vars from the variables_dict
+ if method == "instrumental_variable" and variables_dict.get("instrument_variable"):
+ required_cols_for_method.append(variables_dict["instrument_variable"])
+ elif method == "regression_discontinuity" and variables_dict.get("running_variable"):
+ required_cols_for_method.append(variables_dict["running_variable"])
+
+ missing_df_cols = [col for col in required_cols_for_method if col not in df.columns]
+ if missing_df_cols:
+ raise ValueError(f"Dataset at {dataset_path} is missing required columns for method '{method}': {missing_df_cols}")
+
+ df_processed, updated_treatment, updated_outcome, updated_covariates, column_mappings = \
+ preprocess_data(df, treatment, outcome, covariates, verbose=False)
+
+ # 4. Get the correct method execution function
+ if method not in METHOD_MAPPING:
+ raise ValueError(f"Method '{method}' not found in METHOD_MAPPING.")
+ estimate_func = METHOD_MAPPING[method]
+
+ # 5. Execute the method
+ # Pass only necessary args from variables_dict as kwargs
+ # (e.g., instrument_variable, running_variable, cutoff_value, etc.)
+ # Avoid passing the entire variables_dict as estimate_func expects specific args
+ kwargs_for_method = {}
+ for key in ["instrument_variable", "time_variable", "group_variable",
+ "running_variable", "cutoff_value"]:
+ if key in variables_dict and variables_dict[key] is not None:
+ kwargs_for_method[key] = variables_dict[key]
+
+ # Add new fields from the Variables model (which is inputs.variables)
+ if hasattr(inputs, 'variables'): # ensure variables object exists on inputs
+ if inputs.variables.treatment_reference_level is not None:
+ kwargs_for_method['treatment_reference_level'] = inputs.variables.treatment_reference_level
+ if inputs.variables.interaction_term_suggested is not None: # boolean, so check for None to allow False
+ kwargs_for_method['interaction_term_suggested'] = inputs.variables.interaction_term_suggested
+ if inputs.variables.interaction_variable_candidate is not None:
+ kwargs_for_method['interaction_variable_candidate'] = inputs.variables.interaction_variable_candidate
+
+ # Add query if needed by llm_assist functions within the method
+ kwargs_for_method['query'] = query_str
+ kwargs_for_method['column_mappings'] = column_mappings
+
+
+ results_dict = estimate_func(
+ df=df_processed,
+ treatment=updated_treatment,
+ outcome=updated_outcome,
+ covariates=updated_covariates,
+ dataset_description=dataset_description_str,
+ query_str=query_str,
+ llm=llm_instance,
+ **kwargs_for_method # Pass specific args needed by the method
+ )
+
+ # 6. Prepare output
+ logger.info(f"Method execution successful. Effect estimate: {results_dict.get('effect_estimate')}")
+
+ # Add workflow state
+ workflow_update = create_workflow_state_update(
+ current_step="method_execution",
+ step_completed_flag="method_executed",
+ next_tool="explainer_tool",
+ next_step_reason="Now we need to explain the results and their implications"
+ )
+
+ # --- Prepare Output Dictionary ---
+ # Structure required by explainer_tool: context + nested "results"
+ final_output = {
+ # Nested dictionary for numerical results and diagnostics
+ "results": {
+ # Core estimation results (extracted from results_dict)
+ "effect_estimate": results_dict.get("effect_estimate"),
+ "confidence_interval": results_dict.get("confidence_interval"),
+ "standard_error": results_dict.get("standard_error"),
+ "p_value": results_dict.get("p_value"),
+ "method_used": results_dict.get("method_used"),
+ "llm_assumption_check": results_dict.get("llm_assumption_check"),
+ "raw_results": results_dict.get("raw_results"),
+ # Diagnostics and Refutation results
+ "diagnostics": results_dict.get("diagnostics"),
+ "refutation_results": results_dict.get("refutation_results")
+ },
+ # Top-level context to be passed along
+ "variables": variables_dict,
+ "dataset_analysis": dataset_analysis_dict,
+ "dataset_description": dataset_description_str,
+ "validation_info": validation_info, # Pass validation info
+ "original_query": inputs.original_query,
+ "column_mappings": column_mappings # Add column_mappings to the output
+ # Workflow state will be added next
+ }
+
+ # Add workflow state to the final output
+ final_output.update(workflow_update.get('workflow_state', {}))
+
+ # --- Logging logic (moved from output_formatter.py) ---
+ # Prepare a summary dict for logging
+ summary_keys = {"query", "method_used", "causal_effect", "standard_error", "confidence_interval"}
+ # Try to get these from the available context
+ summary_dict = {
+ "query": inputs.original_query if hasattr(inputs, 'original_query') else None,
+ "method_used": results_dict.get("method_used"),
+ "causal_effect": results_dict.get("effect_estimate"),
+ "standard_error": results_dict.get("standard_error"),
+ "confidence_interval": results_dict.get("confidence_interval")
+ }
+ print(f"summary_dict: {summary_dict}")
+ print(f"CURRENT_OUTPUT_LOG_FILE: {CURRENT_OUTPUT_LOG_FILE}")
+ if CURRENT_OUTPUT_LOG_FILE and summary_dict:
+ try:
+ import json
+ log_entry = {"type": "analysis_result", "data": summary_dict}
+ with open(CURRENT_OUTPUT_LOG_FILE, mode='a', encoding='utf-8') as log_file:
+ log_file.write('\n' + json.dumps(log_entry) + '\n')
+ except Exception as e:
+ print(f"[ERROR] method_executor_tool.py: Failed to write analysis results to log file '{CURRENT_OUTPUT_LOG_FILE}': {e}")
+
+ return final_output
+
+ except Exception as e:
+ error_message = f"Error executing method {method}: {str(e)}"
+ logger.error(error_message, exc_info=True)
+
+ # Return error state, include context if available
+ workflow_update = create_workflow_state_update(
+ current_step="method_execution",
+ step_completed_flag=False,
+ next_tool="explainer_tool", # Or error handler?
+ next_step_reason=f"Failed during method execution: {error_message}"
+ )
+ # Ensure error output still contains necessary context keys if possible
+ error_result = {"error": error_message,
+ "variables": variables_dict if 'variables_dict' in locals() else {},
+ "dataset_analysis": dataset_analysis_dict if 'dataset_analysis_dict' in locals() else {},
+ "dataset_description": dataset_description_str if 'dataset_description_str' in locals() else None,
+ "original_query": inputs.original_query if hasattr(inputs, 'original_query') else None,
+ "column_mappings": column_mappings if 'column_mappings' in locals() else {} # Also add to error output
+ }
+ error_result.update(workflow_update.get('workflow_state', {}))
+ return error_result
\ No newline at end of file
diff --git a/auto_causal/tools/method_selector_tool.py b/auto_causal/tools/method_selector_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab1d926e60c788d5956d488c774aa1a9b94bf8fe
--- /dev/null
+++ b/auto_causal/tools/method_selector_tool.py
@@ -0,0 +1,171 @@
+"""
+Method Selector Tool for selecting causal inference methods.
+
+This module provides a LangChain tool for selecting appropriate
+causal inference methods based on dataset characteristics and query details.
+"""
+
+import logging # Add logging
+from typing import Dict, List, Any, Optional, Union
+from langchain_core.tools import tool # Use langchain_core
+
+# Import component function and central LLM factory
+from auto_causal.components.decision_tree import rule_based_select_method # Rule-based
+from auto_causal.components.decision_tree_llm import DecisionTreeLLMEngine # LLM-based
+from auto_causal.config import get_llm_client # Updated import path
+from auto_causal.components.state_manager import create_workflow_state_update
+
+# Import shared models from central location
+from auto_causal.models import (
+ Variables,
+ DatasetAnalysis,
+ MethodSelectorInput # Still needed for args_schema
+)
+
+logger = logging.getLogger(__name__)
+
+@tool(args_schema=MethodSelectorInput)
+# Option 1: Modify signature to match args_schema fields
+def method_selector_tool(
+ variables: Variables,
+ dataset_analysis: DatasetAnalysis,
+ dataset_description: Optional[str] = None,
+ original_query: Optional[str] = None,
+ excluded_methods: Optional[List[str]] = None
+) -> Dict[str, Any]:
+ """
+ Select the most appropriate causal inference method based on structured input.
+
+ Applies decision logic based on dataset analysis and identified variables (including is_rct).
+
+ Args:
+ variables: Pydantic model containing identified variables (T, O, C, IV, RDD, is_rct, etc.).
+ dataset_analysis: Pydantic model containing results of dataset analysis.
+ dataset_description: Optional textual description of the dataset.
+ original_query: Optional original user query string.
+ excluded_methods: Optional list of method names to exclude from selection.
+
+ Returns:
+ Dictionary with method selection details, context for next step, and workflow state.
+ """
+ logger.info("Running method_selector_tool with individual args...")
+
+ # Access data directly from arguments (they are already Pydantic models)
+ variables_model = variables
+ dataset_analysis_model = dataset_analysis
+ dataset_description_str = dataset_description
+ is_rct_flag = variables_model.is_rct # Get is_rct directly from variables argument
+
+ # Convert Pydantic models to dicts for the component call (select_method expects dicts)
+ variables_dict = variables_model.model_dump()
+ dataset_analysis_dict = dataset_analysis_model.model_dump()
+
+ # Basic validation
+ treatment = variables_dict.get("treatment_variable")
+ outcome = variables_dict.get("outcome_variable")
+ if not all([treatment, outcome]):
+ logger.error("Missing treatment or outcome variable in input.")
+ # Construct error output, including passed-along context
+ workflow_update = create_workflow_state_update(
+ current_step="method_selection",
+ step_completed_flag=False,
+ next_tool="method_selector_tool",
+ next_step_reason="Missing treatment/outcome variable in input",
+ error="Missing treatment/outcome variable in input"
+ )
+ # Use model_dump() for analysis dict
+ return { "error": "Missing treatment/outcome",
+ "variables": variables_dict,
+ "dataset_analysis": dataset_analysis_dict,
+ "dataset_description": dataset_description_str,
+ **workflow_update.get('workflow_state', {})}
+
+ # Get LLM instance (optional for component)
+ try:
+ llm_instance = get_llm_client()
+ except Exception as e:
+ logger.warning(f"Failed to initialize LLM for method_selector_tool: {e}. Proceeding without LLM features.")
+ llm_instance = None
+
+ # --- Configuration for switching ---
+ USE_LLM_DECISION_TREE = False # Set to False to use the original rule-based tree
+
+ # Call the component function
+ try:
+ if USE_LLM_DECISION_TREE:
+ logger.info("Using LLM-based Decision Tree Engine for method selection.")
+ if not llm_instance:
+ logger.warning("LLM instance is required for DecisionTreeLLMEngine but not available. Falling back to rule-based or error.")
+ # Potentially raise an error or explicitly call rule-based here if LLM is mandatory for this path
+ # For now, it will proceed and DecisionTreeLLMEngine will handle the missing llm
+ llm_engine = DecisionTreeLLMEngine(verbose=True) # You can set verbosity as needed
+ method_selection_dict = llm_engine.select_method_llm(
+ dataset_analysis=dataset_analysis_dict,
+ variables=variables_dict,
+ is_rct=is_rct_flag if isinstance(is_rct_flag, bool) else False,
+ llm=llm_instance,
+ excluded_methods=excluded_methods
+ )
+ else:
+ logger.info("Using Rule-based Decision Tree Engine for method selection.")
+ # Pass dicts and the is_rct flag
+ method_selection_dict = rule_based_select_method(
+ dataset_analysis=dataset_analysis_dict,
+ variables=variables_dict,
+ is_rct=is_rct_flag if isinstance(is_rct_flag, bool) else False, # Handle None case
+ llm=llm_instance,
+ dataset_description = dataset_description,
+ original_query = original_query,
+ excluded_methods = excluded_methods
+ )
+ except Exception as e:
+ logger.error(f"Error during method selection execution: {e}", exc_info=True)
+ # Construct error output
+ workflow_update = create_workflow_state_update(
+ current_step="method_selection",
+ step_completed_flag=False,
+ next_tool="error_handler_tool",
+ next_step_reason=f"Component failed: {e}",
+ error=f"Component failed: {e}"
+ )
+ return { "error": f"Method selection logic failed: {e}",
+ "variables": variables_dict,
+ "dataset_analysis": dataset_analysis_dict,
+ "dataset_description": dataset_description_str,
+ **workflow_update.get('workflow_state', {})}
+
+ # --- Prepare Output Dictionary ---
+ method_selected_flag = bool(method_selection_dict.get("selected_method") and method_selection_dict["selected_method"] != "Error")
+
+ # Create the 'method_info' sub-dictionary required by the validator
+ # Include alternative_methods if present in the selection output
+ method_info = {
+ "selected_method": method_selection_dict.get("selected_method"),
+ "method_name": method_selection_dict.get("selected_method", "").replace("_", " ").title() if method_selected_flag else None,
+ "method_justification": method_selection_dict.get("method_justification"),
+ "method_assumptions": method_selection_dict.get("method_assumptions", []),
+ "alternative_methods": method_selection_dict.get("alternatives", []) # Include alternatives
+ }
+
+ # Create the final output dictionary for the agent
+ result = {
+ "method_info": method_info,
+ "variables": variables_dict,
+ "dataset_analysis": dataset_analysis_dict,
+ "dataset_description": dataset_description_str,
+ "original_query": original_query # Pass original query argument
+ }
+
+ # Determine workflow state for the next step
+ next_tool_name = "method_validator_tool" if method_selected_flag else "error_handler_tool"
+ next_reason = "Now we need to validate the assumptions of the selected method" if method_selected_flag else "Method selection failed or returned an error."
+ workflow_update = create_workflow_state_update(
+ current_step="method_selection",
+ step_completed_flag=method_selected_flag,
+ next_tool=next_tool_name,
+ next_step_reason=next_reason
+ )
+ result.update(workflow_update.get('workflow_state', {})) # Add workflow state dict
+
+ logger.info(f"method_selector_tool finished. Selected: {method_info.get('selected_method')}")
+ return result
\ No newline at end of file
diff --git a/auto_causal/tools/method_validator_tool.py b/auto_causal/tools/method_validator_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b05d65d0aed75cabcfb4f3b0092985ed98a3e5
--- /dev/null
+++ b/auto_causal/tools/method_validator_tool.py
@@ -0,0 +1,181 @@
+"""
+Method validator tool for causal inference methods.
+
+This tool validates the selected causal inference method against
+dataset characteristics and available variables.
+"""
+
+from typing import Dict, Any, Optional, List, Union
+from langchain.tools import tool
+import logging
+
+from auto_causal.components.method_validator import validate_method
+from auto_causal.components.state_manager import create_workflow_state_update
+from auto_causal.components.decision_tree import rule_based_select_method
+
+# Import shared models from central location
+from auto_causal.models import (
+ Variables,
+ TemporalStructure, # Needed indirectly by DatasetAnalysis
+ DatasetInfo, # Needed indirectly by DatasetAnalysis
+ DatasetAnalysis,
+ MethodInfo,
+ MethodValidatorInput
+)
+
+logger = logging.getLogger(__name__)
+
+def extract_properties_from_inputs(inputs: MethodValidatorInput) -> Dict[str, Any]:
+ """
+ Helper function to extract dataset properties from MethodValidatorInput
+ for use with the decision tree.
+ """
+ variables_dict = inputs.variables.model_dump()
+ dataset_analysis_dict = inputs.dataset_analysis.model_dump()
+
+ return {
+ "treatment_variable": variables_dict.get("treatment_variable"),
+ "outcome_variable": variables_dict.get("outcome_variable"),
+ "instrument_variable": variables_dict.get("instrument_variable"),
+ "covariates": variables_dict.get("covariates", []),
+ "time_variable": variables_dict.get("time_variable"),
+ "running_variable": variables_dict.get("running_variable"),
+ "treatment_variable_type": variables_dict.get("treatment_variable_type", "binary"),
+ "has_temporal_structure": dataset_analysis_dict.get("temporal_structure", {}).get("has_temporal_structure", False),
+ "frontdoor_criterion": variables_dict.get("frontdoor_criterion", False),
+ "cutoff_value": variables_dict.get("cutoff_value"),
+ "covariate_overlap_score": variables_dict.get("covariate_overlap_result", 0),
+ "is_rct": variables_dict.get("is_rct", False)
+ }
+
+# --- Removed local Pydantic definitions ---
+# class Variables(BaseModel): ...
+# class TemporalStructure(BaseModel): ...
+# class DatasetInfo(BaseModel): ...
+# class DatasetAnalysis(BaseModel): ...
+# class MethodInfo(BaseModel): ...
+# class MethodValidatorInput(BaseModel): ...
+
+# --- Tool Definition ---
+@tool
+def method_validator_tool(inputs: MethodValidatorInput) -> Dict[str, Any]: # Use Pydantic Input
+ """
+ Validate the assumptions of the selected causal method using structured input.
+
+ Args:
+ inputs: Pydantic model containing method_info, dataset_analysis, variables, and dataset_description.
+
+ Returns:
+ Dictionary with validation results, context for next step, and workflow state.
+ """
+ logger.info(f"Running method_validator_tool for method: {inputs.method_info.selected_method}")
+
+ # Access data from input model (converting to dicts for component)
+ method_info_dict = inputs.method_info.model_dump()
+ dataset_analysis_dict = inputs.dataset_analysis.model_dump()
+ variables_dict = inputs.variables.model_dump()
+ dataset_description_str = inputs.dataset_description
+
+ # Call the component function to validate the method
+ try:
+ validation_results = validate_method(method_info_dict, dataset_analysis_dict, variables_dict)
+ if not isinstance(validation_results, dict):
+ raise TypeError(f"validate_method component did not return a dict. Got: {type(validation_results)}")
+
+ except Exception as e:
+ logger.error(f"Error during validate_method execution: {e}", exc_info=True)
+ # Construct error output
+ workflow_update = create_workflow_state_update(
+ current_step="method_validation", method_validated=False, error=f"Component failed: {e}"
+ )
+ # Pass context even on error
+ return {"error": f"Method validation component failed: {e}",
+ "variables": variables_dict,
+ "dataset_analysis": dataset_analysis_dict,
+ "dataset_description": dataset_description_str,
+ **workflow_update.get('workflow_state', {})}
+
+ # Determine if assumptions are valid based on component output
+ assumptions_valid = validation_results.get("valid", False)
+ failed_assumptions = validation_results.get("concerns", [])
+ original_method = method_info_dict.get("selected_method")
+ recommended_method = validation_results.get("recommended_method", original_method)
+
+ # If validation failed, attempt to backtrack through decision tree
+ if not assumptions_valid and failed_assumptions:
+ logger.info(f"Method {original_method} failed validation due to: {failed_assumptions}")
+ logger.info("Attempting to backtrack and select alternative method...")
+
+ try:
+ # Extract properties for decision tree
+ dataset_props = extract_properties_from_inputs(inputs)
+
+ # Get LLM instance (may be None)
+ from auto_causal.config import get_llm_client
+ try:
+ llm_instance = get_llm_client()
+ except Exception as e:
+ logger.warning(f"Failed to get LLM instance: {e}")
+ llm_instance = None
+
+ # Re-run decision tree with failed method excluded
+ excluded_methods = [original_method]
+ new_selection = rule_based_select_method(
+ dataset_analysis=inputs.dataset_analysis.model_dump(),
+ variables=inputs.variables.model_dump(),
+ is_rct=inputs.variables.is_rct or False,
+ llm=llm_instance,
+ dataset_description=inputs.dataset_description,
+ original_query=inputs.original_query,
+ excluded_methods=excluded_methods
+ )
+
+ recommended_method = new_selection.get("selected_method", original_method)
+ logger.info(f"Backtracking selected new method: {recommended_method}")
+
+ # Update validation results to include backtracking info
+ validation_results["backtrack_attempted"] = True
+ validation_results["backtrack_method"] = recommended_method
+ validation_results["excluded_methods"] = excluded_methods
+
+ except Exception as e:
+ logger.error(f"Backtracking failed: {e}")
+ validation_results["backtrack_attempted"] = True
+ validation_results["backtrack_error"] = str(e)
+ # Keep original recommended method
+
+ # Prepare output dictionary for the next tool (method_executor)
+ result = {
+ # --- Data for Method Executor ---
+ "method": recommended_method, # Use recommended method going forward
+ "variables": variables_dict, # Pass along all identified variables
+ "dataset_path": dataset_analysis_dict.get('dataset_info',{}).get('file_path'), # Extract path
+ "dataset_analysis": dataset_analysis_dict, # Pass full analysis
+ "dataset_description": dataset_description_str, # Pass description string
+ "original_query": inputs.original_query, # Pass original query
+
+ # --- Validation Results ---
+ "validation_info": {
+ "original_method": method_info_dict.get("selected_method"),
+ "recommended_method": recommended_method,
+ "assumptions_valid": assumptions_valid,
+ "failed_assumptions": failed_assumptions,
+ "warnings": validation_results.get("warnings", []),
+ "suggestions": validation_results.get("suggestions", [])
+ }
+ }
+
+ # Determine workflow state
+ method_validated_flag = assumptions_valid # Or perhaps always True if validation ran?
+ next_tool_name = "method_executor_tool" if method_validated_flag else "error_handler_tool" # Go to executor even if assumptions failed?
+ next_reason = "Method assumptions checked. Proceeding to execution." if method_validated_flag else "Method assumptions failed validation."
+ workflow_update = create_workflow_state_update(
+ current_step="method_validation",
+ step_completed_flag=method_validated_flag,
+ next_tool=next_tool_name,
+ next_step_reason=next_reason
+ )
+ result.update(workflow_update) # Add workflow state
+
+ logger.info(f"method_validator_tool finished. Assumptions valid: {assumptions_valid}")
+ return result
\ No newline at end of file
diff --git a/auto_causal/tools/output_formatter_tool.py b/auto_causal/tools/output_formatter_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..cca9cac8c7cd76f7b5e84dc6c13b4b083ab04bd7
--- /dev/null
+++ b/auto_causal/tools/output_formatter_tool.py
@@ -0,0 +1,107 @@
+"""
+Output formatter tool for causal inference results.
+
+This tool provides the LangChain interface for the output formatter component.
+"""
+
+# REVERT Pydantic approach for this tool temporarily
+
+from typing import Dict, Any, Optional#, List, Union # Keep only needed
+# from pydantic import BaseModel, Field # REVERT
+import logging
+import json # Ensure json is imported
+
+# Add import for @tool decorator
+from langchain.tools import tool
+
+from auto_causal.components import output_formatter
+# Import the Pydantic model returned by the component
+from auto_causal.models import FormattedOutput
+
+# --- REVERT: Remove Pydantic Model Definitions ---
+# class Variables(BaseModel):
+# ... (Remove all re-defined models)
+# class OutputFormatterInput(BaseModel):
+# ... (Remove definition)
+
+# --- Tool Definition ---
+logger = logging.getLogger(__name__)
+
+@tool
+# REVERT to original signature with individual arguments
+def output_formatter_tool(
+ query: str,
+ method: str,
+ results: Dict[str, Any], # Output from method_executor_tool
+ explanation: Dict[str, Any], # Output from explainer_tool
+ dataset_analysis: Optional[Dict[str, Any]] = None, # Use Dict
+ dataset_description: Optional[str] = None
+) -> Dict[str, Any]:
+ """
+ Formats the final explanation and results using the output_formatter component,
+ packages it into a dictionary, adds workflow state, and a JSON representation.
+
+ Args:
+ query: Original user query.
+ method: The method used (string name).
+ results: Numerical results dict from method_executor_tool.
+ explanation: Structured explanation dict from explainer_tool.
+ dataset_analysis: Optional results from dataset_analyzer_tool.
+ dataset_description: Optional initial description string.
+
+ Returns:
+ Dict containing the formatted output fields, workflow state, and a JSON string.
+ """
+ logger.info("Running output_formatter_tool...")
+
+ try:
+ # Call component function - it now returns a FormattedOutput Pydantic model
+ formatted_output_model: FormattedOutput = output_formatter.format_output(
+ query=query,
+ method=method,
+ results=results,
+ explanation=explanation, # Pass explanation dict directly
+ # Pass analysis dict directly, handle None case for component
+ dataset_analysis=dataset_analysis if dataset_analysis else None,
+ dataset_description=dataset_description
+ )
+
+ # Convert the Pydantic model back to a dictionary for tool output
+ # Use model_dump() for Pydantic v2+, or .dict() for v1
+ try:
+ # Attempt model_dump first (Pydantic v2)
+ formatted_output_dict = formatted_output_model.model_dump(mode='json') # mode='json' handles complex types
+ except AttributeError:
+ # Fallback to dict() (Pydantic v1)
+ formatted_output_dict = formatted_output_model.dict()
+
+ # Generate JSON representation of the dictionary
+ try:
+ # Exclude workflow_state if it accidentally got included in the model dump
+ dict_for_json = {k: v for k, v in formatted_output_dict.items() if k != 'workflow_state'}
+ json_output_str = json.dumps(dict_for_json, indent=4)
+ formatted_output_dict["json_output"] = json_output_str
+ except TypeError as json_err:
+ logger.error(f"Failed to serialize output to JSON: {json_err}")
+ formatted_output_dict["json_output"] = f'{{"error": "Failed to serialize output to JSON: {json_err}"}}'
+
+ # Add workflow state information - analysis is complete
+ formatted_output_dict["workflow_state"] = {
+ "current_step": "output_formatting",
+ "analysis_complete": True
+ }
+
+ logger.info("Output formatting successful.")
+ return formatted_output_dict # Return the final dictionary
+
+ except Exception as e:
+ logger.error(f"Error during output formatting: {e}", exc_info=True)
+ # Return error structure
+ return {
+ "error": f"Failed to format output: {e}",
+ "workflow_state": {
+ "current_step": "output_formatting",
+ "analysis_complete": False, # Indicate failure
+ "error": f"Formatting component failed: {e}"
+ }
+ }
\ No newline at end of file
diff --git a/auto_causal/tools/query_interpreter_tool.py b/auto_causal/tools/query_interpreter_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f01133546f23ddd526be0882c100c3c3af11ee5
--- /dev/null
+++ b/auto_causal/tools/query_interpreter_tool.py
@@ -0,0 +1,117 @@
+"""
+Tool for interpreting causal queries in the context of a dataset.
+
+This module provides a LangChain tool for matching query concepts to actual
+dataset variables, identifying treatment, outcome, and covariate variables.
+"""
+
+# Removed Pydantic import, will be imported via models
+# from pydantic import BaseModel, Field
+from typing import Dict, List, Any, Optional, Union # Keep Any, Dict for workflow_state
+import logging
+
+# Import shared Pydantic models from the central location
+from auto_causal.models import (
+ TemporalStructure,
+ DatasetInfo,
+ DatasetAnalysis,
+ QueryInfo,
+ QueryInterpreterInput,
+ Variables,
+ QueryInterpreterOutput
+)
+
+# --- Removed local Pydantic definitions ---
+# class TemporalStructure(BaseModel): ...
+# class DatasetInfo(BaseModel): ...
+# class DatasetAnalysis(BaseModel): ...
+# class QueryInfo(BaseModel): ...
+# class QueryInterpreterInput(BaseModel): ...
+# class Variables(BaseModel): ...
+# class QueryInterpreterOutput(BaseModel): ...
+
+logger = logging.getLogger(__name__)
+
+from langchain.tools import tool
+from auto_causal.components.query_interpreter import interpret_query
+from auto_causal.components.state_manager import create_workflow_state_update
+
+
+@tool()
+# Modify signature to accept individual Pydantic models/types as arguments
+def query_interpreter_tool(
+ query_info: QueryInfo,
+ dataset_analysis: DatasetAnalysis,
+ dataset_description: str,
+ original_query: Optional[str] = None # Keep optional original_query
+) -> QueryInterpreterOutput:
+ """
+ Interpret a causal query in the context of a specific dataset.
+
+ Args:
+ query_info: Pydantic model with parsed query information.
+ dataset_analysis: Pydantic model with dataset analysis results.
+ dataset_description: String description of the dataset.
+ original_query: The original user query string (optional).
+
+ Returns:
+ A Pydantic model containing identified variables (including is_rct), dataset analysis, description, and workflow state.
+ """
+ logger.info("Running query_interpreter_tool with direct arguments...")
+
+ # Use arguments directly, dump models to dicts for the component call
+ query_info_dict = query_info.model_dump()
+ dataset_analysis_dict = dataset_analysis.model_dump()
+ # dataset_description is already a string
+ # Call the component function
+ try:
+ # Assume interpret_query returns a dictionary compatible with Variables model
+ # AND that interpret_query now attempts to determine is_rct
+ interpretation_dict = interpret_query(query_info_dict, dataset_analysis_dict, dataset_description)
+ if not isinstance(interpretation_dict, dict):
+ raise TypeError(f"interpret_query component did not return a dictionary. Got: {type(interpretation_dict)}")
+
+ # Validate and structure the interpretation using Pydantic
+ # This will raise validation error if interpret_query didn't return expected fields
+ variables_output = Variables(**interpretation_dict)
+
+ except Exception as e:
+ logger.error(f"Error during query interpretation component call: {e}", exc_info=True)
+ workflow_update = create_workflow_state_update(
+ current_step="variable_identification",
+ step_completed_flag=False,
+ next_tool="query_interpreter_tool", # Or error handler
+ next_step_reason=f"Component execution failed: {e}"
+ )
+ error_vars = Variables()
+ # Use the passed dataset_analysis object directly in case of error
+ error_analysis = dataset_analysis
+ # Return Pydantic output even on error
+ return QueryInterpreterOutput(
+ variables=error_vars,
+ dataset_analysis=error_analysis,
+ dataset_description=dataset_description,
+ original_query=original_query, # Pass original query if available
+ workflow_state=workflow_update.get('workflow_state', {})
+ )
+
+ # Create workflow state update for success
+ workflow_update = create_workflow_state_update(
+ current_step="variable_identification",
+ step_completed_flag="variables_identified",
+ next_tool="method_selector_tool",
+ next_step_reason="Now that we have identified the variables, we can select an appropriate causal inference method"
+ )
+
+ # Construct the Pydantic output object
+ output = QueryInterpreterOutput(
+ variables=variables_output,
+ # Pass the original dataset_analysis Pydantic model
+ dataset_analysis=dataset_analysis,
+ dataset_description=dataset_description,
+ original_query=original_query, # Pass along original query
+ workflow_state=workflow_update.get('workflow_state', {}) # Extract state dict
+ )
+
+ logger.info("query_interpreter_tool finished successfully.")
+ return output
\ No newline at end of file
diff --git a/auto_causal/utils/__init__.py b/auto_causal/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0519ecba6ea913e21689ec692e81e9e4973fbf73
--- /dev/null
+++ b/auto_causal/utils/__init__.py
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/auto_causal/utils/agent.py b/auto_causal/utils/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..025a8b9a28245bb58abf5b51cd8fd12ec554f2e7
--- /dev/null
+++ b/auto_causal/utils/agent.py
@@ -0,0 +1,365 @@
+"""
+LangChain agent for the auto_causal module.
+
+This module configures a LangChain agent with specialized tools for causal inference,
+allowing for an interactive approach to analyzing datasets and applying appropriate
+causal inference methods.
+"""
+
+import logging
+from typing import Dict, List, Any, Optional
+from langchain.agents.react.agent import create_react_agent
+from langchain.agents import AgentExecutor, create_structured_chat_agent, create_tool_calling_agent
+from langchain.chains.conversation.memory import ConversationBufferMemory
+from langchain_core.messages import SystemMessage, HumanMessage
+from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
+from langchain.tools import tool
+# Import the callback handler
+from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
+# Import tool rendering utility
+from langchain.tools.render import render_text_description
+# Import LCEL components
+from langchain.agents.format_scratchpad.tools import format_to_tool_messages
+from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
+from langchain_core.runnables import RunnablePassthrough
+from langchain_core.language_models import BaseChatModel
+from langchain_anthropic.chat_models import convert_to_anthropic_tool
+# Import actual tools from the tools directory
+from auto_causal.tools.input_parser_tool import input_parser_tool
+from auto_causal.tools.dataset_analyzer_tool import dataset_analyzer_tool
+from auto_causal.tools.query_interpreter_tool import query_interpreter_tool
+from auto_causal.tools.method_selector_tool import method_selector_tool
+from auto_causal.tools.method_validator_tool import method_validator_tool
+from auto_causal.tools.method_executor_tool import method_executor_tool
+from auto_causal.tools.explanation_generator_tool import explanation_generator_tool
+from auto_causal.tools.output_formatter_tool import output_formatter_tool
+#from auto_causal.prompts import SYSTEM_PROMPT # Assuming SYSTEM_PROMPT is defined here or imported
+from langchain_core.output_parsers import StrOutputParser
+# Import the centralized factory function
+from .config import get_llm_client
+#from .prompts import SYSTEM_PROMPT
+from langchain_core.messages import AIMessage, AIMessageChunk
+import re
+import json
+from typing import Union
+from langchain_core.output_parsers import BaseOutputParser
+from langchain.schema import AgentAction, AgentFinish
+from langchain_anthropic.output_parsers import ToolsOutputParser
+from langchain.agents.react.output_parser import ReActOutputParser
+from langchain.agents import AgentOutputParser
+from langchain.agents.agent import AgentAction, AgentFinish, OutputParserException
+import re
+from typing import Union, List
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.exceptions import OutputParserException
+
+from langchain.agents.agent import AgentOutputParser
+from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
+
+FINAL_ANSWER_ACTION = "Final Answer:"
+MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
+ "Invalid Format: Missing 'Action:' after 'Thought:'"
+)
+MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE = (
+ "Invalid Format: Missing 'Action Input:' after 'Action:'"
+)
+FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
+ "Parsing LLM output produced both a final answer and parse-able actions"
+)
+
+
+class ReActMultiInputOutputParser(AgentOutputParser):
+ """Parses ReAct-style output that may contain multiple tool calls."""
+
+ def get_format_instructions(self) -> str:
+ # You can reuse the original FORMAT_INSTRUCTIONS,
+ # but let the model know it may emit multiple actions.
+ return FORMAT_INSTRUCTIONS + (
+ "\n\nIf you need to call more than one tool, simply repeat:\n"
+ "Action: \n"
+ "Action Input: \n"
+ "…for each tool in sequence."
+ )
+
+ @property
+ def _type(self) -> str:
+ return "react-multi-input"
+
+ def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
+ includes_answer = FINAL_ANSWER_ACTION in text
+ print('-------------------')
+ print(text)
+ print('-------------------')
+ # Grab every Action / Action Input block
+ pattern = (
+ r"Action\s*\d*\s*:[\s]*(.*?)\s*"
+ r"Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*?)(?=(?:Action\s*\d*\s*:|$))"
+ )
+ matches = list(re.finditer(pattern, text, re.DOTALL))
+
+ # If we found tool calls…
+ if matches:
+ if includes_answer:
+ # both a final answer *and* tool calls is ambiguous
+ raise OutputParserException(
+ f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
+ )
+
+ actions: List[AgentAction] = []
+ for m in matches:
+ tool_name = m.group(1).strip()
+ tool_input = m.group(2).strip().strip('"')
+ print('\n--------------------------')
+ print(tool_input)
+ print('--------------------------')
+ actions.append(AgentAction(tool_name, json.loads(tool_input), text))
+
+ return actions
+
+ # Otherwise, if there's a final answer, finish
+ if includes_answer:
+ answer = text.split(FINAL_ANSWER_ACTION, 1)[1].strip()
+ return AgentFinish({"output": answer}, text)
+
+ # No calls and no final answer → figure out which error to throw
+ if not re.search(r"Action\s*\d*\s*:", text):
+ raise OutputParserException(
+ f"Could not parse LLM output: `{text}`",
+ observation=MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
+ llm_output=text,
+ send_to_llm=True,
+ )
+ if not re.search(r"Action\s*\d*\s*Input\s*\d*:", text):
+ raise OutputParserException(
+ f"Could not parse LLM output: `{text}`",
+ observation=MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
+ llm_output=text,
+ send_to_llm=True,
+ )
+
+ # Fallback
+ raise OutputParserException(f"Could not parse LLM output: `{text}`")
+
+
+# Set up basic logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+# --- Centralized LLM Client Factory (REMOVED FROM HERE) ---
+# load_dotenv() # Moved to config
+# def get_llm_client(...): # Moved to config
+# ...
+# --- End Removed Section ---
+
+def create_agent_prompt(tools: List[tool]) -> ChatPromptTemplate:
+ """Create the prompt template for the causal inference agent, emphasizing workflow and data handoff.
+ (This is the version required by the LCEL agent structure below)
+ """
+ # Get the tool descriptions
+ tool_description = render_text_description(tools)
+ tool_names = ", ".join([t.name for t in tools])
+
+ # Define the system prompt template string
+ system_template = """
+You are a causal inference expert helping users answer causal questions by following a strict workflow using specialized tools.
+
+TOOLS:
+------
+You have access to the following tools:
+
+{tools}
+
+To use a tool, please use the following format:
+
+```
+Thought: Do I need to use a tool? Yes
+Action: the action to take, should be one of [{tool_names}]
+Action Input: the input to the action, as a single, valid JSON object string. Check the tool definition for required arguments and structure.
+Observation: the result of the action, often containing structured data like 'variables', 'dataset_analysis', 'method_info', etc.
+```
+
+When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
+
+```
+Thought: Do I need to use a tool? No
+Final Answer: [your response here]
+```
+
+DO NOT UNDER ANY CIRCUMSTANCE CALL MORE THAN ONE TOOL IN A STEP
+
+**IMPORTANT TOOL USAGE:**
+1. **Action Input Format:** The value for 'Action Input' MUST be a single, valid JSON object string. Do NOT include any other text or formatting around the JSON string.
+2. **Argument Gathering:** You MUST gather ALL required arguments for the Action Input JSON from the initial Human input AND the 'Observation' outputs of PREVIOUS steps. Look carefully at the required arguments for the tool you are calling.
+3. **Data Handoff:** The 'Observation' from a previous step often contains structured data needed by the next tool. For example, the 'variables' output from `query_interpreter_tool` contains fields like `treatment_variable`, `outcome_variable`, `covariates`, `time_variable`, `instrument_variable`, `running_variable`, `cutoff_value`, and `is_rct`. When calling `method_selector_tool`, you MUST construct its required `variables` input argument by including **ALL** these relevant fields identified by the `query_interpreter_tool` in the previous Observation. Similarly, pass the full `dataset_analysis`, `dataset_description`, and `original_query` when required by the next tool.
+
+IMPORTANT WORKFLOW:
+-------------------
+You must follow this exact workflow, selecting the appropriate tool for each step:
+
+1. ALWAYS start with `input_parser_tool` to understand the query
+2. THEN use `dataset_analyzer_tool` to analyze the dataset
+3. THEN use `query_interpreter_tool` to identify variables (output includes `variables` and `dataset_analysis`)
+4. THEN use `method_selector_tool` (input requires `variables` and `dataset_analysis` from previous step)
+5. THEN use `method_validator_tool` (input requires `method_info` and `variables` from previous step)
+6. THEN use `method_executor_tool` (input requires `method`, `variables`, `dataset_path`)
+7. THEN use `explanation_generator_tool` (input requires results, method_info, variables, etc.)
+8. FINALLY use `output_formatter_tool` to return the results
+
+REASONING PROCESS:
+------------------
+EXPLICITLY REASON about:
+1. What step you're currently on (based on previous tool's Observation)
+2. Why you're selecting a particular tool (should follow the workflow)
+3. How the output of the previous tool (especially structured data like `variables`, `dataset_analysis`, `method_info`) informs the inputs required for the current tool.
+
+IMPORTANT RULES:
+1. Do not make more than one tool call in a single step.
+2. Do not include ``` in your output at all.
+3. Don't use action names like default_api.dataset_analyzer_tool, instead use tool names like dataset_analyzer_tool.
+4. Always start, action, and observation with a new line.
+5. Don't use '\\' before double quotes
+6. Don't include ```json for Action Input
+Begin!
+"""
+
+ # Create the prompt template
+ prompt = ChatPromptTemplate.from_messages([
+ ("system", system_template),
+ MessagesPlaceholder("chat_history", optional=True), # Use MessagesPlaceholder
+ # MessagesPlaceholder("agent_scratchpad"),
+
+ ("human", "{input}\n Thought:{agent_scratchpad}"),
+ # ("ai", "{agent_scratchpad}"),
+ # MessagesPlaceholder("agent_scratchpad" ), # Use MessagesPlaceholder
+ # "agent_scratchpad"
+ ])
+ return prompt
+
+def create_causal_agent(llm: BaseChatModel) -> AgentExecutor:
+ """
+ Create and configure the LangChain agent with causal inference tools.
+ (Using explicit LCEL construction, compatible with shared LLM client)
+ """
+ # Define tools available to the agent
+ agent_tools = [
+ input_parser_tool,
+ dataset_analyzer_tool,
+ query_interpreter_tool,
+ method_selector_tool,
+ method_validator_tool,
+ method_executor_tool,
+ explanation_generator_tool,
+ output_formatter_tool
+ ]
+ # anthropic_agent_tools = [ convert_to_anthropic_tool(anthropic_tool) for anthropic_tool in agent_tools]
+ # Create the prompt using the helper
+ prompt = create_agent_prompt(agent_tools)
+
+ # Bind tools to the LLM (using the passed shared instance)
+ llm_with_tools = llm.bind_tools(agent_tools)
+
+ # Create memory
+ # Consider if memory needs to be passed in or created here
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
+
+ # Manually construct the agent runnable using LCEL
+ from langchain_anthropic.output_parsers import ToolsOutputParser
+ from langchain.agents.output_parsers.json import JSONAgentOutputParser
+ # from langchain.agents.react.output_parser import MultiActionAgentOutputParsers ReActMultiInputOutputParser
+ agent = create_react_agent(llm_with_tools, agent_tools, prompt, output_parser=ReActMultiInputOutputParser())
+
+ # Create executor (should now work with the manually constructed agent)
+ executor = AgentExecutor(
+ agent=agent,
+ tools=agent_tools,
+ memory=memory, # Pass the memory object
+ verbose=True,
+ callbacks=[ConsoleCallbackHandler()], # Optional: for console debugging
+ handle_parsing_errors=True, # Let AE handle parsing errors
+ max_retries = 100
+ )
+
+ return executor
+
+def run_causal_analysis(query: str, dataset_path: str,
+ dataset_description: Optional[str] = None,
+ api_key: Optional[str] = None) -> Dict[str, Any]:
+ """
+ Run causal analysis on a dataset based on a user query.
+
+ Args:
+ query: User's causal question
+ dataset_path: Path to the dataset
+ dataset_description: Optional textual description of the dataset
+ api_key: Optional OpenAI API key (DEPRECATED - will be ignored)
+
+ Returns:
+ Dictionary containing the final formatted analysis results from the agent's last step.
+ """
+ # Log the start of the analysis
+ logger.info("Starting causal analysis run...")
+
+ try:
+ # --- Instantiate the shared LLM client ---
+ shared_llm = get_llm_client(temperature=0) # Or read provider/model from env
+
+ # --- Dependency Injection Note (REMAINS RELEVANT) ---
+ # If tools need the LLM, they must be adapted. Example using partial:
+ # from functools import partial
+ # from .components import input_parser
+ # # Assume input_parser.parse_input needs llm
+ # input_parser_tool_with_llm = tool(partial(input_parser.parse_input, llm=shared_llm))
+ # Use input_parser_tool_with_llm in the tools list passed to the agent below.
+ # Similar adjustments needed for decision_tree._recommend_ps_method if used.
+ # --- End Note ---
+
+ # --- Create agent using the shared LLM ---
+ agent_executor = create_causal_agent(shared_llm)
+
+ # Construct input, including description if available
+ # IMPORTANT: Agent now expects 'input' and potentially 'chat_history'
+ # The input needs to contain all initial info the first tool might need.
+ initial_input_dict = {
+ "query": query,
+ "dataset_path": dataset_path,
+ "dataset_description": dataset_description
+ }
+ # Maybe format this into a single input string if the prompt expects {input}
+ input_text = f"My question is: {query}\n"
+ input_text += f"The dataset is located at: {dataset_path}\n"
+ if dataset_description:
+ input_text += f"Dataset Description: {dataset_description}\n"
+ input_text += "Please perform the causal analysis following the workflow."
+
+ # Log the constructed input text
+ logger.info(f"Constructed input for agent: \n{input_text}")
+
+ result = agent_executor.invoke({
+ "input": input_text,
+})
+
+
+ # AgentExecutor returns dict. Extract the final output dictionary.
+ logger.info("Causal analysis run finished.")
+
+ # Ensure result is a dict and extract the 'output' part
+ if isinstance(result, dict):
+ final_output = result.get("output")
+ if isinstance(final_output, dict):
+ return final_output # Return only the dictionary from the final tool
+ else:
+ logger.error(f"Agent result['output'] was not a dictionary: {type(final_output)}. Returning error dict.")
+ return {"error": "Agent did not produce the expected dictionary output in the 'output' key.", "raw_agent_result": result}
+ else:
+ logger.error(f"Agent returned non-dict type: {type(result)}. Returning error dict.")
+ return {"error": "Agent did not return expected dictionary output.", "raw_output": str(result)}
+
+ except ValueError as e:
+ logger.error(f"Configuration Error: {e}")
+ # Return an error dictionary in case of exception too
+ return {"error": f"Error: Configuration issue - {e}"} # Ensure consistent error return type
+ except Exception as e:
+ logger.error(f"An unexpected error occurred during causal analysis: {e}", exc_info=True)
+ # Return an error dictionary in case of exception too
+ return {"error": f"An unexpected error occurred: {e}"}
\ No newline at end of file
diff --git a/auto_causal/utils/llm_helpers.py b/auto_causal/utils/llm_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53183d10f773c3827c42791ecb2b3cf5ab55348
--- /dev/null
+++ b/auto_causal/utils/llm_helpers.py
@@ -0,0 +1,295 @@
+"""
+Utility functions for LLM interactions within the auto_causal module.
+"""
+
+from typing import Dict, Any, Optional, List
+# Assume pandas is available for get_columns_info and sample_data
+import pandas as pd
+import logging
+import json # Ensure json is imported
+
+# Added import for type hint
+from langchain.chat_models.base import BaseChatModel
+from langchain_core.messages import AIMessage # For type hinting llm.invoke response
+
+logger = logging.getLogger(__name__)
+
+# Placeholder for actual LLM calling logic
+def call_llm_with_json_output(llm: Optional[BaseChatModel], prompt: str) -> Optional[Dict[str, Any]]:
+ """
+ Calls the provided LLM with a prompt, expecting a JSON object in the response.
+ It parses the JSON string (after attempting to remove markdown fences)
+ and returns it as a Python dictionary.
+
+ Args:
+ llm: An instance of BaseChatModel (e.g., from Langchain). If None,
+ the function will log a warning and return None.
+ prompt: The prompt string to send to the LLM.
+
+ Returns:
+ A dictionary parsed from the LLM's JSON response, or None if:
+ - llm is None.
+ - The LLM call fails.
+ - The LLM response content cannot be extracted as a string.
+ - The response content is empty after stripping markdown.
+ - The response is not valid JSON.
+ - The parsed JSON is not a dictionary.
+ """
+ if not llm:
+ logger.warning("LLM client (BaseChatModel) not provided to call_llm_with_json_output. Cannot make LLM call.")
+ return None
+
+ logger.info(f"Attempting LLM call with {type(llm).__name__} for JSON output.")
+ # Full prompt logging can be verbose, using DEBUG level.
+ logger.debug(f"LLM Prompt for JSON output:\\n{prompt}")
+
+ raw_response_content = "" # For logging in case of errors before parsing
+ processed_content_for_json = "" # For logging in case of JSON parsing error
+
+ try:
+ llm_response_obj = llm.invoke(prompt)
+
+ # Extract string content from LLM response object
+ if hasattr(llm_response_obj, 'content') and isinstance(llm_response_obj.content, str):
+ raw_response_content = llm_response_obj.content
+ elif isinstance(llm_response_obj, str):
+ raw_response_content = llm_response_obj
+ else:
+ # Fallback for other potential response structures
+ logger.warning(
+ f"LLM response is not a string and has no '.content' attribute of type string. "
+ f"Type: {type(llm_response_obj)}. Trying '.text' attribute."
+ )
+ if hasattr(llm_response_obj, 'text') and isinstance(llm_response_obj.text, str):
+ raw_response_content = llm_response_obj.text
+
+ if not raw_response_content:
+ logger.warning(f"LLM invocation returned no extractable string content. Response object type: {type(llm_response_obj)}")
+ return None
+
+ # Prepare content for JSON parsing: strip whitespace and markdown fences.
+ # Using the same stripping logic as in llm_identify_temporal_and_unit_vars for consistency.
+ processed_content_for_json = raw_response_content.strip()
+
+ if processed_content_for_json.startswith("```json"):
+ # Removes "```json" prefix and "```" suffix, then strips whitespace.
+ # Assumes the format is "```json\\nCONTENT\\n```" or similar.
+ processed_content_for_json = processed_content_for_json[7:-3].strip()
+ elif processed_content_for_json.startswith("```"):
+ # Removes generic "```" prefix and "```" suffix, then strips.
+ processed_content_for_json = processed_content_for_json[3:-3].strip()
+
+ if not processed_content_for_json: # Check if empty after stripping
+ logger.warning(
+ "LLM response content became empty after attempting to strip markdown. "
+ f"Original raw content snippet: '{raw_response_content[:200]}...'"
+ )
+ return None
+
+ parsed_json = json.loads(processed_content_for_json)
+
+ if not isinstance(parsed_json, dict):
+ logger.warning(
+ "LLM response was successfully parsed as JSON, but it is not a dictionary. "
+ f"Type: {type(parsed_json)}. Parsed content snippet: '{str(parsed_json)[:200]}...'"
+ )
+ return None
+
+ logger.info(f"Successfully received and parsed JSON response from {type(llm).__name__}.")
+ return parsed_json
+
+ except json.JSONDecodeError as e:
+ logger.error(
+ f"Failed to decode JSON from LLM response. Error: {e}. "
+ f"Content processed for parsing (snippet): '{processed_content_for_json[:500]}...'"
+ )
+ return None
+ except Exception as e:
+ # This catches errors from llm.invoke() or other unexpected issues.
+ logger.error(f"An unexpected error occurred during LLM call or JSON processing: {e}", exc_info=True)
+ # Log raw content if available and different from processed, for better debugging
+ if raw_response_content and raw_response_content[:500] != processed_content_for_json[:500]:
+ logger.debug(f"Original raw LLM response content (snippet): '{raw_response_content[:500]}...'")
+ return None
+
+# Placeholder for processing LLM response
+def process_llm_response(response: Dict[str, Any], method: str) -> Dict[str, Any]:
+ # Validate and structure the LLM response based on the method
+ # For now, just return the response
+ return response
+
+# Placeholder for getting column info
+def get_columns_info(df: pd.DataFrame) -> Dict[str, str]:
+ return {col: str(dtype) for col, dtype in df.dtypes.items()}
+
+
+def analyze_dataset_for_method(df: pd.DataFrame, query: str, method: str) -> Dict[str, Any]:
+ """Use LLM to analyze dataset for appropriate method parameters.
+
+ Args:
+ df: Input DataFrame
+ query: User's causal query
+ method: The causal method being considered
+
+ Returns:
+ Dictionary with suggested parameters and validation checks from LLM.
+ """
+ # Prepare prompt with dataset information
+ columns_info = get_columns_info(df)
+ try:
+ # Attempt to get sample data safely
+ sample_data = df.head(5).to_dict(orient='records')
+ except Exception:
+ sample_data = "Error retrieving sample data."
+
+ # --- Revised Prompt ---
+ prompt = f"""
+ Given the dataset with columns {columns_info} and the causal query "{query}",
+ suggest SENSIBLE INITIAL DEFAULT parameters for applying the {method} method.
+ Do NOT attempt complex optimization; provide common starting points.
+
+ The first 5 rows of data look like:
+ {sample_data}
+
+ Specifically for {method}:
+ - If PS.Matching:
+ - For 'caliper': Suggest a common heuristic value like 0.01, 0.02, or 0.05 (this is relative to std dev of logit score, but just suggest the number). If unsure, suggest 0.02.
+ - For 'n_neighbors': Suggest 1.
+ - For 'propensity_model_type': Suggest 'logistic' unless the context strongly implies a more complex model is needed.
+ - If PS.Weighting:
+ - For 'weight_type': Suggest 'ATE' unless the query specifically asks for ATT or ATC.
+ - For 'trim_threshold': Suggest a small value like 0.01 or 0.05 if the data seems noisy or has extreme propensity scores, otherwise suggest null (no trimming). Default to null if unsure.
+ - Add other parameters if relevant for the specific method.
+
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
+ {{
+ "parameters": {{
+ // method-specific parameters based on the guidelines above
+ }},
+ "validation": {{
+ // validation checks typically needed (e.g., check_balance: true for PSM)
+ }}
+ }}
+ """
+ # --- End Revised Prompt ---
+
+ # Call LLM with prompt - Assuming analyze_dataset_for_method provides the llm object
+ # For now, this internal call still uses the placeholder without passing llm
+ # This needs to be updated if analyze_dataset_for_method is intended to use a passed llm
+ response = call_llm_with_json_output(None, prompt) # Passing None for llm temporarily
+
+ # Process and validate response
+ # This step might involve ensuring the structure is correct,
+ # parameters are valid types, etc.
+ processed_response = process_llm_response(response, method)
+
+ return processed_response
+
+
+def llm_identify_temporal_and_unit_vars(
+ column_names: List[str],
+ column_dtypes: Dict[str, str],
+ dataset_description: str,
+ dataset_summary: str,
+ heuristic_time_candidates: Optional[List[str]] = None, # These are no longer used in the revised prompt
+ heuristic_id_candidates: Optional[List[str]] = None, # These are no longer used in the revised prompt
+ query: str = "No query provided.",
+ llm: Optional[BaseChatModel] = None
+) -> Dict[str, Optional[str]]:
+ """Uses LLM to identify the primary time:
+
+ Args:
+ column_names: List of all column names.
+ column_dtypes: Dictionary mapping column names to string representation of data types.
+ dataset_description: Textual description of the dataset.
+ dataset_summary: Summary of the dataset
+ heuristic_time_candidates: Optional list of columns identified as time vars by heuristics (currently unused by prompt).
+ heuristic_id_candidates: Optional list of columns identified as unit ID vars by heuristics (currently unused by prompt).
+ llm: The language model client instance.
+
+ Returns:
+ A dictionary with keys 'time_variable' and 'unit_variable',
+ whose values are the identified column names or None.
+ """
+ if not llm:
+ logger.warning("LLM client not provided for temporal/unit identification. Returning None.")
+ return {"time_variable": None, "unit_variable": None}
+
+ logger.info("Attempting LLM identification of time and unit variables...")
+
+ # Construct the prompt (revised based on user feedback in conversation)
+ prompt = f"""
+You are a data analysis expert tasked with determining whether a dataset supports a Difference-in-Differences (DiD) or Two-Way Fixed Effects (TWFE) design to answer the following query:
+{query}
+
+You are given the following information:
+
+Dataset Description:
+{dataset_description}
+
+Columns and Data Types:
+{column_dtypes}
+
+First, based on the above information, check if any columns represent information about the time/periods associated directly with intervention application. It could be either:
+1. A variable that represents **time periods associated with the intervention**. This must satisfy one of the following:
+ - A binary indicator showing pre/post-intervention status,
+ - A discrete or continuous variable that records **when units were observed**, which can be aligned with treatment application periods.
+
+ Do **not** select generic time-related variables that merely describe time as a feature, such as **'date of birth'**, **'year of graduation'**, 'week of sign-up', **'years of schooling'** unless they directly represent **observation times relevant to treatment**.
+
+2. A variable that represents the **unit of observation** (e.g., individual, region, school) — the entity over which we compare treated vs. untreated groups across time.
+
+Return ONLY a valid JSON object with this structure and no surrounding explanation:
+
+{{
+ "time_variable": "",
+ "unit_variable": ""
+}}
+"""
+
+
+ parsed_response = None
+ try:
+ llm_response_obj = llm.invoke(prompt)
+ response_content = ""
+ if hasattr(llm_response_obj, 'content'):
+ response_content = llm_response_obj.content
+ elif isinstance(llm_response_obj, str): # Some LLMs might return str directly
+ response_content = llm_response_obj
+ else:
+ logger.warning(f"LLM response object type not recognized for content extraction: {type(llm_response_obj)}")
+
+ if response_content:
+ # Attempt to strip markdown ```json ... ``` if present
+ if response_content.strip().startswith("```json"):
+ response_content = response_content.strip()[7:-3].strip()
+ elif response_content.strip().startswith("```"):
+ response_content = response_content.strip()[3:-3].strip()
+
+ parsed_response = json.loads(response_content)
+ else:
+ logger.warning("LLM invocation returned no content.")
+
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to decode JSON from LLM response for time/unit vars: {e}. Response content: '{response_content[:500]}...'") # Log snippet
+ except Exception as e:
+ logger.error(f"Error during LLM invocation or processing for time/unit vars: {e}", exc_info=True)
+
+ # Process the response
+ if parsed_response and isinstance(parsed_response, dict):
+ time_var = parsed_response.get("time_variable")
+ unit_var = parsed_response.get("unit_variable")
+
+ # Basic validation: ensure returned names are actual columns or None
+ if time_var is not None and time_var not in column_names:
+ logger.warning(f"LLM identified time variable '{time_var}' not found in columns. Setting to None.")
+ time_var = None
+ if unit_var is not None and unit_var not in column_names:
+ logger.warning(f"LLM identified unit variable '{unit_var}' not found in columns. Setting to None.")
+ unit_var = None
+
+ logger.info(f"LLM identified time='{time_var}', unit='{unit_var}'")
+ return {"time_variable": time_var, "unit_variable": unit_var}
+ else:
+ logger.warning("LLM call failed or returned invalid/unparsable JSON for time/unit identification.")
+ return {"time_variable": None, "unit_variable": None}
\ No newline at end of file
diff --git a/main/run_cais.py b/main/run_cais.py
new file mode 100644
index 0000000000000000000000000000000000000000..c027ab156f18ed56e1a91f5b7f1df16afdf0797f
--- /dev/null
+++ b/main/run_cais.py
@@ -0,0 +1,108 @@
+## This file runs the CAIS pipeline for a list of queries provided in a CSV file
+
+import os, re, io, time, json, logging, contextlib, textwrap
+from typing import Dict, Any
+import pandas as pd
+import argparse
+from auto_causal.agent import run_causal_analysis
+
+# Constants
+RATE_LIMIT_SECONDS = 2
+
+def run_cais(desc, question, df):
+ """
+ A wrapper function to run the causal analysis pipeline
+ Args:
+ desc (str): Description of the dataset
+ question (str): Natural language query associated with the dataset
+ df (str): Path to the csv file assocated with the dataset
+
+ Returns:
+ dict: Results from the CAIS pipeline
+ """
+
+ return run_causal_analysis(query=question, dataset_path=df, dataset_description=desc)
+
+def parse_args():
+
+ parser = argparse.ArgumentParser(description="Run batch causal analysis.")
+ parser.add_argument("-m", "--metadata_path", type=str, required=True,
+ help="Path to the CSV file with queries, descriptions, and file names etc")
+ parser.add_argument("-d", "--data_dir", type=str, required=True,
+ help="Path to the folder containing the data in CSV format")
+ parser.add_argument("-o", "--output_dir", type=str, required=True,
+ help="Path to the folder where the output is saved output")
+ parser.add_argument("-n", "--output_name", type=str, default="cais_results.json",)
+ parser.add_argument("-l", "--llm_name", type=str, required=True,
+ help="Name of the LLM used to be used")
+ return parser.parse_args()
+
+def main():
+
+ args = parse_args()
+ metadata_path = args.metadata_path
+ data_dir = args.data_dir
+ output_dir = args.output_dir
+ output_name = args.output_name
+ os.environ["LLM_MODEL"] = args.llm_name
+ print("[main] Starting batch processing…")
+
+ if not os.path.exists(metadata_path):
+ logging.error(f"Meta file not found: {metadata_path}")
+ return
+
+ meta_df = pd.read_csv(metadata_path)
+ print(f"[main] Loaded metadata CSV with {len(meta_df)} rows.")
+
+ results: Dict[int, Dict[str, Any]] = {}
+
+ for idx, row in meta_df.iterrows():
+ data_path = os.path.join(data_dir, str(row["data_files"]))
+ print(f"\n[main] Row {idx+1}/{len(meta_df)} → Dataset: {data_path}")
+
+ try:
+ res = run_cais(desc=row["data_description"], question=row["natural_language_query"],
+ df=data_path)
+
+ # Format result according to specified structure
+ formatted_result = {
+ "query": row["natural_language_query"],
+ "method": row["method"],
+ "answer": row["answer"],
+ "dataset_description": row["data_description"],
+ "dataset_path": data_path,
+ "keywords": row.get("keywords", "Causality, Average treatment effect"),
+ "final_result": {
+ "method": res['results']['results'].get("method_used"),
+ "causal_effect": res['results']['results'].get("effect_estimate"),
+ "standard_deviation": res['results']['results'].get("standard_error"),
+ "treatment_variable": res['results']['variables'].get("treatment_variable", None),
+ "outcome_variable": res['results']['variables'].get("outcome_variable", None),
+ "covariates": res['results']['variables'].get("covariates", []),
+ "instrument_variable": res['results']['variables'].get("instrument_variable", None),
+ "running_variable": res['results']['variables'].get("running_variable", None),
+ "temporal_variable": res['results']['variables'].get("time_variable", None),
+ "statistical_test_results": res.get("summary", ""),
+ "explanation_for_model_choice": res.get("explanation", ""),
+ "regression_equation": res.get("regression_equation", "")
+ }
+ }
+ results[idx] = formatted_result
+ print(f"[main] Formatted result for row {idx+1}:", formatted_result)
+
+ except Exception as e:
+ logging.error(f"[{idx+1}] Error: {e}")
+ results[idx] = {"answer": str(e)}
+
+ time.sleep(RATE_LIMIT_SECONDS)
+
+ os.makedirs(output_dir, exist_ok=True)
+ output_json = os.path.join(output_dir, output_name)
+ if not output_json.endswith(".json"):
+ output_json += ".json"
+ with open(output_json, "w") as f:
+ json.dump(results, f, indent=4)
+ print(f"[main] Done. Predictions saved to {output_json}")
+
+if __name__ == "__main__":
+ main()
diff --git a/reference_files/decision_tree.pdf b/reference_files/decision_tree.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..7f9a6b477a292253842415227a3293572e79f897
--- /dev/null
+++ b/reference_files/decision_tree.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8eee4685d7859f4ada9328143754124fd3d1e4b78727da8bfdeed1fce6a7464c
+size 71393
diff --git a/reproduce_results/create_context/create_context_did_canonical.sh b/reproduce_results/create_context/create_context_did_canonical.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d375850c62f08c4c123a86dbb0220e024d57cfd9
--- /dev/null
+++ b/reproduce_results/create_context/create_context_did_canonical.sh
@@ -0,0 +1,15 @@
+#!/bin/sh
+
+# create_descriptions.sh
+# This script generates the column labels, backstory, and causal query for all the Canonical DiD synthetic datasets.
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="did_canonical"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata/${METHOD}.json"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+OUTPUT_FOLDER="${BASE_FOLDER}/${METHOD}/description"
+
+python main/generate_context.py -mp ${METADATA_FOLDER} -d ${DATA_FOLDER} -o ${OUTPUT_FOLDER} -m ${METHOD}
diff --git a/reproduce_results/create_context/create_context_did_twfe.sh b/reproduce_results/create_context/create_context_did_twfe.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9d445d24a235dd0059fc4eb09236df31ad877d47
--- /dev/null
+++ b/reproduce_results/create_context/create_context_did_twfe.sh
@@ -0,0 +1,15 @@
+#!/bin/sh
+
+# create_descriptions.sh
+# This script generates the column labels, backstory, and causal query for all the synthetic datasets.
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="did_twfe"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata/${METHOD}.json"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+OUTPUT_FOLDER="${BASE_FOLDER}/${METHOD}/description"
+
+python main/generate_context.py -mp ${METADATA_FOLDER} -d ${DATA_FOLDER} -o ${OUTPUT_FOLDER} -m ${METHOD}
diff --git a/reproduce_results/create_context/create_context_front_door.sh b/reproduce_results/create_context/create_context_front_door.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2740cb54e2b0a14639e29919f10fb450b836d843
--- /dev/null
+++ b/reproduce_results/create_context/create_context_front_door.sh
@@ -0,0 +1,11 @@
+source reproduce_results/settings.sh
+METHOD="frontdoor"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata/${METHOD}.json"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+OUTPUT_FOLDER="${BASE_FOLDER}/${METHOD}/description"
+
+python main/generate_context.py \
+ -mp ${METADATA_FOLDER} \
+ -d ${DATA_FOLDER} \
+ -o ${OUTPUT_FOLDER} \
+ -m ${METHOD}
diff --git a/reproduce_results/create_context/create_context_iv.sh b/reproduce_results/create_context/create_context_iv.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f81d7ef6b63060ae7af9b3a34dc5f657391df1d
--- /dev/null
+++ b/reproduce_results/create_context/create_context_iv.sh
@@ -0,0 +1,15 @@
+#!/bin/sh
+
+# create_descriptions.sh
+# This script generates the column labels, backstory, and causal query for all the IV synthetic datasets.
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="iv"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata/${METHOD}.json"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+OUTPUT_FOLDER="${BASE_FOLDER}/${METHOD}/description"
+
+python main/generate_context.py -mp ${METADATA_FOLDER} -d ${DATA_FOLDER} -o ${OUTPUT_FOLDER} -m ${METHOD}
diff --git a/reproduce_results/create_context/create_context_iv_encouragement.sh b/reproduce_results/create_context/create_context_iv_encouragement.sh
new file mode 100644
index 0000000000000000000000000000000000000000..09d46d0cd6923b6edff381857403bf5ab0bb7ce2
--- /dev/null
+++ b/reproduce_results/create_context/create_context_iv_encouragement.sh
@@ -0,0 +1,15 @@
+#!/bin/sh
+
+# create_descriptions.sh
+# This script generates the column labels, backstory, and causal query for all the IV encouragement synthetic datasets.
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="iv_encouragement"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata/${METHOD}.json"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+OUTPUT_FOLDER="${BASE_FOLDER}/${METHOD}/description"
+
+python main/generate_context.py -mp ${METADATA_FOLDER} -d ${DATA_FOLDER} -o ${OUTPUT_FOLDER} -m ${METHOD}
diff --git a/reproduce_results/create_context/create_context_multi_rct.sh b/reproduce_results/create_context/create_context_multi_rct.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1ce5c3ba193faa48a323862f66e69858ce5ca161
--- /dev/null
+++ b/reproduce_results/create_context/create_context_multi_rct.sh
@@ -0,0 +1,15 @@
+#!/bin/sh
+
+# create_descriptions.sh
+# This script generates the column labels, backstory, and causal query for multi-RCT synthetic datasets.
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="multi_rct"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata/${METHOD}.json"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+OUTPUT_FOLDER="${BASE_FOLDER}/${METHOD}/description"
+
+python main/generate_context.py -mp ${METADATA_FOLDER} -d ${DATA_FOLDER} -o ${OUTPUT_FOLDER} -m ${METHOD}
diff --git a/reproduce_results/create_context/create_context_observational.sh b/reproduce_results/create_context/create_context_observational.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e20158d606fbd6be8bd6dc5a96d082eee8a3667f
--- /dev/null
+++ b/reproduce_results/create_context/create_context_observational.sh
@@ -0,0 +1,15 @@
+#!/bin/sh
+
+# create_descriptions.sh
+# This script generates the column labels, backstory, and causal query for observational synthetic datasets.
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="observational"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata/${METHOD}.json"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+OUTPUT_FOLDER="${BASE_FOLDER}/${METHOD}/description"
+
+python main/generate_context.py -mp ${METADATA_FOLDER} -d ${DATA_FOLDER} -o ${OUTPUT_FOLDER} -m ${METHOD}
diff --git a/reproduce_results/create_context/create_context_rct.sh b/reproduce_results/create_context/create_context_rct.sh
new file mode 100644
index 0000000000000000000000000000000000000000..217db16e86005d25a81e4ebbc95be8c1d89033e6
--- /dev/null
+++ b/reproduce_results/create_context/create_context_rct.sh
@@ -0,0 +1,15 @@
+#!/bin/sh
+
+# create_descriptions.sh
+# This script generates the column labels, backstory, and causal query for RCT synthetic datasets.
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="rct"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata/${METHOD}.json"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+OUTPUT_FOLDER="${BASE_FOLDER}/${METHOD}/description"
+
+python main/generate_context.py -mp ${METADATA_FOLDER} -d ${DATA_FOLDER} -o ${OUTPUT_FOLDER} -m ${METHOD}
diff --git a/reproduce_results/create_context/create_context_rdd.sh b/reproduce_results/create_context/create_context_rdd.sh
new file mode 100644
index 0000000000000000000000000000000000000000..31e2e80532520755d0b513343674ad7f3cf9db11
--- /dev/null
+++ b/reproduce_results/create_context/create_context_rdd.sh
@@ -0,0 +1,15 @@
+#!/bin/sh
+
+# create_descriptions.sh
+# This script generates the column labels, backstory, and causal query for RDD synthetic datasets.
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="rdd"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata/${METHOD}.json"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+OUTPUT_FOLDER="${BASE_FOLDER}/${METHOD}/description"
+
+python main/generate_context.py -mp ${METADATA_FOLDER} -d ${DATA_FOLDER} -o ${OUTPUT_FOLDER} -m ${METHOD}
diff --git a/reproduce_results/create_context_all.sh b/reproduce_results/create_context_all.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a2607bf1b5fa0c58911ab0aef929c67bae34d8a9
--- /dev/null
+++ b/reproduce_results/create_context_all.sh
@@ -0,0 +1,35 @@
+#!/bin/sh
+
+# create_descriptions.sh
+# This script generates the column labels, backstory, and causal query for all the synthetic datasets.
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+echo "Generating context for RCT Data"
+bash reproduce_results/create_context/create_context_rct.sh
+
+echo "Generating context for Multi-RCT Data"
+bash reproduce_results/create_context/create_context_multi_rct.sh
+
+echo "Generating context for Front_Door Data"
+bash reproduce_results/create_context/create_context_front_door.sh
+
+echo "Generating context for Observational Data"
+bash reproduce_results/create_context/create_context_observational.sh
+
+echo "Generating context for Canonical DiD Data"
+bash reproduce_results/create_context/create_context_did_canonical.sh
+
+echo "Generating context for TWFE DiD Data"
+bash reproduce_results/create_context/create_context_did_twfe.sh
+
+echo "Generating context for IV Data"
+bash reproduce_results/create_context/create_context_iv.sh
+
+echo "Generating context for IV-Encouragement Data"
+bash reproduce_results/create_context/create_context_iv_encouragement.sh
+
+echo "Generating context for RDD Data"
+bash reproduce_results/create_context/create_context_rdd.sh
+
diff --git a/reproduce_results/create_data/create_did_canonical_data.sh b/reproduce_results/create_data/create_did_canonical_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4df7cfb25d587bb4aad2a1732676e0152c2fcce4
--- /dev/null
+++ b/reproduce_results/create_data/create_did_canonical_data.sh
@@ -0,0 +1,14 @@
+#!/bin/sh
+
+# create_descriptions.sh
+#
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="did_canonical"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+
+python main/generate_synthetic.py -md ${METADATA_FOLDER} -d ${DATA_FOLDER} -m ${METHOD} -s ${DEFAULT_SIZE} -mb ${N_BINARY_OTHERS} -mc ${N_CONTINUOUS_DID_CANONICAL} -o ${DEFAULT_OBS}
diff --git a/reproduce_results/create_data/create_did_twfe_data.sh b/reproduce_results/create_data/create_did_twfe_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..27ba593dc63168f71093b6acf3c9b115a849e2ea
--- /dev/null
+++ b/reproduce_results/create_data/create_did_twfe_data.sh
@@ -0,0 +1,14 @@
+#!/bin/sh
+
+# create_descriptions.sh
+#
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="did_twfe"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+
+python main/generate_synthetic.py -md ${METADATA_FOLDER} -d ${DATA_FOLDER} -m ${METHOD} -s ${DEFAULT_SIZE} -mb ${N_BINARY_OTHERS} -mc ${N_CONTINUOUS_DID_TWFE} -np ${MAX_PERIODS} -o ${DEFAULT_OBS_TWFE}
diff --git a/reproduce_results/create_data/create_front_door_data.sh b/reproduce_results/create_data/create_front_door_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5cb9e3c1cabbafe86a0b8f4a3189657fe936155e
--- /dev/null
+++ b/reproduce_results/create_data/create_front_door_data.sh
@@ -0,0 +1,13 @@
+source reproduce_results/settings.sh
+METHOD="frontdoor"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+
+python main/generate_synthetic.py \
+ -md ${METADATA_FOLDER} \
+ -d ${DATA_FOLDER} \
+ -m ${METHOD} \
+ -s ${DEFAULT_SIZE} \
+ -mb ${N_BINARY} \
+ -mc ${N_CONTINUOUS_FRONTDOOR} \
+ -o ${DEFAULT_OBS}
diff --git a/reproduce_results/create_data/create_iv_data.sh b/reproduce_results/create_data/create_iv_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9f0db20105682e2498851734c69f7a8644be869d
--- /dev/null
+++ b/reproduce_results/create_data/create_iv_data.sh
@@ -0,0 +1,14 @@
+#!/bin/sh
+
+# create_descriptions.sh
+#
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="iv"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+
+python main/generate_synthetic.py -md ${METADATA_FOLDER} -d ${DATA_FOLDER} -m ${METHOD} -s ${DEFAULT_SIZE} -mb ${N_BINARY} -mc ${N_CONTINUOUS_IV} -o ${DEFAULT_OBS}
diff --git a/reproduce_results/create_data/create_iv_encouragement_data.sh b/reproduce_results/create_data/create_iv_encouragement_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d109788ebbf4a08cb57705f98131730c145caa1
--- /dev/null
+++ b/reproduce_results/create_data/create_iv_encouragement_data.sh
@@ -0,0 +1,14 @@
+#!/bin/sh
+
+# create_descriptions.sh
+#
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="iv_encouragement"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+
+python main/generate_synthetic.py -md ${METADATA_FOLDER} -d ${DATA_FOLDER} -m ${METHOD} -s ${DEFAULT_SIZE} -mb ${N_BINARY_OTHERS} -mc ${N_CONTINUOUS_IV_ENCOURAGEMENT} -o ${DEFAULT_OBS}
diff --git a/reproduce_results/create_data/create_multi_rct_data.sh b/reproduce_results/create_data/create_multi_rct_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4f6db6fb8a67916560b5a9f7057c74557aa85945
--- /dev/null
+++ b/reproduce_results/create_data/create_multi_rct_data.sh
@@ -0,0 +1,14 @@
+#!/bin/sh
+
+# create_descriptions.sh
+#
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="multi_rct"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+
+python main/generate_synthetic.py -md ${METADATA_FOLDER} -d ${DATA_FOLDER} -m ${METHOD} -s ${DEFAULT_SIZE} -mb ${N_BINARY} -mc ${N_CONTINUOUS_MULTI} -nt ${MAX_TREATMENTS} -o ${DEFAULT_OBS}
diff --git a/reproduce_results/create_data/create_observational_data.sh b/reproduce_results/create_data/create_observational_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a45d4343f8ffe8fb9716cf8a1512063a94ff4d29
--- /dev/null
+++ b/reproduce_results/create_data/create_observational_data.sh
@@ -0,0 +1,14 @@
+#!/bin/sh
+
+# create_descriptions.sh
+#
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="observational"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+
+python main/generate_synthetic.py -md ${METADATA_FOLDER} -d ${DATA_FOLDER} -m ${METHOD} -s ${DEFAULT_SIZE} -mb ${N_BINARY} -mc ${N_CONTINUOUS} -o ${DEFAULT_OBS}
diff --git a/reproduce_results/create_data/create_rct_data.sh b/reproduce_results/create_data/create_rct_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e27960546d6f6aa518a634e80d2b12e197ea227c
--- /dev/null
+++ b/reproduce_results/create_data/create_rct_data.sh
@@ -0,0 +1,14 @@
+#!/bin/sh
+
+# create_descriptions.sh
+#
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="rct"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+
+python main/generate_synthetic.py -md ${METADATA_FOLDER} -d ${DATA_FOLDER} -m ${METHOD} -s ${DEFAULT_SIZE} -mb ${N_BINARY} -mc ${N_CONTINUOUS} -o ${DEFAULT_OBS}
diff --git a/reproduce_results/create_data/create_rdd_data.sh b/reproduce_results/create_data/create_rdd_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7a8d30fd70e1aa7d8e72efab3ebaeaf806570f27
--- /dev/null
+++ b/reproduce_results/create_data/create_rdd_data.sh
@@ -0,0 +1,14 @@
+#!/bin/sh
+
+# create_descriptions.sh
+#
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+source reproduce_results/settings.sh
+METHOD="rdd"
+METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata"
+DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+
+python main/generate_synthetic.py -md ${METADATA_FOLDER} -d ${DATA_FOLDER} -m ${METHOD} -s ${DEFAULT_SIZE} -mb ${N_BINARY_OTHERS} -mc ${N_CONTINUOUS_RDD} -c ${CUTOFF} -o ${DEFAULT_OBS}
diff --git a/reproduce_results/create_synthetic_data_all.sh b/reproduce_results/create_synthetic_data_all.sh
new file mode 100644
index 0000000000000000000000000000000000000000..50f686d72dced411fc76a2ccff186c0ff560b9ff
--- /dev/null
+++ b/reproduce_results/create_synthetic_data_all.sh
@@ -0,0 +1,36 @@
+#!/bin/sh
+
+# create_synthetic_data_all.sh
+# This scripts generates all the synthetic data
+#
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+
+
+echo "Generating RCT Data"
+bash reproduce_results/create_data/create_rct_data.sh
+
+echo "Generating Multi-RCT Data"
+bash reproduce_results/create_data/create_multi_rct_data.sh
+
+echo "Generating Front_Door Data"
+bash reproduce_results/create_data/create_front_door_data.sh
+
+echo "Generating Observational Data"
+bash reproduce_results/create_data/create_observational_data.sh
+
+echo "Generating Canonical DiD Data"
+bash reproduce_results/create_data/create_did_canonical_data.sh
+
+echo "Generating TWFE DiD Data"
+bash reproduce_results/create_data/create_did_twfe_data.sh
+
+echo "Generating IV Data"
+bash reproduce_results/create_data/create_iv_data.sh
+
+echo "Generating IV-Encouragement Data"
+bash reproduce_results/create_data/create_iv_encouragement_data.sh
+
+echo "Generating RDD Data"
+bash reproduce_results/create_data/create_rdd_data.sh
diff --git a/reproduce_results/finalize_synthetic_dataset.sh b/reproduce_results/finalize_synthetic_dataset.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ec1468ea26301a78182f2d8e57ee56c3356cf47c
--- /dev/null
+++ b/reproduce_results/finalize_synthetic_dataset.sh
@@ -0,0 +1,25 @@
+#!/bin/sh
+
+# finalize_synthetic_dataset.sh
+# This scripts puts together the results of generate_context.py and generate_synthetic.py. It renames the columns of the synthetic data files, and saves the resulting csv file. Additionally, it creates a summary csv file including the key information needed to run and evaluate the tests on the synthetic data.
+#
+# Created by Sawal Acharya on 5/14/25.
+#
+source reproduce_results/settings.sh
+
+for METHOD in rct multi_rct did_canonical did_twfe iv iv_encouragement rdd observational; do
+ METADATA_FOLDER="${BASE_FOLDER}/${METHOD}/metadata/${METHOD}.json"
+ INPUT_DATA_FOLDER="${BASE_FOLDER}/${METHOD}/data"
+ OUTPUT_PATH="${BASE_FOLDER}/data_info"
+ DESCRIPTION_PATH="${BASE_FOLDER}/${METHOD}/description/${METHOD}.json"
+ OUTPUT_DATA_FOLDER="${BASE_FOLDER}/synthetic_data"
+
+ python main/finalize_data.py \
+ -md "$METADATA_FOLDER" \
+ -id "$INPUT_DATA_FOLDER" \
+ -m "$METHOD" \
+ -o "$OUTPUT_PATH" \
+ -de "$DESCRIPTION_PATH" \
+ -od "$OUTPUT_DATA_FOLDER"
+done
+
diff --git a/reproduce_results/log_config.ini b/reproduce_results/log_config.ini
new file mode 100644
index 0000000000000000000000000000000000000000..65488e1b8f25da870e2de05b68b136af3b315d0c
--- /dev/null
+++ b/reproduce_results/log_config.ini
@@ -0,0 +1,125 @@
+[loggers]
+keys=root,observational_data_logger,did_data_logger,iv_data_logger,rct_data_logger,rdd_data_logger,multi_rct_data_logger, description_logger, runs_logger
+
+[handlers]
+keys=consoleHandler,obsHandler,didHandler,ivHandler,rctHandler,rddHandler,multiRCTHandler, descriptionHandler, runsHandler
+
+[formatters]
+keys=simpleFormatter,complexFormatter
+
+# ===== Loggers =====
+[logger_root]
+level=INFO
+handlers=consoleHandler
+
+[logger_observational_data_logger]
+level=DEBUG
+handlers=consoleHandler,obsHandler
+qualname=observational_data_logger
+propagate=0
+
+[logger_did_data_logger]
+level=DEBUG
+handlers=consoleHandler,didHandler
+qualname=did_data_logger
+propagate=0
+
+[logger_iv_data_logger]
+level=DEBUG
+handlers=consoleHandler,ivHandler
+qualname=iv_data_logger
+propagate=0
+
+[logger_rct_data_logger]
+level=DEBUG
+handlers=consoleHandler,rctHandler
+qualname=rct_data_logger
+propagate=0
+
+[logger_rdd_data_logger]
+level=DEBUG
+handlers=consoleHandler,rddHandler
+qualname=rdd_data_logger
+propagate=0
+
+[logger_multi_rct_data_logger]
+level=DEBUG
+handlers=consoleHandler,multiRCTHandler
+qualname=multi_rct_data_logger
+propagate=0
+
+[logger_description_logger]
+level=DEBUG
+handlers=consoleHandler,descriptionHandler
+qualname=description_logger
+propagate=0
+
+[logger_runs_logger]
+level=DEBUG
+handlers=consoleHandler,runsHandler
+qualname=runs_logger
+propagate=0
+
+# ===== Handlers =====
+[handler_consoleHandler]
+class=StreamHandler
+level=DEBUG
+formatter=simpleFormatter
+args=(sys.stdout,)
+
+[handler_obsHandler]
+class=logging.handlers.TimedRotatingFileHandler
+level=DEBUG
+formatter=complexFormatter
+args=('logs/observational_data.log', 'midnight', 1, 5)
+
+[handler_didHandler]
+class=logging.handlers.TimedRotatingFileHandler
+level=DEBUG
+formatter=complexFormatter
+args=('logs/did_data.log', 'midnight', 1, 5)
+
+[handler_ivHandler]
+class=logging.handlers.TimedRotatingFileHandler
+level=DEBUG
+formatter=complexFormatter
+args=('logs/iv_data.log', 'midnight', 1, 5)
+
+[handler_rctHandler]
+class=logging.handlers.TimedRotatingFileHandler
+level=DEBUG
+formatter=complexFormatter
+args=('logs/rct_data.log', 'midnight', 1, 5)
+
+[handler_rddHandler]
+class=logging.handlers.TimedRotatingFileHandler
+level=DEBUG
+formatter=complexFormatter
+args=('logs/rdd_data.log', 'midnight', 1, 5)
+
+[handler_multiRCTHandler]
+class=logging.handlers.TimedRotatingFileHandler
+level=DEBUG
+formatter=complexFormatter
+args=('logs/multi_rct_data.log', 'midnight', 1, 5)
+
+[handler_descriptionHandler]
+class=logging.handlers.TimedRotatingFileHandler
+level=DEBUG
+formatter=complexFormatter
+args=('logs/description.log', 'midnight', 1, 5)
+
+[handler_runsHandler]
+class=logging.handlers.TimedRotatingFileHandler
+level=DEBUG
+formatter=complexFormatter
+args=('logs/runs.log', 'midnight', 1, 5)
+
+# ===== Formatters =====
+[formatter_simpleFormatter]
+format=%(asctime)s [%(levelname)s] - %(message)s
+datefmt=%Y-%m-%d %H:%M:%S
+
+[formatter_complexFormatter]
+format=%(asctime)s [%(levelname)s] [%(module)s (%(lineno)d)] - %(message)s
+datefmt=%Y-%m-%d %H:%M:%S
diff --git a/reproduce_results/readme.md b/reproduce_results/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..afbafd09e26d015db44f35040e0b88f47844a8dc
--- /dev/null
+++ b/reproduce_results/readme.md
@@ -0,0 +1,75 @@
+# Synthetic Data Generation Instructions
+
+## Step 1: Configure Parameters
+
+1. Go to the `reproduce_results` folder
+2. Open `settings.sh` and configure the hyperparameters
+
+## Step 2: Generate Synthetic Data
+### For a Single Method
+- Go to the home directory (Do not run from reproduce results)
+- To generate data for a specific method (e.g., RCT), run the following bash script:
+ ```bash
+ bash reproduce_results/create_data/create_rct_data.sh
+ ```
+
+
+**Output**
+
+***Note*** The results are described with respect to the default parameters in settings.sh. They may vary if the names are modified in settings.sh. `
+- Datasets will be saved to: `samples/synthetic/rct/data/`
+- A metadata file will be created at: `samples/synthetic/rct/metadata/rct.json`
+- The metadata file contains the following information about the synthetic data:
+ - True effects
+ - Number of observations
+ - Number of continuous covariates
+ - Number of binary covariates
+
+### For All Methods
+
+To generate synthetic data for all methods in one go:
+```bash
+bash reproduce_results/create_synthetic_data_all.sh
+```
+
+## Step 3: Generate Contextual Information
+
+### For a Single Method
+
+1. Go to the home directory
+2. To generate column labels, backstory, and query for datasets related to a specific method (e.g., RCT), run:
+ ```bash
+ bash reproduce_results/create_context/create_context_rct.sh
+ ```
+
+**Output:** GPT generated information will be saved to: `samples/synthetic/rct/description/rct.json`
+
+### For All Methods
+
+To generate contextual information for all methods at once:
+```bash
+bash reproduce_results/create_context_all.sh
+```
+
+## Step 4: Generate Summary Files
+- Go to the home directory
+- Then run the following command:
+ ```bash
+ bash reproduce_results/finalize_synthetic_dataset.sh
+ ```
+
+### Output Files
+
+The script generates two types of output files:
+
+1. **CAIS Input Files**
+ - They contain all information needed to run CAIS on the synthetic dataset. A separate file is created for each method (rct_info.csv for RCT). Files are saved to `reproduce_results/samples/synthetic/data_info`
+
+2. **Renamed Dataset Files**
+ - Original columns (X1, X2, ..., Y, D) are renamed with real-world variable names generated by GPT in the previous step. The files are saved in `reproduce_results/samples/synthetic/synthetic_data`
+
+## Sample Results
+
+Example outputs can be found in the `samples/synthetic` directory.
+
+
diff --git a/reproduce_results/run_agent.py b/reproduce_results/run_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0cd99c5ed49c7b3e89f4757710d4a716e3040be
--- /dev/null
+++ b/reproduce_results/run_agent.py
@@ -0,0 +1,90 @@
+import os, re, io, time, json, logging, contextlib, textwrap
+from typing import Dict, Any
+import pandas as pd
+import argparse
+from auto_causal.agent import run_causal_analysis
+
+# Constants
+RATE_LIMIT_SECONDS = 2
+
+def run_caia(desc, question, df):
+ return run_causal_analysis(query=question, dataset_path=df, dataset_description=desc)
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Run batch causal analysis.")
+ parser.add_argument("--csv_path", type=str, required=True, help="CSV file with queries, descriptions, and file names.")
+ parser.add_argument("--data_folder", type=str, required=True, help="Folder containing data CSVs.")
+ parser.add_argument("--data_category", type=str, required=True, help="Dataset category (e.g., real, qrdata, synthetic).")
+ parser.add_argument("--output_folder", type=str, required=True, help="Folder to save output.")
+ parser.add_argument("--llm_name", type=str, required=True, help="Name of the LLM used.")
+ return parser.parse_args()
+
+def main():
+
+ args = parse_args()
+ csv_meta = args.csv_meta
+ data_dir = args.data_dir
+ output_json = args.output_json
+ os.environ["LLM_MODEL"] = args.llm_name
+ print("[main] Starting batch processing…")
+
+ if not os.path.exists(csv_meta):
+ logging.error(f"Meta file not found: {csv_meta}")
+ return
+
+ meta_df = pd.read_csv(csv_meta)
+ print(f"[main] Loaded metadata CSV with {len(meta_df)} rows.")
+
+ results: Dict[int, Dict[str, Any]] = {}
+
+ for idx, row in meta_df.iterrows():
+ data_path = os.path.join(data_dir, str(row["data_files"]))
+ print(f"\n[main] Row {idx+1}/{len(meta_df)} → Dataset: {data_path}")
+
+ try:
+ res = run_caia(
+ desc=row["data_description"],
+ question=row["natural_language_query"],
+ df=data_path,
+ )
+
+ # Format result according to specified structure
+ formatted_result = {
+ "query": row["natural_language_query"],
+ "method": row["method"],
+ "answer": row["answer"],
+ "dataset_description": row["data_description"],
+ "dataset_path": data_path,
+ "keywords": row.get("keywords", "Causality, Average treatment effect"),
+ "final_result": {
+ "method": res['results']['results'].get("method_used"),
+ "causal_effect": res['results']['results'].get("effect_estimate"),
+ "standard_deviation": res['results']['results'].get("standard_error"),
+ "treatment_variable": res['results']['variables'].get("treatment_variable", None),
+ "outcome_variable": res['results']['variables'].get("outcome_variable", None),
+ "covariates": res['results']['variables'].get("covariates", []),
+ "instrument_variable": res['results']['variables'].get("instrument_variable", None),
+ "running_variable": res['results']['variables'].get("running_variable", None),
+ "temporal_variable": res['results']['variables'].get("time_variable", None),
+ "statistical_test_results": res.get("summary", ""),
+ "explanation_for_model_choice": res.get("explanation", ""),
+ "regression_equation": res.get("regression_equation", "")
+ }
+ }
+ results[idx] = formatted_result
+ print(type(res))
+ print(res)
+ print(f"[main] Formatted result for row {idx+1}:", formatted_result)
+ except Exception as e:
+ logging.error(f"[{idx+1}] Error: {e}")
+ results[idx] = {"answer": str(e)}
+
+ time.sleep(RATE_LIMIT_SECONDS)
+
+ os.makedirs(os.path.dirname(output_json), exist_ok=True)
+ with open(output_json, "w") as f:
+ json.dump(results, f, indent=2)
+ print(f"[main] Done. Predictions saved to {output_json}")
+
+if __name__ == "__main__":
+ main()
diff --git a/reproduce_results/settings.sh b/reproduce_results/settings.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b46a43be36b7dcfced6e95b69de11324605a2691
--- /dev/null
+++ b/reproduce_results/settings.sh
@@ -0,0 +1,45 @@
+## Here we set all the hyperparams for the synthetic data and define base directories
+
+export BASE_FOLDER="reproduce_results/samples/synthetic"
+
+## dataset sizes
+export RCT_SIZE=10
+export MULTI_RCT_SIZE=5
+export FRONTDOOR_SIZE=5
+export CANONICAL_DID_SIZE=5
+export TWFE_DID_SIZE=5
+export OBSERVATIONAL_SIZE=5
+export IV_SIZE=5
+export ENCOURAGEMENT_SIZE=5
+export RDD_SIZE=5
+export DEFAULT_SIZE=2
+
+## number of observations
+export MIN_OBS=300
+export MAX_OBS=500
+export DEFAULT_OBS=1000
+export DEFAULT_OBS_TWFE=100
+export MIN_OBS_TWFE=50
+export MAX_OBS_TWFE=100
+
+## maximum number of treatments for multi RCT
+export MAX_TREATMENTS=5
+
+## maximum number of periods for TWFE
+export MAX_PERIODS=10
+
+## maximum number of covariates
+export N_CONTINUOUS=5
+export N_CONTINUOUS_MULTI=2
+export N_CONTINUOUS_FRONTDOOR=3
+export N_CONTINUOUS_DID_CANONICAL=2
+export N_CONTINUOUS_DID_TWFE=2
+export N_CONTINUOUS_IV=4
+export N_CONTINUOUS_IV_ENCOURAGEMENT=3
+export N_CONTINUOUS_RDD=2
+
+export N_BINARY=4
+export N_BINARY_OTHERS=3
+
+## cutoff for RDD
+export CUTOFF=25
diff --git a/requirement.txt b/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..00a9de9e4a4670a7b9a121c740e62218e43bdefb
--- /dev/null
+++ b/requirement.txt
@@ -0,0 +1,51 @@
+pandas
+vertexai
+openai
+google-cloud-aiplatform
+docker
+jupyter_client
+ipykernel
+pyzmq
+requests
+python-dotenv
+statsmodels==0.14.4
+seaborn==0.13.2
+shap==0.43.0
+shapely==2.0.7
+pydantic==2.10.6
+pydantic_core==2.27.2
+pydot==3.0.4
+Pygments==2.19.1
+pyparsing==3.2.1
+pytest==8.3.5
+python-dateutil==2.9.0.post0
+python-dotenv==1.0.1
+pytz==2025.1
+PyYAML==6.0.2
+pyzmq==26.3.0
+regex==2024.11.6
+requests==2.32.3
+langchain-anthropic==0.3.10
+langchain-core==0.3.66
+langchain-openai==0.3.9
+langchain-text-splitters==0.3.8
+langchain-together==0.3.0
+langchain-google-genai
+langchain-deepseek
+langchainhub==0.1.21
+langsmith==0.4.1
+lightgbm==4.6.0
+llvmlite==0.44.0
+markdown-it-py==3.0.0
+ipython==8.34.0
+jedi==0.19.2
+Jinja2==3.1.6
+jiter==0.9.0
+joblib==1.4.2
+jsonpatch==1.33
+dowhy==0.12
+econml==0.15.1
+matplotlib==3.10.1
+langchain==0.3.26
+
+
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..86f2ae985f3dd7506b3d131538718ec63d851416
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,126 @@
+import os
+import subprocess
+import sys
+from setuptools import setup, find_packages
+from setuptools.command.install import install
+
+# --- Conda Environment and Dependency Installation ---
+
+def get_conda_path():
+ """Tries to find the path to the conda executable."""
+ # Common locations for conda executable
+ # 1. In the system's PATH
+ try:
+ conda_path = subprocess.check_output("which conda", shell=True).strip().decode('utf-8')
+ if conda_path: return conda_path
+ except subprocess.CalledProcessError:
+ pass # Not in PATH
+
+ # 2. Common installation directories
+ possible_paths = [
+ os.path.expanduser("~/anaconda3/bin/conda"),
+ os.path.expanduser("~/miniconda3/bin/conda"),
+ "/opt/anaconda3/bin/conda",
+ "/opt/miniconda3/bin/conda",
+ ]
+ for path in possible_paths:
+ if os.path.exists(path):
+ return path
+ return None
+
+def conda_env_exists(env_name):
+ """Check if a conda environment with the given name already exists."""
+ try:
+ envs = subprocess.check_output("conda env list", shell=True).decode('utf-8')
+ return any(line.startswith(env_name + ' ') for line in envs.splitlines())
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ return False
+
+class CondaInstallCommand(install):
+ """Custom command to create Conda env and install dependencies before package installation."""
+ description = "Create Conda environment and install dependencies, then install the package."
+
+ def run(self):
+ env_name = "causal-agent"
+ conda_path = get_conda_path()
+
+ if conda_path:
+ print(f"--- Found Conda at: {conda_path} ---")
+
+ if not conda_env_exists(env_name):
+ print(f"--- Creating Conda environment: {env_name} ---")
+ try:
+ # Create the environment with a specific Python version
+ subprocess.check_call(f"{conda_path} create -n {env_name} python=3.10 --yes", shell=True)
+ except subprocess.CalledProcessError as e:
+ print(f"Error creating conda environment: {e}", file=sys.stderr)
+ sys.exit(1)
+ else:
+ print(f"--- Conda environment '{env_name}' already exists. Skipping creation. ---")
+
+ print(f"--- Installing dependencies from requirement.txt into '{env_name}' ---")
+ try:
+ # Command to run pip install within the conda environment
+ pip_install_cmd = f"{conda_path} run -n {env_name} pip install -r requirement.txt"
+ subprocess.check_call(pip_install_cmd, shell=True)
+ print("--- Dependencies installed successfully. ---")
+ except subprocess.CalledProcessError as e:
+ print(f"Error installing dependencies: {e}", file=sys.stderr)
+ sys.exit(1)
+ else:
+ print("--- Conda not found. Skipping environment creation. ---")
+ print("--- Please ensure you have created an environment and installed dependencies manually. ---")
+
+ # Proceed with the standard installation
+ super().run()
+
+
+# --- Standard Setup Configuration ---
+
+# Read the contents of your requirements file
+try:
+ with open('requirement.txt') as f:
+ requirements = f.read().splitlines()
+except FileNotFoundError:
+ print("requirement.txt not found. Please ensure it is in the root directory.", file=sys.stderr)
+ requirements = []
+
+# Read README for long description
+try:
+ with open('README.md', encoding='utf-8') as f:
+ long_description = f.read()
+except FileNotFoundError:
+ long_description = 'A library for automated causal inference.'
+
+
+setup(
+ name='auto_causal',
+ version='0.1.0',
+ author='Vishal Verma',
+ author_email='vishal.verma@andrew.cmu.edu',
+ description='A library for automated causal inference',
+ long_description=long_description,
+ long_description_content_type='text/markdown',
+ url='https://github.com/causalNLP/causal-agent',
+ packages=find_packages(exclude=['tests', 'tests.*']),
+ install_requires=requirements,
+ classifiers=[
+ 'Development Status :: 3 - Alpha',
+ 'Intended Audience :: Developers',
+ 'Intended Audience :: Science/Research',
+ 'License :: OSI Approved :: MIT License',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.8',
+ 'Programming Language :: Python :: 3.9',
+ 'Programming Language :: Python :: 3.10',
+ 'Operating System :: OS Independent',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'Topic :: Scientific/Engineering :: Information Analysis',
+ ],
+ python_requires='>=3.8',
+ include_package_data=True,
+ zip_safe=False,
+ cmdclass={
+ 'install': CondaInstallCommand,
+ }
+)
\ No newline at end of file
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fcc178a52c0de8197fb5bb416299b1959ee1002
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+# Main test suite
\ No newline at end of file
diff --git a/tests/auto_causal/__init__.py b/tests/auto_causal/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f768aa517ba5f192db4bdb0f776caab6c40dfee0
--- /dev/null
+++ b/tests/auto_causal/__init__.py
@@ -0,0 +1 @@
+# Tests for auto_causal module
\ No newline at end of file
diff --git a/tests/auto_causal/components/__init__.py b/tests/auto_causal/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..334a4eb1fb9f0b9aea272b0b4c8e51c53a39316a
--- /dev/null
+++ b/tests/auto_causal/components/__init__.py
@@ -0,0 +1 @@
+# Tests for auto_causal components
\ No newline at end of file
diff --git a/tests/auto_causal/components/test_dataset_analyzer.py b/tests/auto_causal/components/test_dataset_analyzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e344381d4ab9e674d882126de28a939e943ef8d
--- /dev/null
+++ b/tests/auto_causal/components/test_dataset_analyzer.py
@@ -0,0 +1,127 @@
+import unittest
+import os
+import pandas as pd
+import numpy as np
+
+# Import the function to test
+from auto_causal.components.dataset_analyzer import analyze_dataset
+
+# Helper to create dummy dataset files
+def create_dummy_csv_for_analysis(path, data_dict):
+ df = pd.DataFrame(data_dict)
+ df.to_csv(path, index=False)
+ return path
+
+class TestDatasetAnalyzer(unittest.TestCase):
+
+ def setUp(self):
+ '''Set up dummy data paths and create files.'''
+ self.test_files = []
+ # Basic data
+ self.basic_data_path = "analyzer_test_basic.csv"
+ create_dummy_csv_for_analysis(self.basic_data_path, {
+ 'treatment': [0, 1, 0, 1, 0, 1],
+ 'outcome': [10, 12, 11, 13, 9, 14],
+ 'cov1': ['A', 'B', 'A', 'B', 'A', 'B'],
+ 'numeric_cov': [1.1, 2.2, 1.3, 2.5, 1.0, 2.9]
+ })
+ self.test_files.append(self.basic_data_path)
+
+ # Panel data
+ self.panel_data_path = "analyzer_test_panel.csv"
+ create_dummy_csv_for_analysis(self.panel_data_path, {
+ 'unit': [1, 1, 2, 2],
+ 'year': [2000, 2001, 2000, 2001],
+ 'treat': [0, 1, 0, 0],
+ 'value': [5, 6, 7, 7.5]
+ })
+ self.test_files.append(self.panel_data_path)
+
+ # Data with potential instrument
+ self.iv_data_path = "analyzer_test_iv.csv"
+ create_dummy_csv_for_analysis(self.iv_data_path, {
+ 'Z_assigned': [0, 1, 0, 1],
+ 'D_actual': [0, 0, 0, 1],
+ 'Y_outcome': [10, 11, 12, 15]
+ })
+ self.test_files.append(self.iv_data_path)
+
+ # Data with discontinuity
+ self.rdd_data_path = "analyzer_test_rdd.csv"
+ create_dummy_csv_for_analysis(self.rdd_data_path, {
+ 'running_var': [-1.5, -0.5, 0.5, 1.5, -1.1, 0.8],
+ 'outcome_rdd': [4, 5, 10, 11, 4.5, 10.5]
+ })
+ self.test_files.append(self.rdd_data_path)
+
+ def tearDown(self):
+ '''Clean up dummy files.'''
+ for f in self.test_files:
+ if os.path.exists(f):
+ os.remove(f)
+
+ def test_analyze_basic_structure(self):
+ '''Test the basic structure and keys of the summarized output.'''
+ result = analyze_dataset(self.basic_data_path)
+
+ self.assertIsInstance(result, dict)
+ self.assertNotIn("error", result, f"Analysis failed: {result.get('error')}")
+
+ expected_keys = [
+ "dataset_info", "columns", "potential_treatments", "potential_outcomes",
+ "temporal_structure_detected", "panel_data_detected",
+ "potential_instruments_detected", "discontinuities_detected"
+ ]
+ # Check old detailed keys are NOT present
+ unexpected_keys = [
+ "column_types", "column_categories", "missing_values", "correlations",
+ "discontinuities", "variable_relationships", "column_type_summary",
+ "missing_value_summary", "discontinuity_summary", "relationship_summary"
+ ]
+
+ for key in expected_keys:
+ self.assertIn(key, result, f"Expected key '{key}' missing.")
+ for key in unexpected_keys:
+ self.assertNotIn(key, result, f"Unexpected key '{key}' present.")
+
+ # Check some types
+ self.assertIsInstance(result["columns"], list)
+ self.assertIsInstance(result["potential_treatments"], list)
+ self.assertIsInstance(result["potential_outcomes"], list)
+ self.assertIsInstance(result["temporal_structure_detected"], bool)
+ self.assertIsInstance(result["panel_data_detected"], bool)
+ self.assertIsInstance(result["potential_instruments_detected"], bool)
+ self.assertIsInstance(result["discontinuities_detected"], bool)
+
+ def test_analyze_panel_data(self):
+ '''Test detection of panel data structure.'''
+ result = analyze_dataset(self.panel_data_path)
+ self.assertTrue(result["temporal_structure_detected"])
+ self.assertTrue(result["panel_data_detected"])
+ self.assertIn('year', result["columns"]) # Check columns list is correct
+ self.assertIn('unit', result["columns"])
+
+ def test_analyze_iv_data(self):
+ '''Test detection of potential IV.'''
+ result = analyze_dataset(self.iv_data_path)
+ self.assertTrue(result["potential_instruments_detected"])
+
+ def test_analyze_rdd_data(self):
+ '''Test detection of potential discontinuity.'''
+ # Note: Our summarized output only has a boolean flag.
+ # The internal detection logic might be complex, but output is simple.
+ result = analyze_dataset(self.rdd_data_path)
+ # This depends heavily on the thresholds in detect_discontinuities
+ # It might be False if the dummy data doesn't trigger it reliably
+ # self.assertTrue(result["discontinuities_detected"])
+ # For now, just check the key exists
+ self.assertIn("discontinuities_detected", result)
+
+ def test_analyze_file_not_found(self):
+ '''Test handling of non-existent file.'''
+ result = analyze_dataset("non_existent_file.csv")
+ self.assertIn("error", result)
+ self.assertIn("not found", result["error"])
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/components/test_decision_tree.py b/tests/auto_causal/components/test_decision_tree.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca4a9e6250e933c87e171c9667a2c0b5833ccbe7
--- /dev/null
+++ b/tests/auto_causal/components/test_decision_tree.py
@@ -0,0 +1,119 @@
+import pytest
+
+# Import the function to test and constants
+from auto_causal.components.decision_tree import (
+ select_method,
+ METHOD_ASSUMPTIONS, # Import assumptions map
+ REGRESSION_ADJUSTMENT, LINEAR_REGRESSION, LINEAR_REGRESSION_COV,
+ DIFF_IN_DIFF, REGRESSION_DISCONTINUITY, PROPENSITY_SCORE_MATCHING,
+ INSTRUMENTAL_VARIABLE
+)
+
+# --- Test Data Fixtures (Optional, but good practice) ---
+# Using simple dicts for now
+
+@pytest.fixture
+def base_variables():
+ return {
+ "treatment_variable": "T",
+ "outcome_variable": "Y",
+ "covariates": ["X1", "X2"],
+ "time_variable": None,
+ "group_variable": None,
+ "instrument_variable": None,
+ "running_variable": None,
+ "cutoff_value": None
+ }
+
+@pytest.fixture
+def base_dataset_analysis():
+ return {
+ "temporal_structure": False
+ # Add other keys as needed by specific tests, e.g., potential_instruments
+ }
+
+# --- Test Cases ---
+
+def test_no_covariates(base_dataset_analysis, base_variables):
+ """Test: No covariates provided -> Regression Adjustment"""
+ variables = base_variables.copy()
+ variables["covariates"] = []
+ result = select_method(base_dataset_analysis, variables, is_rct=False)
+ assert result["selected_method"] == REGRESSION_ADJUSTMENT
+ assert "no covariates" in result["method_justification"].lower()
+ assert result["method_assumptions"] == METHOD_ASSUMPTIONS[REGRESSION_ADJUSTMENT]
+
+def test_rct_no_covariates(base_dataset_analysis, base_variables):
+ """Test: RCT, no covariates -> Linear Regression"""
+ variables = base_variables.copy()
+ variables["covariates"] = [] # Explicitly empty
+ # Even though the first check catches empty covariates, test RCT path specifically
+ result = select_method(base_dataset_analysis, variables, is_rct=True)
+ # The initial check for no covariates takes precedence
+ assert result["selected_method"] == REGRESSION_ADJUSTMENT
+ # assert result["selected_method"] == LINEAR_REGRESSION # This won't be reached
+
+def test_rct_with_covariates(base_dataset_analysis, base_variables):
+ """Test: RCT with covariates -> Linear Regression with Covariates"""
+ variables = base_variables.copy()
+ result = select_method(base_dataset_analysis, variables, is_rct=True)
+ assert result["selected_method"] == LINEAR_REGRESSION_COV
+ assert "rct" in result["method_justification"].lower()
+ assert "covariates are provided" in result["method_justification"].lower()
+ assert result["method_assumptions"] == METHOD_ASSUMPTIONS[LINEAR_REGRESSION_COV]
+
+def test_observational_temporal(base_dataset_analysis, base_variables):
+ """Test: Observational, temporal structure -> DiD"""
+ variables = base_variables.copy()
+ variables["time_variable"] = "time"
+ variables["group_variable"] = "unit" # Often needed for DiD context
+ dataset_analysis = base_dataset_analysis.copy()
+ dataset_analysis["temporal_structure"] = True
+ result = select_method(dataset_analysis, variables, is_rct=False)
+ assert result["selected_method"] == DIFF_IN_DIFF
+ assert "temporal structure" in result["method_justification"].lower()
+ assert result["method_assumptions"] == METHOD_ASSUMPTIONS[DIFF_IN_DIFF]
+
+def test_observational_rdd(base_dataset_analysis, base_variables):
+ """Test: Observational, RDD vars present -> RDD"""
+ variables = base_variables.copy()
+ variables["running_variable"] = "score"
+ variables["cutoff_value"] = 50
+ result = select_method(base_dataset_analysis, variables, is_rct=False)
+ assert result["selected_method"] == REGRESSION_DISCONTINUITY
+ assert "running variable" in result["method_justification"].lower()
+ assert "cutoff" in result["method_justification"].lower()
+ assert result["method_assumptions"] == METHOD_ASSUMPTIONS[REGRESSION_DISCONTINUITY]
+
+def test_observational_iv(base_dataset_analysis, base_variables):
+ """Test: Observational, IV present -> IV"""
+ variables = base_variables.copy()
+ variables["instrument_variable"] = "Z"
+ result = select_method(base_dataset_analysis, variables, is_rct=False)
+ assert result["selected_method"] == INSTRUMENTAL_VARIABLE
+ assert "instrumental variable" in result["method_justification"].lower()
+ assert result["method_assumptions"] == METHOD_ASSUMPTIONS[INSTRUMENTAL_VARIABLE]
+
+def test_observational_confounders_default_psm(base_dataset_analysis, base_variables):
+ """Test: Observational, confounders, no other design -> PSM (default)"""
+ variables = base_variables.copy() # Has covariates by default
+ # Ensure no other conditions are met
+ dataset_analysis = base_dataset_analysis.copy()
+ dataset_analysis["temporal_structure"] = False
+ variables["time_variable"] = None
+ variables["running_variable"] = None
+ variables["instrument_variable"] = None
+
+ result = select_method(dataset_analysis, variables, is_rct=False)
+ assert result["selected_method"] == PROPENSITY_SCORE_MATCHING
+ assert "observed confounders" in result["method_justification"].lower()
+ assert "selected as the default method" in result["method_justification"].lower()
+ assert result["method_assumptions"] == METHOD_ASSUMPTIONS[PROPENSITY_SCORE_MATCHING]
+
+# Note: A specific test for the final fallback (Reg Adjustment for observational
+# with covariates but somehow no other method fits) might be hard to trigger
+# given the current logic defaults to PSM if covariates exist and IV/RDD/DiD don't apply.
+# The initial 'no covariates' test effectively covers the main Reg Adjustment path.
+
+if __name__ == '__main__':
+ pytest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/components/test_decision_tree_llm.py b/tests/auto_causal/components/test_decision_tree_llm.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc54219d1c238b39a63cb9fd94ebf9b628cdfd4f
--- /dev/null
+++ b/tests/auto_causal/components/test_decision_tree_llm.py
@@ -0,0 +1,267 @@
+import unittest
+from unittest.mock import patch, MagicMock
+import json
+
+from langchain_core.language_models import BaseChatModel
+from langchain_core.messages import HumanMessage, AIMessage
+
+from auto_causal.components.decision_tree_llm import DecisionTreeLLMEngine
+from auto_causal.components.decision_tree import (
+ METHOD_ASSUMPTIONS,
+ CORRELATION_ANALYSIS,
+ DIFF_IN_DIFF,
+ INSTRUMENTAL_VARIABLE,
+ LINEAR_REGRESSION,
+ PROPENSITY_SCORE_MATCHING,
+ REGRESSION_DISCONTINUITY,
+ DIFF_IN_MEANS
+)
+
+class TestDecisionTreeLLMEngine(unittest.TestCase):
+
+ def setUp(self):
+ self.engine = DecisionTreeLLMEngine(verbose=False)
+ self.mock_dataset_analysis = {
+ "temporal_structure": {"has_temporal_structure": True, "time_variables": ["year"]},
+ "potential_instruments": ["Z1"],
+ "running_variable_analysis": {"is_candidate": False}
+ }
+ self.mock_variables = {
+ "treatment_variable": "T",
+ "outcome_variable": "Y",
+ "covariates": ["X1", "X2"],
+ "time_variable": "year",
+ "instrument_variable": "Z1",
+ "treatment_variable_type": "binary"
+ }
+ self.mock_llm = MagicMock(spec=BaseChatModel)
+
+ def _create_mock_llm_response(self, response_dict):
+ ai_message = AIMessage(content=json.dumps(response_dict))
+ self.mock_llm.invoke = MagicMock(return_value=ai_message)
+
+ def _create_mock_llm_raw_response(self, raw_content_str):
+ ai_message = AIMessage(content=raw_content_str)
+ self.mock_llm.invoke = MagicMock(return_value=ai_message)
+
+ def test_select_method_rct_no_covariates_llm_selects_diff_in_means(self):
+ self._create_mock_llm_response({
+ "selected_method": DIFF_IN_MEANS,
+ "method_justification": "LLM: RCT with no covariates, DiM is appropriate.",
+ "alternative_methods": []
+ })
+ rct_variables = self.mock_variables.copy()
+ rct_variables["covariates"] = []
+ result = self.engine.select_method(
+ self.mock_dataset_analysis, rct_variables, is_rct=True, llm=self.mock_llm
+ )
+ self.assertEqual(result["selected_method"], DIFF_IN_MEANS)
+ self.assertEqual(result["method_justification"], "LLM: RCT with no covariates, DiM is appropriate.")
+ self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[DIFF_IN_MEANS])
+ self.mock_llm.invoke.assert_called_once()
+
+ def test_select_method_rct_with_covariates_llm_selects_linear_regression(self):
+ self._create_mock_llm_response({
+ "selected_method": LINEAR_REGRESSION,
+ "method_justification": "LLM: RCT with covariates, Linear Regression for precision.",
+ "alternative_methods": []
+ })
+ result = self.engine.select_method(
+ self.mock_dataset_analysis, self.mock_variables, is_rct=True, llm=self.mock_llm
+ )
+ self.assertEqual(result["selected_method"], LINEAR_REGRESSION)
+ self.assertEqual(result["method_justification"], "LLM: RCT with covariates, Linear Regression for precision.")
+ self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[LINEAR_REGRESSION])
+
+ def test_select_method_observational_temporal_llm_selects_did(self):
+ self._create_mock_llm_response({
+ "selected_method": DIFF_IN_DIFF,
+ "method_justification": "LLM: Observational with temporal data, DiD selected.",
+ "alternative_methods": [INSTRUMENTAL_VARIABLE]
+ })
+ result = self.engine.select_method(
+ self.mock_dataset_analysis, self.mock_variables, is_rct=False, llm=self.mock_llm
+ )
+ self.assertEqual(result["selected_method"], DIFF_IN_DIFF)
+ self.assertEqual(result["method_justification"], "LLM: Observational with temporal data, DiD selected.")
+ self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[DIFF_IN_DIFF])
+ self.assertEqual(result["alternative_methods"], [INSTRUMENTAL_VARIABLE])
+
+ def test_select_method_observational_instrument_llm_selects_iv(self):
+ # Modify dataset analysis to not strongly suggest DiD
+ no_temporal_analysis = self.mock_dataset_analysis.copy()
+ no_temporal_analysis["temporal_structure"] = {"has_temporal_structure": False}
+
+ self._create_mock_llm_response({
+ "selected_method": INSTRUMENTAL_VARIABLE,
+ "method_justification": "LLM: Observational with instrument, IV selected.",
+ "alternative_methods": []
+ })
+ result = self.engine.select_method(
+ no_temporal_analysis, self.mock_variables, is_rct=False, llm=self.mock_llm
+ )
+ self.assertEqual(result["selected_method"], INSTRUMENTAL_VARIABLE)
+ self.assertEqual(result["method_justification"], "LLM: Observational with instrument, IV selected.")
+ self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[INSTRUMENTAL_VARIABLE])
+
+ def test_select_method_observational_running_var_llm_selects_rdd(self):
+ rdd_analysis = self.mock_dataset_analysis.copy()
+ rdd_analysis["temporal_structure"] = {"has_temporal_structure": False} # Make DiD less likely
+ rdd_variables = self.mock_variables.copy()
+ rdd_variables["instrument_variable"] = None # Make IV less likely
+ rdd_variables["running_variable"] = "age"
+ rdd_variables["cutoff_value"] = 65
+
+ self._create_mock_llm_response({
+ "selected_method": REGRESSION_DISCONTINUITY,
+ "method_justification": "LLM: Running var and cutoff, RDD selected.",
+ "alternative_methods": []
+ })
+ result = self.engine.select_method(
+ rdd_analysis, rdd_variables, is_rct=False, llm=self.mock_llm
+ )
+ self.assertEqual(result["selected_method"], REGRESSION_DISCONTINUITY)
+ self.assertEqual(result["method_justification"], "LLM: Running var and cutoff, RDD selected.")
+ self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[REGRESSION_DISCONTINUITY])
+
+ def test_select_method_observational_covariates_llm_selects_psm(self):
+ psm_analysis = {"temporal_structure": {"has_temporal_structure": False}}
+ psm_variables = {
+ "treatment_variable": "T", "outcome_variable": "Y", "covariates": ["X1", "X2"],
+ "treatment_variable_type": "binary"
+ }
+ self._create_mock_llm_response({
+ "selected_method": PROPENSITY_SCORE_MATCHING,
+ "method_justification": "LLM: Observational with covariates, PSM.",
+ "alternative_methods": []
+ })
+ result = self.engine.select_method(
+ psm_analysis, psm_variables, is_rct=False, llm=self.mock_llm
+ )
+ self.assertEqual(result["selected_method"], PROPENSITY_SCORE_MATCHING)
+ self.assertEqual(result["method_justification"], "LLM: Observational with covariates, PSM.")
+ self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[PROPENSITY_SCORE_MATCHING])
+
+ def test_select_method_no_llm_provided_defaults_to_correlation(self):
+ result = self.engine.select_method(
+ self.mock_dataset_analysis, self.mock_variables, is_rct=False, llm=None
+ )
+ self.assertEqual(result["selected_method"], CORRELATION_ANALYSIS)
+ self.assertIn("LLM client not provided", result["method_justification"])
+ self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[CORRELATION_ANALYSIS])
+
+ def test_select_method_llm_returns_malformed_json_defaults_to_correlation(self):
+ self._create_mock_llm_raw_response("This is not a valid JSON")
+ result = self.engine.select_method(
+ self.mock_dataset_analysis, self.mock_variables, is_rct=False, llm=self.mock_llm
+ )
+ self.assertEqual(result["selected_method"], CORRELATION_ANALYSIS)
+ self.assertIn("LLM response was not valid JSON", result["method_justification"])
+ self.assertIn("This is not a valid JSON", result["method_justification"])
+ self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[CORRELATION_ANALYSIS])
+
+ def test_select_method_llm_returns_unknown_method_defaults_to_correlation(self):
+ self._create_mock_llm_response({
+ "selected_method": "SUPER_NOVEL_METHOD_X",
+ "method_justification": "LLM thinks this is best.",
+ "alternative_methods": []
+ })
+ result = self.engine.select_method(
+ self.mock_dataset_analysis, self.mock_variables, is_rct=False, llm=self.mock_llm
+ )
+ self.assertEqual(result["selected_method"], CORRELATION_ANALYSIS)
+ self.assertIn("LLM output was problematic (selected: SUPER_NOVEL_METHOD_X)", result["method_justification"])
+ self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[CORRELATION_ANALYSIS])
+
+ def test_select_method_llm_call_raises_exception_defaults_to_correlation(self):
+ self.mock_llm.invoke = MagicMock(side_effect=Exception("LLM API Error"))
+ result = self.engine.select_method(
+ self.mock_dataset_analysis, self.mock_variables, is_rct=False, llm=self.mock_llm
+ )
+ self.assertEqual(result["selected_method"], CORRELATION_ANALYSIS)
+ self.assertIn("An unexpected error occurred during LLM method selection.", result["method_justification"])
+ self.assertIn("LLM API Error", result["method_justification"])
+ self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[CORRELATION_ANALYSIS])
+
+ def test_prompt_construction_content(self):
+ actual_prompt_generated = [] # List to capture the prompt
+
+ # Store the original method before patching
+ original_construct_prompt = self.engine._construct_prompt
+
+ def side_effect_for_construct_prompt(dataset_analysis, variables, is_rct):
+ # Call the original _construct_prompt method using the stored original
+ # self.engine is the instance, so it's implicitly passed if original_construct_prompt is bound
+ # However, to be explicit and safe, if we treat original_construct_prompt as potentially unbound:
+ prompt = original_construct_prompt(dataset_analysis, variables, is_rct)
+ actual_prompt_generated.append(prompt)
+ return prompt
+
+ with patch.object(self.engine, '_construct_prompt', side_effect=side_effect_for_construct_prompt) as mock_construct_prompt:
+ self._create_mock_llm_response({ # Need a mock response for the select_method to run
+ "selected_method": DIFF_IN_DIFF, "method_justification": "Test", "alternative_methods": []
+ })
+ self.engine.select_method(self.mock_dataset_analysis, self.mock_variables, False, self.mock_llm)
+
+ mock_construct_prompt.assert_called_once_with(self.mock_dataset_analysis, self.mock_variables, False)
+
+ self.assertTrue(actual_prompt_generated, "Prompt was not generated or captured by side_effect")
+ prompt_string = actual_prompt_generated[0]
+
+ self.assertIn("You are an expert in causal inference.", prompt_string)
+ self.assertIn(json.dumps(self.mock_dataset_analysis, indent=2), prompt_string)
+ self.assertIn(json.dumps(self.mock_variables, indent=2), prompt_string)
+ self.assertIn("Is the data from a Randomized Controlled Trial (RCT)? No", prompt_string)
+ self.assertIn(f"- {DIFF_IN_DIFF}", prompt_string) # Check if method descriptions are there
+ self.assertIn(f"- {INSTRUMENTAL_VARIABLE}", prompt_string)
+ self.assertIn("Output your final decision as a JSON object", prompt_string)
+
+ def test_llm_response_with_triple_backticks_json(self):
+ raw_response = """
+Some conversational text before the JSON.
+```json
+{
+ "selected_method": "difference_in_differences",
+ "method_justification": "LLM reasoned and selected DiD.",
+ "alternative_methods": ["instrumental_variable"]
+}
+```
+And some text after.
+ """
+ self._create_mock_llm_raw_response(raw_response)
+ result = self.engine.select_method(self.mock_dataset_analysis, self.mock_variables, False, self.mock_llm)
+ self.assertEqual(result["selected_method"], DIFF_IN_DIFF)
+ self.assertEqual(result["method_justification"], "LLM reasoned and selected DiD.")
+
+ def test_llm_response_with_triple_backticks_only(self):
+ raw_response = """
+```
+{
+ "selected_method": "difference_in_differences",
+ "method_justification": "LLM reasoned and selected DiD with only triple backticks.",
+ "alternative_methods": ["instrumental_variable"]
+}
+```
+ """
+ self._create_mock_llm_raw_response(raw_response)
+ result = self.engine.select_method(self.mock_dataset_analysis, self.mock_variables, False, self.mock_llm)
+ self.assertEqual(result["selected_method"], DIFF_IN_DIFF)
+ self.assertEqual(result["method_justification"], "LLM reasoned and selected DiD with only triple backticks.")
+
+
+ def test_llm_response_plain_json(self):
+ raw_response = """
+{
+ "selected_method": "difference_in_differences",
+ "method_justification": "LLM reasoned and selected DiD plain JSON.",
+ "alternative_methods": ["instrumental_variable"]
+}
+ """
+ self._create_mock_llm_raw_response(raw_response)
+ result = self.engine.select_method(self.mock_dataset_analysis, self.mock_variables, False, self.mock_llm)
+ self.assertEqual(result["selected_method"], DIFF_IN_DIFF)
+ self.assertEqual(result["method_justification"], "LLM reasoned and selected DiD plain JSON.")
+
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/components/test_input_parser.py b/tests/auto_causal/components/test_input_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..78eb14cc5a4b405e96e93afedeeb6bf435ca4af5
--- /dev/null
+++ b/tests/auto_causal/components/test_input_parser.py
@@ -0,0 +1,100 @@
+import pytest
+import os
+import pandas as pd
+
+# Import the refactored parse_input function
+from auto_causal.components import input_parser
+
+# Check if OpenAI API key is available, skip if not
+api_key_present = bool(os.environ.get("OPENAI_API_KEY"))
+skip_if_no_key = pytest.mark.skipif(not api_key_present, reason="OPENAI_API_KEY environment variable not set")
+
+@skip_if_no_key
+def test_parse_input_with_real_llm():
+ """Tests the parse_input function invoking the actual LLM.
+
+ Note: This test requires the OPENAI_API_KEY environment variable to be set
+ and will make a real API call.
+ """
+ # --- Test Case 1: Effect query with dataset and constraint ---
+ query1 = "analyze the effect of 'Minimum Wage Increase' on 'Unemployment Rate' using data/county_data.csv where year > 2010"
+
+ # Provide some dummy dataset context
+ dataset_info1 = {
+ 'columns': ['County', 'Year', 'Minimum Wage Increase', 'Unemployment Rate', 'Population'],
+ 'column_types': {'County': 'object', 'Year': 'int64', 'Minimum Wage Increase': 'int64', 'Unemployment Rate': 'float64', 'Population': 'int64'},
+ 'sample_rows': [
+ {'County': 'A', 'Year': 2009, 'Minimum Wage Increase': 0, 'Unemployment Rate': 5.5, 'Population': 10000},
+ {'County': 'A', 'Year': 2011, 'Minimum Wage Increase': 1, 'Unemployment Rate': 6.0, 'Population': 10200}
+ ]
+ }
+
+ # Create a dummy data file for path checking (relative to workspace root)
+ dummy_file_path = "data/county_data.csv"
+ os.makedirs(os.path.dirname(dummy_file_path), exist_ok=True)
+ with open(dummy_file_path, 'w') as f:
+ f.write("County,Year,Minimum Wage Increase,Unemployment Rate,Population\n")
+ f.write("A,2009,0,5.5,10000\n")
+ f.write("A,2011,1,6.0,10200\n")
+
+ result1 = input_parser.parse_input(query=query1, dataset_info=dataset_info1)
+
+ # Clean up dummy file
+ if os.path.exists(dummy_file_path):
+ os.remove(dummy_file_path)
+ # Try removing the directory if empty
+ try:
+ os.rmdir(os.path.dirname(dummy_file_path))
+ except OSError:
+ pass # Ignore if directory is not empty or other error
+
+ # Assertions for Test Case 1
+ assert result1 is not None
+ assert result1['original_query'] == query1
+ assert result1['query_type'] == "EFFECT_ESTIMATION"
+ assert result1['dataset_path'] == dummy_file_path # Check if path extraction worked
+
+ # Check variables (allowing for some LLM interpretation flexibility)
+ assert 'treatment' in result1['extracted_variables']
+ assert 'outcome' in result1['extracted_variables']
+ # Check if the core variable names are present in the extracted lists
+ assert any('Minimum Wage Increase' in t for t in result1['extracted_variables'].get('treatment', []))
+ assert any('Unemployment Rate' in o for o in result1['extracted_variables'].get('outcome', []))
+
+ # Check constraints
+ assert isinstance(result1['constraints'], list)
+ # Check if a constraint related to 'year > 2010' was captured (LLM might phrase it differently)
+ assert any('year' in c.lower() and '2010' in c for c in result1.get('constraints', [])), "Constraint 'year > 2010' not found or not parsed correctly."
+
+ # --- Test Case 2: Counterfactual without dataset path ---
+ query2 = "What would sales have been if we hadn't run the 'Summer Sale' campaign?"
+ dataset_info2 = {
+ 'columns': ['Date', 'Sales', 'Summer Sale', 'Competitor Activity'],
+ 'column_types': { 'Date': 'datetime64[ns]', 'Sales': 'float64', 'Summer Sale': 'int64', 'Competitor Activity': 'float64'}
+ }
+
+ result2 = input_parser.parse_input(query=query2, dataset_info=dataset_info2)
+
+ # Assertions for Test Case 2
+ assert result2 is not None
+ assert result2['query_type'] == "COUNTERFACTUAL"
+ assert result2['dataset_path'] is None # No path mentioned or inferrable here
+ assert any('Summer Sale' in t for t in result2['extracted_variables'].get('treatment', []))
+ assert any('Sales' in o for o in result2['extracted_variables'].get('outcome', []))
+ assert not result2['constraints'] # No constraints expected
+
+ # --- Test Case 3: Simple query, LLM might fail validation? ---
+ # This tests if the retry/failure mechanism logs warnings but doesn't crash
+ # (Assuming LLM might struggle to extract treatment/outcome from just "sales vs ads")
+ query3 = "sales vs ads"
+ dataset_info3 = {
+ 'columns': ['sales', 'ads'],
+ 'column_types': {'sales': 'float', 'ads': 'float'}
+ }
+ result3 = input_parser.parse_input(query=query3, dataset_info=dataset_info3)
+ assert result3 is not None
+ # LLM might fail extraction; check default/fallback values
+ # Query type might default to OTHER or CORRELATION/DESCRIPTIVE
+ # Variables might be empty or partially filled
+ # This mainly checks that the function completes without error even if LLM fails
+ print(f"Result for ambiguous query: {result3}")
\ No newline at end of file
diff --git a/tests/auto_causal/components/test_state_manager.py b/tests/auto_causal/components/test_state_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..622be718e484980b9b9212dc41b38813ca898ecf
--- /dev/null
+++ b/tests/auto_causal/components/test_state_manager.py
@@ -0,0 +1,26 @@
+import unittest
+from auto_causal.components.state_manager import create_workflow_state_update
+
+class TestStateManagerUtils(unittest.TestCase):
+
+ def test_create_workflow_state_update(self):
+ '''Test the workflow state update utility function.'''
+ current = "step_A"
+ flag = "step_A_done"
+ next_tool = "tool_B"
+ reason = "Reason for B"
+
+ expected_output = {
+ "workflow_state": {
+ "current_step": current,
+ flag: True,
+ "next_tool": next_tool,
+ "next_step_reason": reason
+ }
+ }
+
+ actual_output = create_workflow_state_update(current, flag, next_tool, reason)
+ self.assertDictEqual(actual_output, expected_output)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/methods/__init__.py b/tests/auto_causal/methods/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c42781d158f21a9e0f150436adfee9d5c44ff91e
--- /dev/null
+++ b/tests/auto_causal/methods/__init__.py
@@ -0,0 +1 @@
+# Tests for auto_causal methods
\ No newline at end of file
diff --git a/tests/auto_causal/methods/backdoor_adjustment/__init__.py b/tests/auto_causal/methods/backdoor_adjustment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tests/auto_causal/methods/backdoor_adjustment/test_backdoor_diagnostics.py b/tests/auto_causal/methods/backdoor_adjustment/test_backdoor_diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed9204493a6dc2969a0152ddca40d5d58a12d22d
--- /dev/null
+++ b/tests/auto_causal/methods/backdoor_adjustment/test_backdoor_diagnostics.py
@@ -0,0 +1,90 @@
+import pytest
+import pandas as pd
+import numpy as np
+import statsmodels.api as sm
+from auto_causal.methods.backdoor_adjustment.diagnostics import run_backdoor_diagnostics
+
+# --- Fixture for confounded data ---
+@pytest.fixture
+def sample_confounded_data():
+ """Generates synthetic data with confounding for testing diagnostics."""
+ np.random.seed(789)
+ n_samples = 200
+ W1 = np.random.normal(0, 1, n_samples)
+ W2 = np.random.normal(2, 1, n_samples)
+ treatment_prob = 1 / (1 + np.exp(-(0.5 * W1 - 0.5)))
+ treatment = np.random.binomial(1, treatment_prob, n_samples)
+ true_effect = 3.0
+ error = np.random.normal(0, 1, n_samples)
+ outcome = 10 + true_effect * treatment + 2.0 * W1 - 1.0 * W2 + error
+
+ df = pd.DataFrame({
+ 'outcome': outcome,
+ 'treatment': treatment,
+ 'confounder1': W1,
+ 'confounder2': W2
+ })
+ return df
+
+def test_run_backdoor_diagnostics_success(sample_confounded_data):
+ """Tests the diagnostics function with real results."""
+ # Run a regression to get a real results object
+ df_analysis = sample_confounded_data.dropna()
+ treatment = 'treatment'
+ covariates = ['confounder1', 'confounder2']
+ X = df_analysis[[treatment] + covariates]
+ X = sm.add_constant(X)
+ y = df_analysis['outcome']
+ model = sm.OLS(y, X)
+ results = model.fit()
+
+ # Run diagnostics
+ diagnostics = run_backdoor_diagnostics(results, X)
+
+ assert isinstance(diagnostics, dict)
+ assert diagnostics["status"] == "Success"
+ assert "details" in diagnostics
+ details = diagnostics["details"]
+
+ # Check for key OLS diagnostic metrics
+ assert "r_squared" in details
+ assert "adj_r_squared" in details
+ assert "f_statistic" in details
+ assert "f_p_value" in details
+ assert "n_observations" in details
+ assert "degrees_of_freedom_resid" in details
+ assert "durbin_watson" in details
+
+ # Check normality test results
+ assert "residuals_normality_jb_stat" in details
+ assert "residuals_normality_jb_p_value" in details
+ assert "residuals_skewness" in details
+ assert "residuals_kurtosis" in details
+ assert "residuals_normality_status" in details
+
+ # Check homoscedasticity test results
+ assert "homoscedasticity_bp_lm_stat" in details
+ assert "homoscedasticity_bp_lm_p_value" in details
+ assert "homoscedasticity_status" in details
+
+ # Check multicollinearity proxy
+ assert "model_condition_number" in details
+ assert "multicollinearity_status" in details
+
+ # Check placeholder status
+ assert "linearity_check" in details
+ assert details["linearity_check"] == "Requires visual inspection (e.g., residual vs fitted plot)"
+
+ # Check types (basic)
+ assert isinstance(details["r_squared"], float)
+ assert isinstance(details["f_p_value"], float)
+ assert isinstance(details["n_observations"], int)
+
+def test_run_backdoor_diagnostics_failure():
+ """Test diagnostic failure mode (e.g., passing wrong object)."""
+ # Pass a non-results object
+ # Need a dummy X with matching columns expected by the function if it gets that far
+ dummy_X = pd.DataFrame({'const': [1], 'treatment': [0], 'cov1': [1]})
+ diagnostics = run_backdoor_diagnostics("not a results object", dummy_X)
+ assert diagnostics["status"] == "Failed"
+ assert "error" in diagnostics
diff --git a/tests/auto_causal/methods/backdoor_adjustment/test_backdoor_estimator.py b/tests/auto_causal/methods/backdoor_adjustment/test_backdoor_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cc721b3e56b1f52e731fa62d3c5caee39fca692
--- /dev/null
+++ b/tests/auto_causal/methods/backdoor_adjustment/test_backdoor_estimator.py
@@ -0,0 +1,98 @@
+import pytest
+import pandas as pd
+import numpy as np
+import statsmodels.api as sm
+from statsmodels.iolib.summary import Summary
+from unittest.mock import patch, MagicMock
+
+from auto_causal.methods.backdoor_adjustment.estimator import estimate_effect
+
+# --- Fixtures ---
+
+@pytest.fixture
+def sample_confounded_data():
+ """Generates synthetic data with confounding."""
+ np.random.seed(789)
+ n_samples = 200
+ # Confounder affects both treatment and outcome
+ W1 = np.random.normal(0, 1, n_samples)
+ W2 = np.random.normal(2, 1, n_samples)
+ # Treatment depends on confounder W1
+ treatment_prob = 1 / (1 + np.exp(-(0.5 * W1 - 0.5)))
+ treatment = np.random.binomial(1, treatment_prob, n_samples)
+ # Outcome depends on treatment and confounders W1, W2
+ true_effect = 3.0
+ error = np.random.normal(0, 1, n_samples)
+ outcome = 10 + true_effect * treatment + 2.0 * W1 - 1.0 * W2 + error
+
+ df = pd.DataFrame({
+ 'outcome': outcome,
+ 'treatment': treatment,
+ 'confounder1': W1,
+ 'confounder2': W2,
+ 'irrelevant_var': np.random.rand(n_samples) # Not a confounder
+ })
+ return df
+
+# --- Test Cases ---
+
+@patch('auto_causal.methods.backdoor_adjustment.estimator.run_backdoor_diagnostics')
+@patch('auto_causal.methods.backdoor_adjustment.estimator.interpret_backdoor_results')
+def test_estimate_effect_basic(mock_interpret, mock_diagnostics, sample_confounded_data):
+ """Test basic execution with a valid adjustment set."""
+ mock_diagnostics.return_value = {"status": "Success", "details": {}}
+ mock_interpret.return_value = "LLM Interpretation"
+ adjustment_set = ['confounder1', 'confounder2']
+
+ results = estimate_effect(sample_confounded_data, 'treatment', 'outcome', adjustment_set)
+
+ assert 'effect_estimate' in results
+ assert 'p_value' in results
+ assert 'confidence_interval' in results
+ assert 'standard_error' in results
+ assert 'formula' in results
+ assert 'model_summary' in results
+ assert 'diagnostics' in results
+ assert 'interpretation' in results
+ assert 'method_used' in results
+
+ # Check if effect estimate is reasonably close to the true effect (3.0)
+ assert abs(results['effect_estimate'] - 3.0) < 1.0
+ assert "outcome ~ treatment + confounder1 + confounder2 + const" in results['formula']
+ assert results['method_used'] == 'Backdoor Adjustment (OLS)'
+ assert isinstance(results['model_summary'], Summary)
+
+ mock_diagnostics.assert_called_once()
+ mock_interpret.assert_called_once()
+
+def test_estimate_effect_missing_treatment(sample_confounded_data):
+ """Test error handling for missing treatment column."""
+ with pytest.raises(ValueError, match="Missing required columns.*:.*missing_treat"):
+ estimate_effect(sample_confounded_data, 'missing_treat', 'outcome', ['confounder1'])
+
+def test_estimate_effect_missing_outcome(sample_confounded_data):
+ """Test error handling for missing outcome column."""
+ with pytest.raises(ValueError, match="Missing required columns.*:.*missing_outcome"):
+ estimate_effect(sample_confounded_data, 'treatment', 'missing_outcome', ['confounder1'])
+
+def test_estimate_effect_missing_covariate(sample_confounded_data):
+ """Test error handling for missing covariate column in adjustment set."""
+ with pytest.raises(ValueError, match="Missing required columns.*:.*missing_cov"):
+ estimate_effect(sample_confounded_data, 'treatment', 'outcome', ['confounder1', 'missing_cov'])
+
+def test_estimate_effect_empty_covariates(sample_confounded_data):
+ """Test error handling when covariate list is empty."""
+ with pytest.raises(ValueError, match="Backdoor Adjustment requires a non-empty list of covariates"):
+ estimate_effect(sample_confounded_data, 'treatment', 'outcome', [])
+ with pytest.raises(ValueError, match="Backdoor Adjustment requires a non-empty list of covariates"):
+ estimate_effect(sample_confounded_data, 'treatment', 'outcome', None) # type: ignore
+
+def test_estimate_effect_nan_data():
+ """Test handling of data with NaNs resulting in empty analysis set."""
+ df_nan = pd.DataFrame({
+ 'outcome': [np.nan, 2, 3, 4], # Ensure all rows have NaN in required cols
+ 'treatment': [0, np.nan, 1, 1],
+ 'covariate1': [5, 6, np.nan, np.nan]
+ })
+ with pytest.raises(ValueError, match="No data remaining after dropping NaNs"):
+ estimate_effect(df_nan, 'treatment', 'outcome', ['covariate1'])
diff --git a/tests/auto_causal/methods/backdoor_adjustment/test_backdoor_llm_assist.py b/tests/auto_causal/methods/backdoor_adjustment/test_backdoor_llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..679b711ebf3b0432f84344f6d95cf4bc9a00a272
--- /dev/null
+++ b/tests/auto_causal/methods/backdoor_adjustment/test_backdoor_llm_assist.py
@@ -0,0 +1,146 @@
+import pytest
+from unittest.mock import MagicMock, patch
+import pandas as pd
+from auto_causal.methods.backdoor_adjustment.llm_assist import (
+ identify_backdoor_set,
+ interpret_backdoor_results
+)
+
+# Patch target for the helper function where it's used
+LLM_ASSIST_MODULE = "auto_causal.methods.backdoor_adjustment.llm_assist"
+
+@pytest.fixture
+def mock_llm():
+ """Fixture for a basic mock LLM object."""
+ return MagicMock()
+
+@pytest.fixture
+def mock_ols_results():
+ """Creates a mock statsmodels OLS results object."""
+ results = MagicMock()
+ treatment_var = 'treatment'
+ covs = ['confounder1', 'confounder2']
+ results.params = pd.Series({'const': 10.0, treatment_var: 3.1, **{c: i*0.5 for i, c in enumerate(covs)}})
+ results.pvalues = pd.Series({'const': 0.1, treatment_var: 0.02, **{c: 0.1 for c in covs}})
+ conf_int_df = pd.DataFrame([[2.1, 4.1]], index=[treatment_var], columns=[0, 1])
+ results.conf_int.return_value = conf_int_df
+ results.rsquared = 0.6
+ results.rsquared_adj = 0.55
+ return results
+
+@pytest.fixture
+def mock_backdoor_diagnostics():
+ """Creates a mock diagnostics dictionary."""
+ return {
+ "status": "Success",
+ "details": {
+ 'r_squared': 0.6,
+ 'residuals_normality_status': "Non-Normal",
+ 'homoscedasticity_status': "Homoscedastic",
+ 'multicollinearity_status': "Low"
+ }
+ }
+
+# --- Tests for identify_backdoor_set ---
+
+@patch(f"{LLM_ASSIST_MODULE}.call_llm_with_json_output")
+def test_identify_backdoor_set_llm_success(mock_call_json, mock_llm):
+ """Test successful identification of backdoor set via LLM."""
+ mock_call_json.return_value = {"suggested_backdoor_set": ["W1", "W2"]}
+ df_cols = ['Y', 'T', 'W1', 'W2', 'X']
+ treatment = 'T'
+ outcome = 'Y'
+ query = "Effect of T on Y?"
+
+ result = identify_backdoor_set(df_cols, treatment, outcome, query=query, llm=mock_llm)
+
+ assert result == ["W1", "W2"]
+ mock_call_json.assert_called_once()
+ # Check prompt content
+ call_args, _ = mock_call_json.call_args
+ prompt = call_args[1]
+ assert "'T' on 'Y'" in prompt
+ assert "Available variables" in prompt
+ assert "'W1', 'W2', 'X'" in prompt # Check potential confounders list
+ assert "Return ONLY a valid JSON" in prompt
+ assert "suggested_backdoor_set" in prompt
+
+@patch(f"{LLM_ASSIST_MODULE}.call_llm_with_json_output")
+def test_identify_backdoor_set_llm_combines_existing(mock_call_json, mock_llm):
+ """Test that LLM suggestions are combined with existing covariates."""
+ mock_call_json.return_value = {"suggested_backdoor_set": ["W2", "W3"]}
+ df_cols = ['Y', 'T', 'W1', 'W2', 'W3']
+ existing = ["W1", "W2"] # User provided W1, W2
+
+ result = identify_backdoor_set(df_cols, 'T', 'Y', existing_covariates=existing, llm=mock_llm)
+
+ # Order should be existing + suggested, with duplicates removed
+ assert result == ["W1", "W2", "W3"]
+ mock_call_json.assert_called_once()
+
+def test_identify_backdoor_set_no_llm():
+ """Test behavior when no LLM is provided."""
+ df_cols = ['Y', 'T', 'W1']
+ existing = ["W1"]
+ result_with = identify_backdoor_set(df_cols, 'T', 'Y', existing_covariates=existing, llm=None)
+ assert result_with == ["W1"] # Returns only existing
+
+ result_without = identify_backdoor_set(df_cols, 'T', 'Y', llm=None)
+ assert result_without == [] # Returns empty list
+
+@patch(f"{LLM_ASSIST_MODULE}.call_llm_with_json_output")
+def test_identify_backdoor_set_llm_fail(mock_call_json, mock_llm):
+ """Test behavior when LLM call fails or returns bad format."""
+ mock_call_json.return_value = None
+ df_cols = ['Y', 'T', 'W1']
+ existing = ["W1"]
+ result = identify_backdoor_set(df_cols, 'T', 'Y', existing_covariates=existing, llm=mock_llm)
+ assert result == ["W1"] # Should return only existing on failure
+ mock_call_json.assert_called_once()
+
+# --- Tests for interpret_backdoor_results ---
+
+@patch(f"{LLM_ASSIST_MODULE}.call_llm_with_json_output")
+def test_interpret_backdoor_results_implementation(mock_call_json, mock_llm, mock_ols_results, mock_backdoor_diagnostics):
+ """Test the implemented backdoor results interpretation function."""
+ treatment_var = 'treatment'
+ covariates = ['confounder1', 'confounder2']
+ mock_interpretation_text = "After adjusting, treatment has a significant positive effect, assuming confounders are controlled."
+ mock_call_json.return_value = {"interpretation": mock_interpretation_text}
+
+ # --- Test with LLM ---
+ interp_with_llm = interpret_backdoor_results(
+ mock_ols_results,
+ mock_backdoor_diagnostics,
+ treatment_var,
+ covariates,
+ llm=mock_llm
+ )
+
+ assert interp_with_llm == mock_interpretation_text
+ mock_call_json.assert_called_once()
+ # Basic check on the prompt structure passed to the helper
+ call_args, _ = mock_call_json.call_args
+ prompt = call_args[1]
+ assert "Backdoor Adjustment (Regression) results" in prompt
+ assert "Results Summary:" in prompt
+ assert "Diagnostics Summary" in prompt
+ assert "Treatment Effect Estimate': '3.100" in prompt
+ assert "Adjustment Set (Covariates Used)': ['confounder1', 'confounder2']" in prompt
+ assert "relies heavily on the assumption" in prompt # Check assumption emphasis
+ assert "confounder1', 'confounder2" in prompt
+ assert "Residuals Normality Status': 'Non-Normal'" in prompt
+ assert "Return ONLY a valid JSON" in prompt
+
+ # --- Test LLM Call Failure ---
+ mock_call_json.reset_mock()
+ mock_call_json.return_value = None
+ interp_fail = interpret_backdoor_results(mock_ols_results, mock_backdoor_diagnostics, treatment_var, covariates, llm=mock_llm)
+ assert "LLM interpretation not available for Backdoor Adjustment" in interp_fail
+ mock_call_json.assert_called_once()
+
+ # --- Test without LLM ---
+ mock_call_json.reset_mock()
+ interp_no_llm = interpret_backdoor_results(mock_ols_results, mock_backdoor_diagnostics, treatment_var, covariates, llm=None)
+ assert "LLM interpretation not available for Backdoor Adjustment" in interp_no_llm
+ mock_call_json.assert_not_called()
diff --git a/tests/auto_causal/methods/diff_in_means/__init__.py b/tests/auto_causal/methods/diff_in_means/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tests/auto_causal/methods/diff_in_means/test_dim_diagnostics.py b/tests/auto_causal/methods/diff_in_means/test_dim_diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a7180c15c61e4aa7f5033900a6e6d575501ee7f
--- /dev/null
+++ b/tests/auto_causal/methods/diff_in_means/test_dim_diagnostics.py
@@ -0,0 +1,73 @@
+import pytest
+import pandas as pd
+import numpy as np
+from auto_causal.methods.diff_in_means.diagnostics import run_dim_diagnostics
+
+# --- Fixture ---
+@pytest.fixture
+def sample_stats_data():
+ """Generates data for testing diagnostic stats."""
+ df = pd.DataFrame({
+ 'outcome': [10, 12, 11, 20, 25, 22, 100], # Group 0: 10, 11, 12; Group 1: 20, 22, 25; Group 2: 100
+ 'treatment': [ 0, 1, 0, 1, 1, 1, 2],
+ 'another': [1,1,1,1,1,1,1]
+ })
+ return df
+
+# --- Test Cases ---
+
+def test_run_dim_diagnostics_success(sample_stats_data):
+ """Test successful calculation of group stats."""
+ df = sample_stats_data[sample_stats_data['treatment'].isin([0, 1])] # Filter to binary
+ results = run_dim_diagnostics(df, 'treatment', 'outcome')
+
+ assert results['status'] == "Success"
+ assert "details" in results
+ details = results['details']
+
+ assert "control_group_stats" in details
+ assert "treated_group_stats" in details
+
+ control = details['control_group_stats']
+ treated = details['treated_group_stats']
+
+ assert control['count'] == 2
+ assert treated['count'] == 4
+ assert np.isclose(control['mean'], 10.5)
+ assert np.isclose(treated['mean'], 19.75)
+ assert np.isclose(control['std'], np.std([10, 11], ddof=1)) # Pandas uses ddof=1
+ assert np.isclose(treated['std'], np.std([12, 20, 25, 22], ddof=1))
+ assert "variance_homogeneity_status" in details
+ assert details['variance_homogeneity_status'] == "Potentially Unequal (ratio > 4 or < 0.25)"
+
+def test_run_dim_diagnostics_empty_group(sample_stats_data):
+ """Test warning when one group is empty."""
+ df_one_group = sample_stats_data[sample_stats_data['treatment'] == 1]
+ results = run_dim_diagnostics(df_one_group, 'treatment', 'outcome')
+
+ assert results['status'] == "Warning - Empty Group(s)"
+ assert results['details']['control_group_stats']['count'] == 0
+ assert results['details']['treated_group_stats']['count'] == 4
+
+def test_run_dim_diagnostics_key_error(sample_stats_data):
+ """Test that function runs successfully even with non-0/1 levels, but only calculates for 0/1."""
+ # Use data with treatment = 2 present
+ results = run_dim_diagnostics(sample_stats_data, 'treatment', 'outcome')
+
+ # Expect success because groups 0 and 1 exist and stats can be calculated for them
+ assert results['status'] == "Success"
+ details = results['details']
+ assert 'control_group_stats' in details
+ assert 'treated_group_stats' in details
+ assert details['control_group_stats']['count'] == 2 # Should still find group 0
+ assert details['treated_group_stats']['count'] == 4 # Should still find group 1
+
+def test_run_dim_diagnostics_zero_variance(sample_stats_data):
+ """Test handling when one group has zero variance."""
+ df_zero_var = pd.DataFrame({
+ 'outcome': [10, 10, 20, 25, 22],
+ 'treatment': [ 0, 0, 1, 1, 1],
+ })
+ results = run_dim_diagnostics(df_zero_var, 'treatment', 'outcome')
+ assert results['status'] == "Success"
+ assert results['details']['variance_homogeneity_status'] == "Could not calculate (zero variance in a group)"
diff --git a/tests/auto_causal/methods/diff_in_means/test_dim_estimator.py b/tests/auto_causal/methods/diff_in_means/test_dim_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..e39c754c744d7e37fbef5f2cae44999901db30d8
--- /dev/null
+++ b/tests/auto_causal/methods/diff_in_means/test_dim_estimator.py
@@ -0,0 +1,100 @@
+import pytest
+import pandas as pd
+import numpy as np
+import statsmodels.api as sm
+from unittest.mock import patch, MagicMock
+
+from auto_causal.methods.diff_in_means.estimator import estimate_effect
+
+# --- Fixtures ---
+
+@pytest.fixture
+def sample_rct_data():
+ """Generates simple synthetic RCT data."""
+ np.random.seed(456)
+ n_samples = 150
+ treatment_effect = 5.0
+ treatment = np.random.binomial(1, 0.5, n_samples)
+ error = np.random.normal(0, 2, n_samples)
+ # Simple outcome model: baseline + treatment effect + noise
+ outcome = 20.0 + treatment_effect * treatment + error
+
+ df = pd.DataFrame({
+ 'outcome': outcome,
+ 'treatment': treatment,
+ 'ignored_covariate': np.random.rand(n_samples)
+ })
+ return df
+
+# --- Test Cases ---
+
+@patch('auto_causal.methods.diff_in_means.estimator.run_dim_diagnostics')
+@patch('auto_causal.methods.diff_in_means.estimator.interpret_dim_results')
+def test_estimate_effect_basic(mock_interpret, mock_diagnostics, sample_rct_data):
+ """Test basic execution and output structure."""
+ mock_diagnostics.return_value = {"status": "Success", "details": {'control_group_stats': {}, 'treated_group_stats': {}}}
+ mock_interpret.return_value = "LLM Interpretation"
+
+ results = estimate_effect(sample_rct_data, 'treatment', 'outcome')
+
+ assert 'effect_estimate' in results
+ assert 'p_value' in results
+ assert 'confidence_interval' in results
+ assert 'standard_error' in results
+ assert 'formula' in results
+ assert 'model_summary' in results
+ assert 'diagnostics' in results
+ assert 'interpretation' in results
+ assert 'method_used' in results
+
+ # Check if effect estimate is reasonably close to the true effect (5.0)
+ assert abs(results['effect_estimate'] - 5.0) < 1.0 # Allow some margin
+ assert results['formula'] == "outcome ~ treatment + const"
+ assert results['method_used'] == 'Difference in Means (OLS)'
+ assert isinstance(results['model_summary'], sm.iolib.summary.Summary)
+
+ mock_diagnostics.assert_called_once()
+ mock_interpret.assert_called_once()
+
+def test_estimate_effect_ignores_kwargs(sample_rct_data):
+ """Test that extra kwargs (like covariates) are ignored."""
+ # Should run without error and produce same results as basic test
+ with patch('auto_causal.methods.diff_in_means.estimator.run_dim_diagnostics') as mock_diag, \
+ patch('auto_causal.methods.diff_in_means.estimator.interpret_dim_results') as mock_interp:
+ results = estimate_effect(sample_rct_data, 'treatment', 'outcome', covariates=['ignored_covariate'])
+
+ assert results['formula'] == "outcome ~ treatment + const"
+ assert abs(results['effect_estimate'] - 5.0) < 1.0
+ mock_diag.assert_called_once()
+ mock_interp.assert_called_once()
+
+def test_estimate_effect_missing_treatment(sample_rct_data):
+ """Test error handling for missing treatment column."""
+ with pytest.raises(ValueError, match="Missing required columns:.*missing_treat.*"):
+ estimate_effect(sample_rct_data, 'missing_treat', 'outcome')
+
+def test_estimate_effect_missing_outcome(sample_rct_data):
+ """Test error handling for missing outcome column."""
+ with pytest.raises(ValueError, match="Missing required columns:.*missing_outcome.*"):
+ estimate_effect(sample_rct_data, 'treatment', 'missing_outcome')
+
+def test_estimate_effect_non_binary_treatment(sample_rct_data):
+ """Test warning for non-binary treatment column."""
+ df_non_binary = sample_rct_data.copy()
+ df_non_binary['treatment'] = np.random.randint(0, 3, size=len(df_non_binary)) # 0, 1, 2
+
+ with pytest.warns(UserWarning, match="Treatment column 'treatment' contains values other than 0 and 1"):
+ # We still expect it to run the OLS under the hood
+ with patch('auto_causal.methods.diff_in_means.estimator.run_dim_diagnostics'), \
+ patch('auto_causal.methods.diff_in_means.estimator.interpret_dim_results'):
+ results = estimate_effect(df_non_binary, 'treatment', 'outcome')
+ assert 'effect_estimate' in results # Check it still produced output
+
+def test_estimate_effect_nan_data():
+ """Test handling of data with NaNs resulting in empty analysis set."""
+ df_nan = pd.DataFrame({
+ 'outcome': [np.nan, 2, np.nan], # Ensure all rows have NaN in required cols
+ 'treatment': [0, np.nan, 1],
+ })
+ with pytest.raises(ValueError, match="No data remaining after dropping NaNs"):
+ estimate_effect(df_nan, 'treatment', 'outcome')
diff --git a/tests/auto_causal/methods/diff_in_means/test_dim_llm_assist.py b/tests/auto_causal/methods/diff_in_means/test_dim_llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..e24f13cbfa66738a20f512600474595f17db4dab
--- /dev/null
+++ b/tests/auto_causal/methods/diff_in_means/test_dim_llm_assist.py
@@ -0,0 +1,77 @@
+import pytest
+from unittest.mock import MagicMock, patch
+import pandas as pd
+from auto_causal.methods.diff_in_means.llm_assist import interpret_dim_results
+
+# Patch target for the helper function where it's used
+LLM_ASSIST_MODULE = "auto_causal.methods.diff_in_means.llm_assist"
+
+@pytest.fixture
+def mock_llm():
+ """Fixture for a basic mock LLM object."""
+ return MagicMock()
+
+@pytest.fixture
+def mock_dim_ols_results():
+ """Creates a mock statsmodels OLS results object for DiM."""
+ results = MagicMock()
+ treatment_var = 'treatment'
+ results.params = pd.Series({'const': 20.0, treatment_var: 5.1})
+ results.pvalues = pd.Series({'const': 0.001, treatment_var: 0.03})
+ # Mock conf_int() to return a DataFrame-like object accessible by .loc
+ conf_int_df = pd.DataFrame([[0.5, 9.7]], index=[treatment_var], columns=[0, 1])
+ results.conf_int.return_value = conf_int_df
+ return results
+
+@pytest.fixture
+def mock_dim_diagnostics():
+ """Creates a mock diagnostics dictionary for DiM."""
+ return {
+ "status": "Success",
+ "details": {
+ 'control_group_stats': {'mean': 20.1, 'std': 2.0, 'count': 75},
+ 'treated_group_stats': {'mean': 25.2, 'std': 2.2, 'count': 75},
+ 'variance_homogeneity_status': "Likely Similar"
+ }
+ }
+
+@patch(f"{LLM_ASSIST_MODULE}.call_llm_with_json_output")
+def test_interpret_dim_results_implementation(mock_call_json, mock_llm, mock_dim_ols_results, mock_dim_diagnostics):
+ """Test the implemented DiM results interpretation function."""
+ treatment_var = 'treatment'
+ mock_interpretation_text = "The treatment group had a significantly higher average outcome."
+ mock_call_json.return_value = {"interpretation": mock_interpretation_text}
+
+ # --- Test with LLM ---
+ interp_with_llm = interpret_dim_results(
+ mock_dim_ols_results,
+ mock_dim_diagnostics,
+ treatment_var,
+ llm=mock_llm
+ )
+
+ assert interp_with_llm == mock_interpretation_text
+ mock_call_json.assert_called_once()
+ # Basic check on the prompt structure passed to the helper
+ call_args, call_kwargs = mock_call_json.call_args
+ prompt = call_args[1] # Second argument is the prompt string
+ assert "Difference in Means results" in prompt
+ assert "Results Summary:" in prompt
+ assert "Effect Estimate (Difference in Means)': '5.100" in prompt # Check formatting
+ assert "Control Group Mean Outcome': '20.100" in prompt
+ assert "Treated Group Mean Outcome': '25.200" in prompt
+ assert "Return ONLY a valid JSON" in prompt
+
+ # --- Test LLM Call Failure ---
+ mock_call_json.reset_mock()
+ mock_call_json.return_value = None # Simulate LLM helper failure
+ interp_fail = interpret_dim_results(mock_dim_ols_results, mock_dim_diagnostics, treatment_var, llm=mock_llm)
+ assert "LLM interpretation not available for Difference in Means" in interp_fail
+ mock_call_json.assert_called_once() # Ensure it was still called
+
+ # --- Test without LLM ---
+ mock_call_json.reset_mock()
+ interp_no_llm = interpret_dim_results(mock_dim_ols_results, mock_dim_diagnostics, treatment_var, llm=None)
+ assert isinstance(interp_no_llm, str)
+ assert "LLM interpretation not available for Difference in Means" in interp_no_llm
+ mock_call_json.assert_not_called() # Ensure helper wasn't called
diff --git a/tests/auto_causal/methods/difference_in_differences/test_did_diagnostics.py b/tests/auto_causal/methods/difference_in_differences/test_did_diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..60a1bb05ca05196eec1edbd9f9d72b9bca2ede95
--- /dev/null
+++ b/tests/auto_causal/methods/difference_in_differences/test_did_diagnostics.py
@@ -0,0 +1,37 @@
+import pytest
+import pandas as pd
+import numpy as np
+from auto_causal.methods.difference_in_differences.diagnostics import validate_parallel_trends
+
+# Fixture (can reuse from estimator tests if needed, or define simpler one)
+@pytest.fixture
+def sample_did_data_diag():
+ df = pd.DataFrame({
+ 'time': [1, 2, 3, 1, 2, 3],
+ 'unit': ['A', 'A', 'A', 'B', 'B', 'B'],
+ 'outcome': [10, 11, 12, 15, 17, 19],
+ 'group': [0, 0, 0, 1, 1, 1] # A is control, B is treated
+ })
+ return df
+
+# Test Cases
+def test_validate_parallel_trends_placeholder(sample_did_data_diag):
+ """Tests the placeholder parallel trends validation function."""
+ # For placeholder, specific args don't matter much
+ results = validate_parallel_trends(
+ sample_did_data_diag,
+ time_var='time',
+ outcome='outcome',
+ group_indicator_col='group',
+ treatment_period_start=3, # Example treatment start
+ dataset_description=None # Add arg
+ )
+
+ assert isinstance(results, dict)
+ # Check for specific placeholder values if they are defined
+ assert results.get('valid') is True # Function defaults to True when test cannot be run
+ # Check the actual detail message returned when test fails on this data
+ assert "Insufficient pre-treatment data or variation" in results.get('details', "")
+
+# Add tests here if/when parallel trends validation is implemented
+# e.g., test_parallel_trends_pass, test_parallel_trends_fail
\ No newline at end of file
diff --git a/tests/auto_causal/methods/difference_in_differences/test_did_estimator.py b/tests/auto_causal/methods/difference_in_differences/test_did_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe7fa419891eceddcf2d726b69550e65219a5b0d
--- /dev/null
+++ b/tests/auto_causal/methods/difference_in_differences/test_did_estimator.py
@@ -0,0 +1,198 @@
+import pytest
+import pandas as pd
+import numpy as np
+import statsmodels.formula.api as smf
+from unittest.mock import patch, MagicMock
+
+# Module containing the function to test
+ESTIMATOR_MODULE = "auto_causal.methods.difference_in_differences.estimator"
+
+# Import the function to test AFTER defining the module path
+from auto_causal.methods.difference_in_differences.estimator import estimate_effect
+
+# --- Fixtures ---
+
+@pytest.fixture
+def sample_did_data():
+ """Generates synthetic panel data suitable for DiD testing."""
+ np.random.seed(2024)
+ n_units = 50
+ n_periods = 10
+ treatment_start_time = 5 # Treatment starts in period 5
+ true_effect = 7.0
+
+ units = np.arange(n_units)
+ periods = np.arange(n_periods)
+
+ # Create panel structure
+ panel_index = pd.MultiIndex.from_product([units, periods], names=['unit_id', 'time_period'])
+ df = pd.DataFrame(index=panel_index).reset_index()
+
+ # Assign treatment group (first half of units)
+ df['group'] = (df['unit_id'] < n_units // 2).astype(int)
+
+ # Create post-treatment indicator - REMOVE this, let estimator call the helper
+ # df['post'] = (df['time_period'] >= treatment_start_time).astype(int)
+
+ # Create interaction term (true treatment effect applied here)
+ # Need 'post' for this, so create it temporarily or adjust outcome formula
+ # Let's adjust outcome formula to not rely on pre-calculated interaction
+ # df['did_interaction'] = df['group'] * df['post']
+ df['is_post_treatment'] = (df['time_period'] >= treatment_start_time).astype(int)
+
+ # Create covariates
+ df['covariate1'] = np.random.normal(5, 1, size=len(df))
+ # Time-varying covariate
+ df['covariate2'] = df['time_period'] * 0.2 + np.random.normal(0, 0.5, size=len(df))
+
+ # Unit and time fixed effects
+ unit_fe = np.random.normal(0, 3, n_units)
+ time_fe = np.random.normal(0, 2, n_periods)
+ df['unit_fe_val'] = df['unit_id'].map(dict(enumerate(unit_fe)))
+ df['time_fe_val'] = df['time_period'].map(dict(enumerate(time_fe)))
+
+ # Generate outcome
+ error = np.random.normal(0, 1, len(df))
+ df['outcome'] = (10 +
+ true_effect * df['group'] * df['is_post_treatment'] + # Use group * post directly
+ df['unit_fe_val'] +
+ df['time_fe_val'] +
+ 0.5 * df['covariate1'] +
+ -0.3 * df['covariate2'] +
+ error)
+
+ # Use 'group' as the main treatment indicator for some tests
+ df['treatment'] = df['group']
+
+ return df
+
+# --- Test Cases ---
+
+# Mock all imported helper functions from estimator module
+@patch(f'{ESTIMATOR_MODULE}.identify_time_variable')
+@patch(f'{ESTIMATOR_MODULE}.determine_treatment_period')
+@patch(f'{ESTIMATOR_MODULE}.identify_treatment_group')
+@patch(f'{ESTIMATOR_MODULE}.create_post_indicator')
+@patch(f'{ESTIMATOR_MODULE}.validate_parallel_trends')
+@patch(f'{ESTIMATOR_MODULE}.format_did_results') # Also mock the formatter
+@patch(f'{ESTIMATOR_MODULE}.smf.ols') # Mock statsmodels itself
+def test_estimate_effect_twfe_no_covariates(
+ mock_ols, mock_formatter, mock_validate, mock_create_post,
+ mock_id_group, mock_det_period, mock_id_time,
+ sample_did_data
+):
+ """Test basic TWFE DiD estimation without covariates."""
+ # Setup mocks for helpers
+ mock_id_time.return_value = 'time_period'
+ mock_det_period.return_value = 5 # Treatment start time
+ mock_id_group.return_value = 'unit_id' # The ID variable for FE/clustering
+ mock_create_post.return_value = (sample_did_data['time_period'] >= 5).astype(int)
+ mock_validate.return_value = {'valid': True, 'details': 'Mocked validation'}
+
+ # Setup mock for statsmodels results
+ mock_fit = MagicMock()
+ interaction_term_formula = "Q('did_interaction')" # Key based on created col name
+ mock_fit.params = pd.Series({interaction_term_formula: 7.1, 'other_coef': 1.0})
+ mock_fit.bse = pd.Series({interaction_term_formula: 0.5, 'other_coef': 0.1})
+ mock_fit.pvalues = pd.Series({interaction_term_formula: 0.001, 'other_coef': 0.1})
+ conf_int_df = pd.DataFrame([[6.1, 8.1]], index=[interaction_term_formula], columns=[0, 1])
+ mock_fit.conf_int.return_value = conf_int_df
+ mock_fit.summary.return_value = "Mock Summary"
+ mock_ols.return_value.fit.return_value = mock_fit
+
+ # Setup mock for formatter to check its input
+ mock_formatter.return_value = {"effect_estimate": 7.1, "method_used": "DiD.TWFE"} # Dummy return
+
+ # Call the function (treatment='group' which is the binary indicator)
+ results = estimate_effect(
+ sample_did_data,
+ treatment='group',
+ outcome='outcome',
+ covariates=[],
+ dataset_description={}
+ )
+
+ # Assertions
+ mock_id_time.assert_called_once()
+ mock_det_period.assert_called_once()
+ mock_id_group.assert_called_once()
+ mock_create_post.assert_called_once()
+ mock_validate.assert_called_once()
+ mock_ols.assert_called_once()
+ mock_formatter.assert_called_once()
+
+ # Check formula passed to OLS
+ call_args, call_kwargs = mock_ols.call_args
+ formula_used = call_kwargs['formula']
+ assert "Q('outcome') ~ Q('did_interaction') + C(unit_id) + C(time_period)" == formula_used
+
+ # Check clustering variable
+ fit_call_args, fit_call_kwargs = mock_ols.return_value.fit.call_args
+ assert fit_call_kwargs['cov_type'] == 'cluster'
+ assert 'groups' in fit_call_kwargs['cov_kwds']
+ # Check if the correct grouping column (unit_id) was used for clustering
+ assert fit_call_kwargs['cov_kwds']['groups'].name == 'unit_id'
+
+ # Check arguments passed to formatter
+ format_call_args, format_call_kwargs = mock_formatter.call_args
+ assert format_call_args[0] == mock_fit # Check results object
+ assert format_call_args[1] == interaction_term_formula # Check interaction term key
+ assert format_call_args[2]["parallel_trends"]['valid'] is True # Check diagnostics
+ assert format_call_kwargs['parameters']['time_var'] == 'time_period'
+ assert format_call_kwargs['parameters']['group_var'] == 'unit_id'
+ assert format_call_kwargs['parameters']['treatment_indicator'] == 'group' # Identified correctly
+ assert format_call_kwargs['parameters']['covariates'] == []
+
+ # Check final output (dummy from formatter mock)
+ assert results['effect_estimate'] == 7.1
+ assert results['method_used'] == 'DiD.TWFE'
+
+# Add more tests: with covariates, missing columns, variable identification scenarios etc.
+# Example for missing column:
+def test_estimate_effect_missing_outcome(sample_did_data):
+ with pytest.raises(ValueError, match="Outcome variable 'missing_outcome' not found"):
+ # Need to mock helpers even for early exit tests if they are called before check
+ with patch(f'{ESTIMATOR_MODULE}.identify_time_variable', return_value='time_period'), \
+ patch(f'{ESTIMATOR_MODULE}.identify_treatment_group', return_value='unit_id'), \
+ patch(f'{ESTIMATOR_MODULE}.determine_treatment_period', return_value=5):
+ estimate_effect(sample_did_data, treatment='group', outcome='missing_outcome', covariates=[], dataset_description={})
+
+# Example for variable identification test
+@patch(f'{ESTIMATOR_MODULE}.identify_time_variable')
+@patch(f'{ESTIMATOR_MODULE}.determine_treatment_period')
+@patch(f'{ESTIMATOR_MODULE}.identify_treatment_group')
+@patch(f'{ESTIMATOR_MODULE}.create_post_indicator')
+@patch(f'{ESTIMATOR_MODULE}.validate_parallel_trends')
+@patch(f'{ESTIMATOR_MODULE}.format_did_results')
+@patch(f'{ESTIMATOR_MODULE}.smf.ols')
+def test_treatment_col_identification(
+ mock_ols, mock_formatter, mock_validate, mock_create_post,
+ mock_id_group, mock_det_period, mock_id_time,
+ sample_did_data
+):
+ """Test identification of the binary treatment indicator."""
+ mock_id_time.return_value = 'time_period'
+ mock_det_period.return_value = 5
+ mock_id_group.return_value = 'unit_id'
+ mock_create_post.return_value = (sample_did_data['time_period'] >= 5).astype(int)
+ mock_validate.return_value = {'valid': True}
+ mock_ols.return_value.fit.return_value = MagicMock() # Don't need detailed results mock
+ mock_formatter.return_value = {}
+
+ # Scenario 1: 'treatment' arg IS the binary indicator ('group' column)
+ estimate_effect(sample_did_data, treatment='group', outcome='outcome', covariates=[], dataset_description={})
+ format_call_args, format_call_kwargs = mock_formatter.call_args
+ assert format_call_kwargs['parameters']['treatment_indicator'] == 'group'
+ mock_formatter.reset_mock()
+
+ # Scenario 2: 'treatment' arg is NOT binary, but 'group' col exists and IS binary
+ # Keep the original 'group' column, just add the non-binary unit_id_str
+ df_modified = sample_did_data.copy()
+ # df_modified = sample_did_data.rename(columns={'group': 'treatment'}) # DON'T rename
+ df_modified['unit_id_str'] = df_modified['unit_id'].astype(str) # Make unit_id non-binary
+ mock_id_group.return_value = 'unit_id_str' # LLM identifies non-binary unit_id
+ # Call estimate_effect with the non-binary unit_id_str as the treatment argument
+ estimate_effect(df_modified, treatment='unit_id_str', outcome='outcome', covariates=[], dataset_description={})
+ format_call_args, format_call_kwargs = mock_formatter.call_args
+ # Should correctly identify the *original* 'group' column via Priority 2 logic
+ assert format_call_kwargs['parameters']['treatment_indicator'] == 'group'
\ No newline at end of file
diff --git a/tests/auto_causal/methods/difference_in_differences/test_did_llm_assist.py b/tests/auto_causal/methods/difference_in_differences/test_did_llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..a10ed66d05e19ad89be8323984dbbe8678e60067
--- /dev/null
+++ b/tests/auto_causal/methods/difference_in_differences/test_did_llm_assist.py
@@ -0,0 +1,139 @@
+import pytest
+from unittest.mock import MagicMock, patch
+import pandas as pd
+
+# Functions to test (assuming they exist in llm_assist)
+from auto_causal.methods.difference_in_differences.llm_assist import (
+ identify_time_variable,
+ determine_treatment_period,
+ identify_treatment_group,
+ interpret_did_results
+)
+
+# Patch target for the helper function if LLM calls are made
+LLM_ASSIST_MODULE = "auto_causal.methods.difference_in_differences.llm_assist"
+
+@pytest.fixture
+def mock_llm():
+ """Fixture for a basic mock LLM object."""
+ return MagicMock()
+
+@pytest.fixture
+def mock_did_results():
+ """Creates a mock DiD results dictionary (output of estimate_effect)."""
+ # This should match the structure returned by difference_in_differences/estimator.py
+ return {
+ 'effect_estimate': 7.1,
+ 'p_value': 0.001,
+ 'confidence_interval': [6.1, 8.1],
+ 'effect_se': 0.5, # Added SE based on format_did_results
+ 'method_used': 'Statsmodels TWFE DiD (Fallback)', # Example method used
+ 'method_details': 'DiD via Statsmodels TWFE (C() Notation)',
+ 'parameters': {
+ 'time_var': 'time_period',
+ 'group_var': 'unit_id',
+ 'treatment_indicator': 'group',
+ 'treatment_period_start': 5,
+ 'covariates': ['cov1']
+ },
+ 'details': "Mock statsmodels summary..." # Placeholder for summary string
+ }
+
+@pytest.fixture
+def mock_did_diagnostics():
+ """Creates a mock DiD diagnostics dictionary."""
+ # This should match the structure returned by difference_in_differences/diagnostics.py
+ return {
+ "status": "Success (Partial Implementation)", # Example status
+ "details": {
+ 'parallel_trends': {'valid': True, 'details': 'Mocked validation', 'p_value': 0.6}
+ }
+ }
+
+@pytest.fixture
+def sample_dataframe():
+ """Simple DataFrame for testing identify functions."""
+ return pd.DataFrame({
+ 'year': [2010, 2011, 2012, 2010, 2011, 2012],
+ 'state_id': [1, 1, 1, 2, 2, 2],
+ 'value': [10, 11, 12, 15, 16, 17],
+ 'treated_state': [0, 0, 0, 1, 1, 1]
+ })
+
+# Test Cases
+
+def test_identify_time_variable_heuristic(sample_dataframe):
+ """Test heuristic identification of time variable."""
+ # Should identify 'year' based on name
+ time_var = identify_time_variable(sample_dataframe, dataset_description={})
+ assert time_var == 'year'
+
+ # Test with no obvious time column
+ df_no_time = sample_dataframe.rename(columns={'year': 'col_a'})
+ time_var_none = identify_time_variable(df_no_time, dataset_description={})
+ assert time_var_none is None
+
+# TODO: Add tests for LLM fallback in identify_time_variable if implemented
+
+def test_determine_treatment_period_heuristic(sample_dataframe):
+ """Test heuristic determination of treatment period."""
+ # Heuristic based on median time (2011), expects first period after median (2012)
+ period = determine_treatment_period(sample_dataframe, time_var='year', treatment='treated_state', dataset_description={})
+ assert period == 2012
+
+ # Test with odd number of periods
+ df_odd = pd.DataFrame({'year': [1, 2, 3, 4, 5], 'treatment': [0,0,1,1,1]})
+ period_odd = determine_treatment_period(df_odd, 'year', 'treatment', dataset_description={})
+ assert period_odd == 4
+ # Let's re-check the heuristic: median index for [1,2,3,4,5] is 2 (value 3). Correct.
+ # The placeholder assumes treatment starts *at* the median period for non-numeric.
+
+# TODO: Add tests for LLM fallback in determine_treatment_period if implemented
+
+def test_identify_treatment_group_placeholder(sample_dataframe):
+ """Test the placeholder function for treatment group identification."""
+ # Placeholder assumes treatment_var is the group_var
+ group_var = identify_treatment_group(sample_dataframe, treatment_var='treated_state', dataset_description={})
+ assert group_var == 'treated_state'
+
+ group_var_id = identify_treatment_group(sample_dataframe, treatment_var='state_id', dataset_description={})
+ # Heuristic finds 'treated_state' as potential ID when 'state_id' is non-binary
+ assert group_var_id == 'treated_state'
+
+# TODO: Add tests for interpret_did_results if implemented
+
+@patch(f"{LLM_ASSIST_MODULE}.call_llm_with_json_output")
+def test_interpret_did_results_implementation(mock_call_json, mock_llm, mock_did_results, mock_did_diagnostics):
+ """Test the implemented DiD interpretation function."""
+ # Mock necessary inputs are now provided by fixtures
+ mock_dataset_desc_str = "This is a mock dataset about smoking."
+ mock_interpretation_text = "DiD shows a significant positive effect..."
+ mock_call_json.return_value = {"interpretation": mock_interpretation_text}
+
+ # Pass the correct fixtures
+ interp = interpret_did_results(mock_did_results, mock_did_diagnostics, mock_dataset_desc_str, llm=mock_llm)
+
+ assert interp == mock_interpretation_text
+ mock_call_json.assert_called_once()
+ call_args, _ = mock_call_json.call_args
+ prompt = call_args[1]
+ assert "DiD results" in prompt
+ assert "Estimation Results Summary:" in prompt
+ assert "Effect Estimate': '7.100" in prompt
+ # Update assertion based on new mock diagnostics structure
+ assert "Parallel Trends Assumption Status': 'Passed (Placeholder)" in prompt
+ assert "Dataset Context Provided:\nThis is a mock dataset about smoking." in prompt
+
+ # --- Test LLM Call Failure ---
+ mock_call_json.reset_mock()
+ mock_call_json.return_value = None # Simulate LLM helper failure
+ interp_fail = interpret_did_results(mock_did_results, mock_did_diagnostics, mock_dataset_desc_str, llm=mock_llm)
+ assert "LLM interpretation not available for DiD" in interp_fail
+ mock_call_json.assert_called_once() # Ensure it was still called
+
+ # --- Test without LLM ---
+ mock_call_json.reset_mock()
+ interp_no_llm = interpret_did_results(mock_did_results, mock_did_diagnostics, mock_dataset_desc_str, llm=None)
+ assert isinstance(interp_no_llm, str)
+ assert "LLM interpretation not available for DiD" in interp_no_llm
+ mock_call_json.assert_not_called() # Ensure helper wasn't called
\ No newline at end of file
diff --git a/tests/auto_causal/methods/difference_in_differences/test_estimator.py b/tests/auto_causal/methods/difference_in_differences/test_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..165d00a1511727befcb2fe1818ccf0cabc937df2
--- /dev/null
+++ b/tests/auto_causal/methods/difference_in_differences/test_estimator.py
@@ -0,0 +1,137 @@
+import pytest
+import pandas as pd
+import numpy as np
+
+# Import the function to test
+from auto_causal.methods.difference_in_differences import estimator as did_estimator
+# Import placeholder diagnostics to check if they are called
+from auto_causal.methods.difference_in_differences import diagnostics as did_diagnostics
+# Import placeholder llm assists to check if they are called
+from auto_causal.methods.difference_in_differences import llm_assist as did_llm_assist
+
+from unittest.mock import patch # For testing placeholder calls
+
+# --- Synthetic Data Generation ---
+def generate_synthetic_did_data(n_units=50, n_periods=10, treatment_start_period=6,
+ treatment_effect=5.0, seed=42):
+ """Generates synthetic panel data suitable for DiD.
+
+ Features:
+ - Units with fixed effects.
+ - Time periods with fixed effects.
+ - A subset of units treated after a specific period.
+ - Observed time-invariant covariate (optional).
+ - Known true treatment effect.
+ """
+ np.random.seed(seed)
+
+ units = range(n_units)
+ periods = range(n_periods)
+
+ df = pd.DataFrame([(u, p) for u in units for p in periods], columns=['unit', 'time'])
+
+ # Unit fixed effects
+ unit_effects = pd.DataFrame({'unit': units, 'unit_fe': np.random.normal(0, 2, n_units)})
+ df = pd.merge(df, unit_effects, on='unit')
+
+ # Time fixed effects
+ time_effects = pd.DataFrame({'time': periods, 'time_fe': np.random.normal(0, 1.5, n_periods)})
+ df = pd.merge(df, time_effects, on='time')
+
+ # Treatment group (e.g., half the units)
+ treated_units = range(n_units // 2)
+ df['group'] = df['unit'].apply(lambda u: 1 if u in treated_units else 0)
+
+ # Treatment indicator (post * treated_group)
+ df['post'] = (df['time'] >= treatment_start_period).astype(int)
+ df['treatment'] = df['group'] * df['post']
+
+ # Outcome model
+ df['outcome'] = (10 +
+ df['unit_fe'] +
+ df['time_fe'] +
+ df['treatment'] * treatment_effect +
+ np.random.normal(0, 1, len(df))) # Noise
+
+ # Select relevant columns for clarity - remove X1
+ df = df[['unit', 'time', 'group', 'treatment', 'outcome']]
+ # Log info about generated data
+ print("\n--- Generated Synthetic Data Info ---")
+ print("Head:\n", df.head())
+ print("\nInfo:")
+ df.info()
+ print("\nDescribe:\n", df.describe())
+ print("-------------------------------------")
+ return df
+
+# --- Test Class ---
+class TestDifferenceInDifferences:
+
+ def test_did_estimate_effect_synthetic(self):
+ """Test the end-to-end estimate_effect with synthetic DiD data."""
+ # Arrange
+ true_effect = 7.0
+ df = generate_synthetic_did_data(n_units=100, n_periods=12, treatment_start_period=8,
+ treatment_effect=true_effect, seed=123)
+
+ # treatment_var should be the ACTUAL 0/1 treatment status indicator
+ treatment_var = 'treatment'
+ outcome_var = 'outcome'
+ time_var = 'time'
+ group_var = 'unit' # The identifier for the panel unit
+ covariates = [] # Was ['X1']
+
+ # Act
+ results = did_estimator.estimate_effect(
+ df=df,
+ treatment=treatment_var,
+ outcome=outcome_var,
+ covariates=covariates,
+ # Pass specific required args for DiD via kwargs
+ time_var=time_var,
+ group_var=group_var,
+ # Explicitly pass the correct treatment start period
+ treatment_period_start=8
+ )
+
+ # Assert
+ assert results is not None
+ assert "error" not in results # Check for errors
+
+ assert 'effect_estimate' in results
+ assert 'effect_se' in results
+ assert 'confidence_interval' in results
+ assert 'diagnostics' in results
+ assert 'parameters' in results
+ assert 'details' in results # Should contain statsmodels summary
+
+ # Check estimate value
+ estimated_effect = results['effect_estimate']
+ assert estimated_effect is not None
+ assert abs(estimated_effect - true_effect) < 1.0 # Tolerance for estimation noise
+
+ # Check SE and CI
+ assert results['effect_se'] is not None and results['effect_se'] > 0
+ assert results['confidence_interval'] is not None
+ assert results['confidence_interval'][0] < estimated_effect < results['confidence_interval'][1]
+
+ # Check diagnostics were actually run and included
+ assert 'parallel_trends' in results['diagnostics']
+ # Assert that the actual validation passed (should for this synthetic data)
+ assert results['diagnostics']['parallel_trends'].get('error') is None
+ assert results['diagnostics']['parallel_trends']['valid'] == True
+ assert results['diagnostics']['parallel_trends']['p_value'] > 0.05 # Check p-value indication
+
+ # Check parameters reflect inputs/defaults
+ assert results['parameters']['time_var'] == time_var
+ assert results['parameters']['group_var'] == group_var
+ assert results['parameters']['treatment_indicator'] == treatment_var
+ assert results['parameters']['covariates'] == covariates
+ assert results['parameters']['estimation_method'] == "Statsmodels OLS with TWFE (C() Notation) and Clustered SE"
+ assert 'treatment_period_start' in results['parameters'] # Check it was determined
+ assert 'interaction_term' in results['parameters']
+
+ # TODO: Add test case where time_var/group_var are not passed,
+ # mock llm_assist functions, and assert they are called.
+ # TODO: Add test case for data where parallel trends validation *should* fail (if implemented).
+ # TODO: Add test case for placebo test validation (if implemented).
\ No newline at end of file
diff --git a/tests/auto_causal/methods/instrumental_variable/test_iv_diagnostics.py b/tests/auto_causal/methods/instrumental_variable/test_iv_diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..94015185569162002d2c65dba40d21aab7a8aaf8
--- /dev/null
+++ b/tests/auto_causal/methods/instrumental_variable/test_iv_diagnostics.py
@@ -0,0 +1,127 @@
+import pytest
+import pandas as pd
+import numpy as np
+from auto_causal.methods.instrumental_variable.diagnostics import (
+ calculate_first_stage_f_statistic,
+ run_overidentification_test
+)
+from statsmodels.sandbox.regression.gmm import IV2SLS
+
+# Fixture for basic IV data
+@pytest.fixture
+def iv_data():
+ np.random.seed(42)
+ n = 1000
+ Z1 = np.random.normal(0, 1, n) # Instrument 1
+ Z2 = np.random.normal(0, 1, n) # Instrument 2 (for over-ID test)
+ W = np.random.normal(0, 1, n) # Exogenous Covariate
+ U = np.random.normal(0, 1, n) # Unobserved Confounder
+
+ # Strong Instrument Case
+ T_strong = 0.5 * Z1 + 0.5 * W + 0.5 * U + np.random.normal(0, 1, n)
+ Y_strong = 2.0 * T_strong + 1.0 * W + 1.0 * U + np.random.normal(0, 1, n)
+
+ # Weak Instrument Case
+ T_weak = 0.05 * Z1 + 0.5 * W + 0.5 * U + np.random.normal(0, 1, n)
+ Y_weak = 2.0 * T_weak + 1.0 * W + 1.0 * U + np.random.normal(0, 1, n)
+
+ df_strong = pd.DataFrame({'Y': Y_strong, 'T': T_strong, 'Z1': Z1, 'Z2': Z2, 'W': W, 'U': U})
+ df_weak = pd.DataFrame({'Y': Y_weak, 'T': T_weak, 'Z1': Z1, 'Z2': Z2, 'W': W, 'U': U})
+
+ return df_strong, df_weak
+
+
+def test_calculate_first_stage_f_statistic_strong(iv_data):
+ df_strong, _ = iv_data
+ f_stat, p_val = calculate_first_stage_f_statistic(
+ df=df_strong, treatment='T', instruments=['Z1'], covariates=['W']
+ )
+ assert f_stat is not None
+ assert p_val is not None
+ assert f_stat > 10 # Expect strong instrument
+ assert p_val < 0.01 # Expect significance
+
+def test_calculate_first_stage_f_statistic_weak(iv_data):
+ _, df_weak = iv_data
+ f_stat, p_val = calculate_first_stage_f_statistic(
+ df=df_weak, treatment='T', instruments=['Z1'], covariates=['W']
+ )
+ assert f_stat is not None
+ assert p_val is not None
+ # Note: With random noise, weak instrument test might occasionally pass 10, but should be low
+ assert f_stat < 15 # Check it's not extremely high
+ # P-value might still be significant if sample size is large
+
+def test_calculate_first_stage_f_statistic_no_instruments(caplog):
+ """Test graceful handling when no instruments are provided."""
+ df = pd.DataFrame({'T': [1, 2], 'W': [3, 4]})
+ # Should now return (None, None) and log a warning, not raise Exception
+ # with pytest.raises(Exception): # OLD assertion
+ # calculate_first_stage_f_statistic(
+ # df=df, treatment='T', instruments=[], covariates=['W']
+ # )
+ f_stat, p_val = calculate_first_stage_f_statistic(
+ df=df, treatment='T', instruments=[], covariates=['W']
+ )
+ assert f_stat is None
+ assert p_val is None
+ assert "No instruments provided" in caplog.text # Check log message
+
+
+def test_run_overidentification_test_applicable(iv_data):
+ df_strong, _ = iv_data
+ # Need to run statsmodels IV first to get results object
+ df_copy = df_strong.copy()
+ df_copy['intercept'] = 1
+ endog = df_copy['Y']
+ exog_vars = ['intercept', 'W', 'T']
+ instrument_vars = ['intercept', 'W', 'Z1', 'Z2'] # Z1, Z2 are instruments
+
+ iv_model = IV2SLS(endog=endog, exog=df_copy[exog_vars], instrument=df_copy[instrument_vars])
+ sm_results = iv_model.fit()
+
+ stat, p_val, status = run_overidentification_test(
+ sm_results=sm_results,
+ df=df_strong,
+ treatment='T',
+ outcome='Y',
+ instruments=['Z1', 'Z2'],
+ covariates=['W']
+ )
+
+ assert "Test successful" in status
+ assert stat is not None
+ assert p_val is not None
+ assert stat >= 0
+ # In this correctly specified model, we expect the test to NOT reject H0 (p > 0.05)
+ assert p_val > 0.05
+
+def test_run_overidentification_test_not_applicable(iv_data):
+ df_strong, _ = iv_data
+ # Only one instrument
+ stat, p_val, status = run_overidentification_test(
+ sm_results=None, # Not needed if not applicable
+ df=df_strong,
+ treatment='T',
+ outcome='Y',
+ instruments=['Z1'],
+ covariates=['W']
+ )
+ assert stat is None
+ assert p_val is None
+ assert "not applicable" in status.lower()
+
+def test_run_overidentification_test_no_sm_results(iv_data):
+ df_strong, _ = iv_data
+ # More than one instrument, but no sm_results provided
+ stat, p_val, status = run_overidentification_test(
+ sm_results=None,
+ df=df_strong,
+ treatment='T',
+ outcome='Y',
+ instruments=['Z1', 'Z2'],
+ covariates=['W']
+ )
+ assert stat is None
+ assert p_val is None
+ assert "object not available" in status.lower()
\ No newline at end of file
diff --git a/tests/auto_causal/methods/instrumental_variable/test_iv_estimator.py b/tests/auto_causal/methods/instrumental_variable/test_iv_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d9425ef2fb5204b653d4f2b8e054eba2a0dcef2
--- /dev/null
+++ b/tests/auto_causal/methods/instrumental_variable/test_iv_estimator.py
@@ -0,0 +1,139 @@
+import pytest
+import pandas as pd
+import numpy as np
+from auto_causal.methods.instrumental_variable.estimator import estimate_effect, build_iv_graph_gml
+
+# Consistent random state for reproducibility
+SEED = 42
+TRUE_EFFECT = 2.0
+
+@pytest.fixture
+def synthetic_iv_data():
+ """Generates synthetic data suitable for IV estimation."""
+ np.random.seed(SEED)
+ n_samples = 2000
+ # Instrument (relevant and exogenous)
+ Z = np.random.normal(loc=5, scale=1, size=n_samples)
+ # Observed Covariate
+ W = np.random.normal(loc=2, scale=1, size=n_samples)
+ # Unobserved Confounder
+ U = np.random.normal(loc=0, scale=1, size=n_samples)
+
+ # Treatment model: T = alpha*Z + beta*W + gamma*U + error
+ # Ensure Z has a reasonably strong effect on T (relevance)
+ T = 0.6 * Z + 0.4 * W + 0.7 * U + np.random.normal(loc=0, scale=1, size=n_samples)
+
+ # Outcome model: Y = TRUE_EFFECT*T + delta*W + eta*U + error
+ # Z should NOT directly affect Y (exclusion restriction)
+ Y = TRUE_EFFECT * T + 0.3 * W + 0.9 * U + np.random.normal(loc=0, scale=1, size=n_samples)
+
+ df = pd.DataFrame({
+ 'outcome': Y,
+ 'treatment': T,
+ 'instrument': Z,
+ 'covariate': W,
+ 'unobserved': U # Keep for reference, but should not be used in estimation
+ })
+ return df
+
+
+def test_build_iv_graph_gml():
+ """Tests the GML graph construction."""
+ gml = build_iv_graph_gml(
+ treatment='T', outcome='Y', instruments=['Z1', 'Z2'], covariates=['W1', 'W2']
+ )
+ assert 'node [ id "T" label "T" ]' in gml
+ assert 'node [ id "Y" label "Y" ]' in gml
+ assert 'node [ id "Z1" label "Z1" ]' in gml
+ assert 'node [ id "Z2" label "Z2" ]' in gml
+ assert 'node [ id "W1" label "W1" ]' in gml
+ assert 'node [ id "W2" label "W2" ]' in gml
+ assert 'node [ id "U" label "U" ]' in gml # Unobserved confounder
+
+ assert 'edge [ source "Z1" target "T" ]' in gml
+ assert 'edge [ source "Z2" target "T" ]' in gml
+ assert 'edge [ source "W1" target "T" ]' in gml
+ assert 'edge [ source "W2" target "T" ]' in gml
+ assert 'edge [ source "W1" target "Y" ]' in gml
+ assert 'edge [ source "W2" target "Y" ]' in gml
+ assert 'edge [ source "T" target "Y" ]' in gml
+ assert 'edge [ source "U" target "T" ]' in gml
+ assert 'edge [ source "U" target "Y" ]' in gml
+
+ assert 'edge [ source "Z1" target "Y" ]' not in gml # Exclusion
+ assert 'edge [ source "Z2" target "Y" ]' not in gml # Exclusion
+
+def test_estimate_effect_dowhy_path(synthetic_iv_data):
+ """Tests the IV estimation using the primary DoWhy path."""
+ df = synthetic_iv_data
+ results = estimate_effect(
+ df=df,
+ treatment='treatment',
+ outcome='outcome',
+ instrument='instrument',
+ covariates=['covariate']
+ )
+
+ print("DoWhy Path Results:", results)
+ assert results is not None
+ assert 'error' not in results
+ assert results['method_used'] == 'dowhy'
+ assert results['effect_estimate'] == pytest.approx(TRUE_EFFECT, abs=0.2) # Allow some tolerance
+ assert 'diagnostics' in results
+ assert results['diagnostics']['first_stage_f_statistic'] > 10
+ assert results['diagnostics']['is_instrument_weak'] is False
+ assert results['diagnostics']['overid_test_applicable'] is False # Only 1 instrument
+
+def test_estimate_effect_statsmodels_fallback(synthetic_iv_data):
+ """Tests the IV estimation using the statsmodels fallback path."""
+ df = synthetic_iv_data
+ results = estimate_effect(
+ df=df,
+ treatment='treatment',
+ outcome='outcome',
+ instrument='instrument',
+ covariates=['covariate'],
+ force_statsmodels=True # Force skipping DoWhy
+ )
+
+ print("Statsmodels Path Results:", results)
+ assert results is not None
+ assert 'error' not in results
+ assert results['method_used'] == 'statsmodels'
+ assert results['effect_estimate'] == pytest.approx(TRUE_EFFECT, abs=0.2)
+ assert 'diagnostics' in results
+ assert results['diagnostics']['first_stage_f_statistic'] > 10
+ assert results['diagnostics']['is_instrument_weak'] is False
+ assert results['diagnostics']['overid_test_applicable'] is False
+
+def test_estimate_effect_missing_column():
+ """Tests error handling for missing columns."""
+ df = pd.DataFrame({'outcome': [1, 2], 'instrument': [3, 4]})
+ results = estimate_effect(
+ df=df,
+ treatment='treatment', # Missing
+ outcome='outcome',
+ instrument='instrument',
+ covariates=[]
+ )
+ assert 'error' in results
+ assert "Missing required columns" in results['error']
+
+def test_estimate_effect_no_instrument():
+ """Tests error handling when no instrument is provided."""
+ df = pd.DataFrame({'outcome': [1, 2], 'treatment': [3, 4]})
+ results = estimate_effect(
+ df=df,
+ treatment='treatment',
+ outcome='outcome',
+ instrument=[], # Empty
+ covariates=[]
+ )
+ assert 'error' in results
+ assert "Instrument variable(s) must be provided" in results['error']
+
+# TODO: Add tests for:
+# - Cases where DoWhy fails and fallback *should* occur
+# - Overidentification test results when applicable (using >1 instrument in synthetic data)
+# - More complex graph structures if needed
+# - Handling of NaNs
\ No newline at end of file
diff --git a/tests/auto_causal/methods/instrumental_variable/test_iv_llm_assist.py b/tests/auto_causal/methods/instrumental_variable/test_iv_llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..96c245fbe69e78d71943dc281aeb0e04a7655350
--- /dev/null
+++ b/tests/auto_causal/methods/instrumental_variable/test_iv_llm_assist.py
@@ -0,0 +1,137 @@
+import pytest
+from unittest.mock import patch, MagicMock
+
+# Import functions to test
+from auto_causal.methods.instrumental_variable.llm_assist import (
+ identify_instrument_variable,
+ validate_instrument_assumptions_qualitative,
+ interpret_iv_results
+)
+# Assume BaseChatModel is importable for type hinting if needed elsewhere,
+# but for mocking we don't strictly need it here.
+# from langchain.chat_models.base import BaseChatModel
+
+# Assume shared helpers are in this location
+# LLM_HELPERS_PATH = "causalscientist.auto_causal.utils.llm_helpers"
+# Correct patch target is where the function is *used*
+LLM_ASSIST_PATH = "auto_causal.methods.instrumental_variable.llm_assist"
+
+@pytest.fixture
+def mock_llm():
+ """Fixture to create a mock LLM object."""
+ return MagicMock() # Basic mock, can be configured in tests
+
+# --- Tests for identify_instrument_variable --- #
+
+@patch(f"{LLM_ASSIST_PATH}.call_llm_with_json_output")
+def test_identify_instrument_variable_success(mock_call_json, mock_llm):
+ """Test successful identification of instruments."""
+ mock_call_json.return_value = {"potential_instruments": ["Z1", "Z2"]}
+ df_cols = ["Y", "T", "Z1", "Z2", "W"]
+ query = "What is the effect of T on Y using Z1 and Z2 as instruments?"
+
+ result = identify_instrument_variable(df_cols, query, llm=mock_llm)
+
+ assert result == ["Z1", "Z2"]
+ mock_call_json.assert_called_once()
+ # TODO: Optionally add assertion on the prompt passed to mock_call_json
+
+@patch(f"{LLM_ASSIST_PATH}.call_llm_with_json_output")
+def test_identify_instrument_variable_llm_fail(mock_call_json, mock_llm):
+ """Test when LLM call fails or returns bad format."""
+ mock_call_json.return_value = None # Simulate failure
+ df_cols = ["Y", "T", "Z1", "Z2", "W"]
+ query = "What is the effect of T on Y using Z1 and Z2 as instruments?"
+
+ result = identify_instrument_variable(df_cols, query, llm=mock_llm)
+
+ assert result == [] # Expect empty list on failure
+ mock_call_json.assert_called_once()
+
+@patch(f"{LLM_ASSIST_PATH}.call_llm_with_json_output")
+def test_identify_instrument_variable_no_llm(mock_call_json):
+ """Test behavior when no LLM is provided."""
+ df_cols = ["Y", "T", "Z1", "Z2", "W"]
+ query = "What is the effect of T on Y using Z1 and Z2 as instruments?"
+
+ result = identify_instrument_variable(df_cols, query, llm=None)
+
+ assert result == []
+ mock_call_json.assert_not_called() # LLM helper should not be called
+
+# --- Tests for validate_instrument_assumptions_qualitative --- #
+
+@patch(f"{LLM_ASSIST_PATH}.call_llm_with_json_output")
+def test_validate_assumptions_success(mock_call_json, mock_llm):
+ """Test successful qualitative validation."""
+ mock_response = {"exclusion_assessment": "Plausible", "exogeneity_assessment": "Likely holds"}
+ mock_call_json.return_value = mock_response
+
+ result = validate_instrument_assumptions_qualitative(
+ treatment='T', outcome='Y', instrument=['Z1'], covariates=['W'], query="Test query", llm=mock_llm
+ )
+
+ assert result == mock_response
+ mock_call_json.assert_called_once()
+
+@patch(f"{LLM_ASSIST_PATH}.call_llm_with_json_output")
+def test_validate_assumptions_llm_fail(mock_call_json, mock_llm):
+ """Test qualitative validation when LLM fails."""
+ mock_call_json.return_value = None # Simulate failure
+
+ result = validate_instrument_assumptions_qualitative(
+ treatment='T', outcome='Y', instrument=['Z1'], covariates=['W'], query="Test query", llm=mock_llm
+ )
+
+ assert result == {"exclusion_assessment": "LLM Check Failed", "exogeneity_assessment": "LLM Check Failed"}
+ mock_call_json.assert_called_once()
+
+@patch(f"{LLM_ASSIST_PATH}.call_llm_with_json_output")
+def test_validate_assumptions_no_llm(mock_call_json):
+ """Test qualitative validation when no LLM is provided."""
+ result = validate_instrument_assumptions_qualitative(
+ treatment='T', outcome='Y', instrument=['Z1'], covariates=['W'], query="Test query", llm=None
+ )
+
+ assert result == {"exclusion_assessment": "LLM Not Provided", "exogeneity_assessment": "LLM Not Provided"}
+ mock_call_json.assert_not_called()
+
+# --- Tests for interpret_iv_results --- #
+
+@patch(f"{LLM_ASSIST_PATH}.call_llm_with_json_output")
+def test_interpret_results_success(mock_call_json, mock_llm):
+ """Test successful interpretation generation."""
+ mock_interpretation_text = "The effect is positive and significant."
+ # Simulate the JSON helper returning a dict containing the text
+ mock_call_json.return_value = {"interpretation": mock_interpretation_text}
+ sample_results = {'effect_estimate': 2.5, 'p_value': 0.01, 'confidence_interval': [1.0, 4.0], 'treatment_variable': 'T', 'outcome_variable': 'Y', 'method_used': 'dowhy'}
+ sample_diagnostics = {'first_stage_f_statistic': 50.0, 'weak_instrument_test_status': 'Strong', 'overid_test_applicable': False}
+
+ result = interpret_iv_results(sample_results, sample_diagnostics, llm=mock_llm)
+
+ assert result == mock_interpretation_text
+ mock_call_json.assert_called_once()
+ # TODO: Optionally add assertion on the prompt passed to mock_call_json
+
+@patch(f"{LLM_ASSIST_PATH}.call_llm_with_json_output")
+def test_interpret_results_llm_fail(mock_call_json, mock_llm):
+ """Test interpretation generation when LLM fails."""
+ mock_call_json.return_value = None # Simulate failure
+ sample_results = {'effect_estimate': 2.5}
+ sample_diagnostics = {}
+
+ result = interpret_iv_results(sample_results, sample_diagnostics, llm=mock_llm)
+
+ assert "LLM interpretation could not be generated" in result
+ mock_call_json.assert_called_once()
+
+@patch(f"{LLM_ASSIST_PATH}.call_llm_with_json_output")
+def test_interpret_results_no_llm(mock_call_json):
+ """Test interpretation generation when no LLM is provided."""
+ sample_results = {'effect_estimate': 2.5}
+ sample_diagnostics = {}
+
+ result = interpret_iv_results(sample_results, sample_diagnostics, llm=None)
+
+ assert "LLM was not available" in result
+ mock_call_json.assert_not_called()
\ No newline at end of file
diff --git a/tests/auto_causal/methods/linear_regression/__init__.py b/tests/auto_causal/methods/linear_regression/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tests/auto_causal/methods/linear_regression/test_linear_regression_estimator.py b/tests/auto_causal/methods/linear_regression/test_linear_regression_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fbbc76ad99a9878655f4ea49b5c137e650fc6cd
--- /dev/null
+++ b/tests/auto_causal/methods/linear_regression/test_linear_regression_estimator.py
@@ -0,0 +1,122 @@
+import unittest
+import pandas as pd
+import numpy as np
+from auto_causal.methods.linear_regression.estimator import estimate_effect
+
+class TestLinearRegressionEstimator(unittest.TestCase):
+
+ def _generate_data(
+ self,
+ n=100,
+ treatment_type='binary_numeric', # binary_numeric, binary_categorical, multi_categorical, continuous
+ reference_level=None,
+ seed=42
+ ):
+ np.random.seed(seed)
+ data = pd.DataFrame()
+ data['X1'] = np.random.normal(0, 1, n)
+ data['X2'] = np.random.rand(n) * 10
+
+ if treatment_type == 'binary_numeric':
+ data['T'] = np.random.binomial(1, 0.5, n)
+ # Y = 10 + 5*T + 3*X1 - 2*X2 + err
+ data['Y'] = 10 + 5 * data['T'] + 3 * data['X1'] - 2 * data['X2'] + np.random.normal(0, 2, n)
+ elif treatment_type == 'binary_categorical':
+ data['T_cat'] = np.random.choice(['Control', 'Treated'], size=n, p=[0.5, 0.5])
+ treatment_map = {'Control': 0, 'Treated': 1}
+ data['Y'] = 10 + 5 * data['T_cat'].map(treatment_map) + 3 * data['X1'] - 2 * data['X2'] + np.random.normal(0, 2, n)
+ elif treatment_type == 'multi_categorical':
+ # Levels: A, B, C. Let C be reference if specified, otherwise Patsy picks one.
+ levels = ['A', 'B', 'C']
+ data['T_multi'] = np.random.choice(levels, size=n, p=[0.3, 0.3, 0.4])
+ # Y = 10 + (5 if T=A else 0) + (-3 if T=B else 0) + 3*X1 - 2*X2 + err (effects relative to C)
+ effect_A = 5
+ effect_B = -3
+ data['Y'] = 10 + \
+ data['T_multi'].apply(lambda x: effect_A if x == 'A' else (effect_B if x == 'B' else 0)) + \
+ 3 * data['X1'] - 2 * data['X2'] + np.random.normal(0, 2, n)
+ elif treatment_type == 'continuous':
+ data['T_cont'] = np.random.normal(5, 2, n)
+ data['Y'] = 10 + 2 * data['T_cont'] + 3 * data['X1'] - 2 * data['X2'] + np.random.normal(0, 2, n)
+
+ return data
+
+ def test_binary_numeric_treatment(self):
+ df = self._generate_data(treatment_type='binary_numeric')
+ results = estimate_effect(df, treatment='T', outcome='Y', covariates=['X1', 'X2'])
+ self.assertIn('effect_estimate', results)
+ self.assertIsNotNone(results['effect_estimate'])
+ self.assertTrue('T' in results['formula'])
+ self.assertFalse('C(T' in results['formula'])
+ self.assertIsNone(results.get('estimated_effects_by_level'))
+ self.assertAlmostEqual(results['effect_estimate'], 5, delta=1.0) # Check if close to true effect
+
+ def test_binary_categorical_treatment(self):
+ df = self._generate_data(treatment_type='binary_categorical')
+ # preprocess_data will likely convert T_cat to 0/1 based on first value as reference if not specified
+ # The estimator then sees it as numeric 0/1 unless it stays category and C() is used.
+ # Current logic in estimator for C(T) is based on dtype and nunique.
+ results = estimate_effect(df, treatment='T_cat', outcome='Y', covariates=['X1', 'X2'])
+ self.assertIn('effect_estimate', results)
+ self.assertIsNotNone(results['effect_estimate'])
+ # Expect C(T_cat) if T_cat is object/category dtype and has 2 unique values.
+ self.assertIn(f"C({df['T_cat'].name})", results['formula'])
+ self.assertIsNone(results.get('estimated_effects_by_level'))
+ self.assertAlmostEqual(results['effect_estimate'], 5, delta=1.5) # Wider delta due to encoding/ref choice
+
+ def test_multi_categorical_treatment_with_reference(self):
+ reference = 'C'
+ df = self._generate_data(treatment_type='multi_categorical', reference_level=reference)
+ results = estimate_effect(
+ df,
+ treatment='T_multi',
+ outcome='Y',
+ covariates=['X1', 'X2'],
+ treatment_reference_level=reference,
+ column_mappings={ # Simulate that preprocess_data did not alter T_multi's type
+ 'T_multi': {'original_dtype': 'object', 'transformed_as': 'original'}
+ }
+ )
+ self.assertIn('estimated_effects_by_level', results)
+ self.assertIsNotNone(results['estimated_effects_by_level'])
+ self.assertEqual(results['reference_level_used'], reference)
+ self.assertTrue(f"C(T_multi, Treatment(reference='{reference}'))" in results['formula'])
+ self.assertIn('A', results['estimated_effects_by_level'])
+ self.assertIn('B', results['estimated_effects_by_level'])
+ self.assertNotIn('C', results['estimated_effects_by_level']) # Reference level should not have its own effect listed
+ self.assertAlmostEqual(results['estimated_effects_by_level']['A']['estimate'], 5, delta=1.5)
+ self.assertAlmostEqual(results['estimated_effects_by_level']['B']['estimate'], -3, delta=1.5)
+ self.assertIsNone(results['effect_estimate']) # Main effect is None for multi-level
+
+ def test_multi_categorical_treatment_no_reference(self):
+ df = self._generate_data(treatment_type='multi_categorical')
+ results = estimate_effect(
+ df,
+ treatment='T_multi',
+ outcome='Y',
+ covariates=['X1', 'X2'],
+ column_mappings={ # Simulate that preprocess_data did not alter T_multi's type
+ 'T_multi': {'original_dtype': 'object', 'transformed_as': 'original'}
+ }
+ )
+ # Without explicit reference, Patsy picks one (usually first alphabetically: A)
+ # The output structure for 'estimated_effects_by_level' would have effects relative to this implicit ref.
+ # The current linear_regression_estimator.py when no ref is given AND it's categorical AND >2 levels
+ # might not populate estimated_effects_by_level clearly. It falls to single effect logic.
+ # This test highlights that area. For now, we check if formula uses C() and some effect is found.
+ self.assertTrue(f"C(T_multi)" in results['formula'] or "T_multi[T." in results['formula'])
+ self.assertIsNotNone(results['effect_estimate']) # It will pick one of the level effects
+ # A more detailed check would be needed for which specific levels are present vs implicit reference.
+
+ def test_continuous_treatment(self):
+ df = self._generate_data(treatment_type='continuous')
+ results = estimate_effect(df, treatment='T_cont', outcome='Y', covariates=['X1', 'X2'])
+ self.assertIn('effect_estimate', results)
+ self.assertIsNotNone(results['effect_estimate'])
+ self.assertTrue('T_cont' in results['formula'])
+ self.assertFalse('C(T_cont' in results['formula'])
+ self.assertIsNone(results.get('estimated_effects_by_level'))
+ self.assertAlmostEqual(results['effect_estimate'], 2, delta=1.0)
+
+if __name__ == '__main__':
+ unittest.main(argv=['first-arg-is-ignored'], exit=False)
\ No newline at end of file
diff --git a/tests/auto_causal/methods/linear_regression/test_lr_diagnostics.py b/tests/auto_causal/methods/linear_regression/test_lr_diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfdee2fde6eec21449221b88c46a752deb5080d6
--- /dev/null
+++ b/tests/auto_causal/methods/linear_regression/test_lr_diagnostics.py
@@ -0,0 +1,88 @@
+import pytest
+import pandas as pd
+import numpy as np
+import statsmodels.api as sm
+from auto_causal.methods.linear_regression.diagnostics import run_lr_diagnostics
+
+# Reuse the sample data fixture from estimator tests
+@pytest.fixture
+def sample_data():
+ """Generates simple synthetic data for testing LR."""
+ np.random.seed(42)
+ n_samples = 100
+ treatment_effect = 2.0
+ X1 = np.random.normal(0, 1, n_samples)
+ X2 = np.random.normal(5, 2, n_samples)
+ treatment = np.random.binomial(1, 0.5, n_samples)
+ error = np.random.normal(0, 1, n_samples)
+ outcome = 1.0 + treatment_effect * treatment + 0.5 * X1 - 1.5 * X2 + error
+
+ df = pd.DataFrame({
+ 'outcome': outcome,
+ 'treatment': treatment,
+ 'covariate1': X1,
+ 'covariate2': X2
+ })
+ return df
+
+def test_run_lr_diagnostics_implementation(sample_data):
+ """Tests the implemented diagnostics function with real results."""
+ # Run a regression to get a real results object
+ df_analysis = sample_data.dropna()
+ covariates = ['covariate1', 'covariate2']
+ X = df_analysis[['treatment'] + covariates]
+ X = sm.add_constant(X)
+ y = df_analysis['outcome']
+ model = sm.OLS(y, X)
+ results = model.fit()
+
+ # Run diagnostics
+ diagnostics = run_lr_diagnostics(results, X)
+
+ assert isinstance(diagnostics, dict)
+ assert diagnostics["status"] == "Success"
+ assert "details" in diagnostics
+ details = diagnostics["details"]
+
+ # Check for key diagnostic metrics
+ assert "r_squared" in details
+ assert "adj_r_squared" in details
+ assert "f_statistic" in details
+ assert "f_p_value" in details
+ assert "n_observations" in details
+ assert "degrees_of_freedom_resid" in details
+
+ # Check normality test results
+ assert "residuals_normality_jb_stat" in details
+ assert "residuals_normality_jb_p_value" in details
+ assert "residuals_skewness" in details
+ assert "residuals_kurtosis" in details
+ assert "residuals_normality_status" in details
+ assert isinstance(details["residuals_normality_status"], str)
+
+ # Check homoscedasticity test results
+ assert "homoscedasticity_bp_lm_stat" in details
+ assert "homoscedasticity_bp_lm_p_value" in details
+ assert "homoscedasticity_bp_f_stat" in details
+ assert "homoscedasticity_bp_f_p_value" in details
+ assert "homoscedasticity_status" in details
+ assert isinstance(details["homoscedasticity_status"], str)
+
+ # Check placeholder statuses
+ assert "linearity_check" in details
+ assert "multicollinearity_check" in details
+ assert details["linearity_check"] == "Requires visual inspection (e.g., residual vs fitted plot)"
+ assert details["multicollinearity_check"] == "Not Implemented (Requires VIF)"
+
+ # Check types (basic)
+ assert isinstance(details["r_squared"], float)
+ assert isinstance(details["f_p_value"], float)
+ assert isinstance(details["n_observations"], int)
+
+def test_run_lr_diagnostics_failure():
+ """Test diagnostic failure mode (e.g., passing wrong object)."""
+ # Pass a non-results object
+ diagnostics = run_lr_diagnostics("not a results object", pd.DataFrame({'const': [1]}))
+ assert diagnostics["status"] == "Failed"
+ assert "error" in diagnostics
+
diff --git a/tests/auto_causal/methods/linear_regression/test_lr_estimator.py b/tests/auto_causal/methods/linear_regression/test_lr_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..45f3cf625972c9143f6f8c484352a90a84aeefee
--- /dev/null
+++ b/tests/auto_causal/methods/linear_regression/test_lr_estimator.py
@@ -0,0 +1,115 @@
+import pytest
+import pandas as pd
+import numpy as np
+import statsmodels.api as sm
+from statsmodels import iolib
+from auto_causal.methods.linear_regression.estimator import estimate_effect
+
+# --- Fixtures ---
+
+@pytest.fixture
+def sample_data():
+ """Generates simple synthetic data for testing LR."""
+ np.random.seed(42)
+ n_samples = 100
+ treatment_effect = 2.0
+ X1 = np.random.normal(0, 1, n_samples)
+ X2 = np.random.normal(5, 2, n_samples)
+ treatment = np.random.binomial(1, 0.5, n_samples)
+ error = np.random.normal(0, 1, n_samples)
+ outcome = 1.0 + treatment_effect * treatment + 0.5 * X1 - 1.5 * X2 + error
+
+ df = pd.DataFrame({
+ 'outcome': outcome,
+ 'treatment': treatment,
+ 'covariate1': X1,
+ 'covariate2': X2,
+ 'other_col': np.random.rand(n_samples) # Unused column
+ })
+ return df
+
+# --- Test Cases ---
+
+def test_estimate_effect_no_covariates(sample_data):
+ """Test estimating effect without covariates."""
+ results = estimate_effect(sample_data, 'treatment', 'outcome')
+
+ assert 'effect_estimate' in results
+ assert 'p_value' in results
+ assert 'confidence_interval' in results
+ assert 'standard_error' in results
+ assert 'formula' in results
+ assert 'model_summary' in results
+ assert 'diagnostics' in results # Placeholder check
+ assert 'interpretation' in results # Placeholder check
+ assert 'method_used' in results
+
+ # Check if effect estimate is reasonably close (simple check)
+ assert abs(results['effect_estimate'] - 2.0) < 0.5
+ assert 'treatment' in results['formula']
+ assert 'covariate1' not in results['formula']
+ assert results['method_used'] == 'Linear Regression (OLS)'
+
+def test_estimate_effect_with_covariates(sample_data):
+ """Test estimating effect with covariates."""
+ covariates = ['covariate1', 'covariate2']
+ results = estimate_effect(sample_data, 'treatment', 'outcome', covariates)
+
+ assert 'effect_estimate' in results
+ assert 'p_value' in results
+ assert 'confidence_interval' in results
+ assert 'standard_error' in results
+
+ # Check if effect estimate is reasonably close to the true effect (2.0)
+ assert abs(results['effect_estimate'] - 2.0) < 0.5
+ assert 'treatment' in results['formula']
+ assert 'covariate1' in results['formula']
+ assert 'covariate2' in results['formula']
+ assert results['method_used'] == 'Linear Regression (OLS)'
+ # Check summary type (basic check)
+ assert isinstance(results['model_summary'], sm.iolib.summary.Summary)
+
+def test_estimate_effect_missing_treatment(sample_data):
+ """Test error handling for missing treatment column."""
+ with pytest.raises(ValueError, match="Missing required columns:.*missing_treat.*"):
+ estimate_effect(sample_data, 'missing_treat', 'outcome')
+
+def test_estimate_effect_missing_outcome(sample_data):
+ """Test error handling for missing outcome column."""
+ with pytest.raises(ValueError, match="Missing required columns:.*missing_outcome.*"):
+ estimate_effect(sample_data, 'treatment', 'missing_outcome')
+
+def test_estimate_effect_missing_covariate(sample_data):
+ """Test error handling for missing covariate column."""
+ with pytest.raises(ValueError, match="Missing required columns:.*missing_cov.*"):
+ estimate_effect(sample_data, 'treatment', 'outcome', ['covariate1', 'missing_cov'])
+
+def test_estimate_effect_nan_data():
+ """Test handling of data with NaNs resulting in empty analysis set."""
+ df_nan = pd.DataFrame({
+ 'outcome': [1, np.nan, 3],
+ 'treatment': [0, np.nan, 1],
+ 'covariate1': [np.nan, 6, 7] # Add NaN here to ensure row 0 is dropped
+ })
+ # With this setup, row 0 has NaN in covariate1
+ # Row 1 has NaN in outcome and treatment
+ # Only row 2 is complete, but dropna() needs *all* specified cols
+ # to be non-NA. Let's ensure dropna removes all rows.
+ df_nan_all_removed = pd.DataFrame({
+ 'outcome': [1, np.nan, 3],
+ 'treatment': [0, 1, np.nan],
+ 'covariate1': [np.nan, 6, 7]
+ })
+ with pytest.raises(ValueError, match="No data remaining after dropping NaNs"):
+ estimate_effect(df_nan_all_removed, 'treatment', 'outcome', ['covariate1'])
+
+def test_formula_generation(sample_data):
+ """Test the formula string generation."""
+ # No covariates
+ results_no_cov = estimate_effect(sample_data, 'treatment', 'outcome')
+ assert results_no_cov['formula'] == "outcome ~ treatment + const"
+
+ # With covariates
+ results_with_cov = estimate_effect(sample_data, 'treatment', 'outcome', ['covariate1', 'covariate2'])
+ assert results_with_cov['formula'] == "outcome ~ treatment + covariate1 + covariate2 + const"
+
diff --git a/tests/auto_causal/methods/linear_regression/test_lr_llm_assist.py b/tests/auto_causal/methods/linear_regression/test_lr_llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..c954bdd3da922621dabd1c909073dc7d03ae41cf
--- /dev/null
+++ b/tests/auto_causal/methods/linear_regression/test_lr_llm_assist.py
@@ -0,0 +1,123 @@
+import pytest
+from unittest.mock import MagicMock, patch
+import pandas as pd
+from auto_causal.methods.linear_regression.llm_assist import (
+ suggest_lr_covariates,
+ interpret_lr_results
+)
+
+# Patch target for the helper function where it's used
+LLM_ASSIST_MODULE = "auto_causal.methods.linear_regression.llm_assist"
+
+@pytest.fixture
+def mock_llm():
+ """Fixture for a basic mock LLM object."""
+ return MagicMock()
+
+@pytest.fixture
+def mock_ols_results():
+ """Creates a mock statsmodels OLS results object with necessary attributes."""
+ results = MagicMock()
+ treatment_var = 'treatment'
+ results.params = pd.Series({'const': 1.0, treatment_var: 2.5, 'cov1': 0.5})
+ results.pvalues = pd.Series({'const': 0.5, treatment_var: 0.01, 'cov1': 0.1})
+ # Mock conf_int() to return a DataFrame-like object accessible by .loc
+ conf_int_df = pd.DataFrame([[2.0, 3.0]], index=[treatment_var], columns=[0, 1])
+ results.conf_int.return_value = conf_int_df
+ results.rsquared = 0.75
+ results.rsquared_adj = 0.70
+ return results
+
+@pytest.fixture
+def mock_diagnostics_success():
+ """Creates a mock diagnostics dictionary for successful checks."""
+ return {
+ "status": "Success",
+ "details": {
+ 'residuals_normality_jb_p_value': 0.6,
+ 'homoscedasticity_bp_lm_p_value': 0.5,
+ 'homoscedasticity_status': "Homoscedastic",
+ 'residuals_normality_status': "Normal"
+ }
+ }
+
+def test_suggest_lr_covariates_placeholder(mock_llm):
+ """Test the placeholder covariate suggestion function."""
+ df_cols = ['a', 'b', 't', 'y']
+ treatment = 't'
+ outcome = 'y'
+ query = "What is the effect of t on y?"
+
+ # Test without LLM
+ suggested_no_llm = suggest_lr_covariates(df_cols, treatment, outcome, query, llm=None)
+ assert suggested_no_llm == []
+
+ # Test with LLM (should still return empty list for placeholder)
+ suggested_with_llm = suggest_lr_covariates(df_cols, treatment, outcome, query, llm=mock_llm)
+ assert suggested_with_llm == []
+ # Ensure mock LLM wasn't actually called by the placeholder
+ mock_llm.assert_not_called()
+
+@patch(f"{LLM_ASSIST_MODULE}.call_llm_with_json_output")
+def test_interpret_lr_results_implementation(mock_call_json, mock_llm, mock_ols_results, mock_diagnostics_success):
+ """Test the implemented results interpretation function."""
+ treatment_var = 'treatment'
+ mock_interpretation_text = "The treatment had a positive and significant effect."
+ mock_call_json.return_value = {"interpretation": mock_interpretation_text}
+
+ # --- Test with LLM ---
+ interp_with_llm = interpret_lr_results(
+ mock_ols_results,
+ mock_diagnostics_success,
+ treatment_var,
+ llm=mock_llm
+ )
+
+ assert interp_with_llm == mock_interpretation_text
+ mock_call_json.assert_called_once()
+ # Basic check on the prompt structure passed to the helper
+ call_args, call_kwargs = mock_call_json.call_args
+ prompt = call_args[1] # Second argument is the prompt string
+ assert "Linear Regression (OLS) results" in prompt
+ assert "Model Results Summary:" in prompt
+ assert "Model Diagnostics Summary:" in prompt
+ assert treatment_var in prompt
+ assert "Treatment Effect Estimate': '2.500" in prompt # Check formatting
+ assert "Homoscedasticity Status': 'Homoscedastic" in prompt # Check diagnostics inclusion
+ assert "Return ONLY a valid JSON" in prompt
+
+ # --- Test LLM Call Failure ---
+ mock_call_json.reset_mock()
+ mock_call_json.return_value = None # Simulate LLM helper failure
+ interp_fail = interpret_lr_results(mock_ols_results, mock_diagnostics_success, treatment_var, llm=mock_llm)
+ assert "LLM interpretation not available" in interp_fail
+ mock_call_json.assert_called_once() # Ensure it was still called
+
+ # --- Test without LLM ---
+ mock_call_json.reset_mock()
+ interp_no_llm = interpret_lr_results(mock_ols_results, mock_diagnostics_success, treatment_var, llm=None)
+ assert isinstance(interp_no_llm, str)
+ assert "LLM interpretation not available" in interp_no_llm
+ mock_call_json.assert_not_called() # Ensure helper wasn't called
+
+# Test edge case where treatment var isn't in results (though estimator should prevent this)
+def test_interpret_lr_results_treatment_not_found(mock_llm):
+ """Test interpretation when treatment var is unexpectedly missing from results."""
+ mock_res = MagicMock()
+ mock_res.params = pd.Series({'const': 1.0})
+ mock_res.pvalues = pd.Series({'const': 0.5})
+ mock_res.rsquared = 0.1
+ mock_res.rsquared_adj = 0.05
+ # Mock conf_int to avoid error even if treatment isn't there
+ mock_res.conf_int.return_value = pd.DataFrame([[0.0, 2.0]], index=['const'], columns=[0, 1])
+
+ mock_diag = {"status": "Success", "details": {}}
+
+ # With LLM (should default gracefully)
+ interp = interpret_lr_results(mock_res, mock_diag, "missing_treatment", llm=mock_llm)
+ assert "LLM interpretation not available" in interp # Should hit default as LLM call won't work well
+
+ # Without LLM
+ interp_no_llm = interpret_lr_results(mock_res, mock_diag, "missing_treatment", llm=None)
+ assert "LLM interpretation not available" in interp_no_llm
+
diff --git a/tests/auto_causal/methods/propensity_score/__init__.py b/tests/auto_causal/methods/propensity_score/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..927b75568a83384494e21c03107fef6436d41bf4
--- /dev/null
+++ b/tests/auto_causal/methods/propensity_score/__init__.py
@@ -0,0 +1 @@
+# Tests for propensity score methods
\ No newline at end of file
diff --git a/tests/auto_causal/methods/propensity_score/test_matching.py b/tests/auto_causal/methods/propensity_score/test_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..33f4c4783f9ac46c2a48f0e24844735aa69b791f
--- /dev/null
+++ b/tests/auto_causal/methods/propensity_score/test_matching.py
@@ -0,0 +1,67 @@
+import unittest
+import pandas as pd
+import numpy as np
+from unittest.mock import patch, MagicMock
+
+# Import the function to test
+from auto_causal.methods.propensity_score.matching import estimate_effect
+
+class TestPropensityScoreMatching(unittest.TestCase):
+
+ def setUp(self):
+ '''Set up a dummy DataFrame for testing.'''
+ self.df = pd.DataFrame({
+ 'treatment': [0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
+ 'outcome': [10, 12, 11, 13, 9, 14, 10, 15, 11, 16],
+ 'covariate1': [1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
+ 'covariate2': [5.5, 6.5, 5.8, 6.2, 5.1, 6.8, 5.3, 6.1, 5.9, 6.3]
+ })
+ self.treatment = 'treatment'
+ self.outcome = 'outcome'
+ self.covariates = ['covariate1', 'covariate2']
+
+ @patch('auto_causal.methods.propensity_score.matching.get_llm_parameters')
+ @patch('auto_causal.methods.propensity_score.matching.determine_optimal_caliper')
+ @patch('auto_causal.methods.propensity_score.matching.select_propensity_model')
+ @patch('auto_causal.methods.propensity_score.matching.estimate_propensity_scores')
+ @patch('auto_causal.methods.propensity_score.matching.assess_balance')
+ def test_estimate_effect_structure_and_types(self, mock_assess_balance, mock_estimate_ps,
+ mock_select_model, mock_determine_caliper, mock_get_llm_params):
+ '''Test the basic structure and types of the estimate_effect output.'''
+ # Configure mocks
+ mock_get_llm_params.return_value = {"parameters": {"caliper": 0.5}, "validation": {}}
+ mock_determine_caliper.return_value = 0.5 # Ensure caliper is set if LLM misses
+ mock_select_model.return_value = 'logistic'
+ # Simulate propensity scores (needs same length as df)
+ mock_estimate_ps.return_value = np.random.uniform(0.1, 0.9, size=len(self.df))
+ # Simulate diagnostics output
+ mock_assess_balance.return_value = {
+ "balance_metrics": {'covariate1': 0.05, 'covariate2': 0.08},
+ "balance_achieved": True,
+ "problematic_covariates": [],
+ "plots": {}
+ }
+
+ # Call the function
+ result = estimate_effect(self.df, self.treatment, self.outcome, self.covariates, query="Test query")
+
+ # Assertions
+ self.assertIsInstance(result, dict)
+ expected_keys = ["effect_estimate", "effect_se", "confidence_interval",
+ "diagnostics", "method_details", "parameters"]
+ for key in expected_keys:
+ self.assertIn(key, result, f"Key '{key}' missing from result")
+
+ self.assertEqual(result["method_details"], "PS.Matching")
+ self.assertIsInstance(result["effect_estimate"], float)
+ self.assertIsInstance(result["effect_se"], float)
+ self.assertIsInstance(result["confidence_interval"], list)
+ self.assertEqual(len(result["confidence_interval"]), 2)
+ self.assertIsInstance(result["diagnostics"], dict)
+ self.assertIsInstance(result["parameters"], dict)
+ self.assertIn("caliper", result["parameters"])
+ self.assertIn("propensity_model", result["parameters"])
+ self.assertIn("balance_achieved", result["diagnostics"])
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/methods/propensity_score/test_ps_matching.py b/tests/auto_causal/methods/propensity_score/test_ps_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc5c89bc7a418d8fe6924406781e9de1b2e45932
--- /dev/null
+++ b/tests/auto_causal/methods/propensity_score/test_ps_matching.py
@@ -0,0 +1,121 @@
+import pytest
+import pandas as pd
+import numpy as np
+from unittest.mock import patch, MagicMock
+
+# Import the function to test
+from auto_causal.methods.propensity_score import matching as ps_matching
+
+# Helper function to generate synthetic data for PSM
+def generate_synthetic_psm_data(n_samples=1000, treatment_effect=5.0, seed=42):
+ """Generates synthetic data suitable for PSM testing.
+
+ Features:
+ - Binary treatment based on covariates (logistic function).
+ - Confounding: Covariates affect both treatment and outcome.
+ - Known true treatment effect.
+ """
+ np.random.seed(seed)
+
+ # Covariates
+ X1 = np.random.normal(0, 1, n_samples)
+ X2 = np.random.binomial(1, 0.6, n_samples)
+
+ # Propensity score (true probability of treatment)
+ logit_p = -0.5 + 1.5 * X1 - 0.8 * X2 + np.random.normal(0, 0.5, n_samples)
+ p_treatment = 1 / (1 + np.exp(-logit_p))
+
+ # Treatment assignment
+ treatment = np.random.binomial(1, p_treatment, n_samples)
+
+ # Outcome model (linear)
+ # Base outcome depends on covariates (confounding)
+ base_outcome = 10 + 2.0 * X1 + 3.0 * X2 + np.random.normal(0, 2, n_samples)
+ # Treatment adds a fixed effect
+ outcome = base_outcome + treatment * treatment_effect
+
+ data = pd.DataFrame({
+ 'X1': X1,
+ 'X2': X2,
+ 'treatment': treatment,
+ 'outcome': outcome
+ })
+ return data
+
+# Test Class
+class TestPropensityScoreMatching:
+
+ def test_estimate_effect_synthetic_data(self):
+ """Test the end-to-end estimate_effect with synthetic data."""
+ df = generate_synthetic_psm_data(n_samples=2000, treatment_effect=5.0, seed=123)
+ covariates = ['X1', 'X2']
+ treatment = 'treatment'
+ outcome = 'outcome'
+
+ # Run the matching estimator
+ # Use default parameters for now (caliper=0.2, n_neighbors=1, logistic model)
+ results = ps_matching.estimate_effect(df, treatment, outcome, covariates)
+
+ assert 'effect_estimate' in results
+ assert 'effect_se' in results
+ assert 'confidence_interval' in results
+ assert 'diagnostics' in results
+ assert 'parameters' in results
+
+ # Check if the estimated effect is reasonably close to the true effect
+ # Allow for some tolerance due to estimation noise
+ true_effect = 5.0
+ estimated_effect = results['effect_estimate']
+ assert abs(estimated_effect - true_effect) < 1.0 # Adjust tolerance as needed
+
+ # Check if standard error and CI are plausible
+ assert results['effect_se'] > 0
+ assert results['confidence_interval'][0] < estimated_effect
+ assert results['confidence_interval'][1] > estimated_effect
+
+ # Check diagnostics structure (based on current placeholder implementation)
+ assert 'balance_metrics' in results['diagnostics']
+ assert 'balance_achieved' in results['diagnostics']
+ assert 'plots' in results['diagnostics']
+ assert 'percent_treated_matched' in results['diagnostics']
+
+ @patch('auto_causal.methods.propensity_score.matching.get_llm_parameters')
+ @patch('auto_causal.methods.propensity_score.matching.determine_optimal_caliper')
+ @patch('auto_causal.methods.propensity_score.matching.select_propensity_model')
+ def test_llm_parameter_usage(self, mock_select_model, mock_determine_caliper, mock_get_llm_params):
+ """Test that LLM helper functions are called and their results are potentially used."""
+ df = generate_synthetic_psm_data(n_samples=100, seed=456) # Smaller sample for this test
+ covariates = ['X1', 'X2']
+ treatment = 'treatment'
+ outcome = 'outcome'
+ query = "What is the effect?"
+
+ # Configure mocks
+ # Simulate LLM providing some parameters
+ mock_get_llm_params.return_value = {
+ "parameters": {"caliper": 0.1, "n_neighbors": 2},
+ "validation": {}
+ }
+ # Ensure other helpers return defaults if LLM doesn't provide everything
+ mock_determine_caliper.return_value = 0.2 # Fallback if LLM doesn't provide caliper
+ mock_select_model.return_value = 'logistic' # Fallback model
+
+ # Call the function with a query to trigger LLM pathway
+ results = ps_matching.estimate_effect(df, treatment, outcome, covariates, query=query)
+
+ # Assertions
+ mock_get_llm_params.assert_called_once_with(df, query, "PS.Matching")
+ # determine_optimal_caliper should NOT be called if LLM provides the caliper
+ mock_determine_caliper.assert_not_called()
+ # select_propensity_model should NOT be called if get_llm_parameters provides it (it doesn't in this mock setup)
+ # --> actually, it WILL be called because the mock get_llm_parameters doesn't provide 'propensity_model_type'
+ mock_select_model.assert_called_once()
+
+ # Check if the parameters used in the results reflect the LLM suggestions
+ assert results['parameters']['caliper'] == 0.1
+ assert results['parameters']['n_neighbors'] == 2
+ assert results['parameters']['propensity_model'] == 'logistic' # Came from fallback
+
+ # TODO: Add tests for diagnostic outputs (checking balance improvement)
+ # TODO: Add tests for edge cases (e.g., no matches found)
+ # TODO: Add tests for different parameter inputs (e.g., specifying caliper directly)
\ No newline at end of file
diff --git a/tests/auto_causal/methods/propensity_score/test_ps_weighting.py b/tests/auto_causal/methods/propensity_score/test_ps_weighting.py
new file mode 100644
index 0000000000000000000000000000000000000000..d42613bea82ff5c9c2a5b53c7da8dc31fb655e51
--- /dev/null
+++ b/tests/auto_causal/methods/propensity_score/test_ps_weighting.py
@@ -0,0 +1,103 @@
+import pytest
+import pandas as pd
+import numpy as np
+from unittest.mock import patch, MagicMock
+
+# Import the function to test
+from auto_causal.methods.propensity_score import weighting as ps_weighting
+
+# Reuse the synthetic data generation function from matching tests
+# (or redefine if weighting requires different data characteristics)
+from .test_ps_matching import generate_synthetic_psm_data
+
+# Test Class
+class TestPropensityScoreWeighting:
+
+ def test_estimate_effect_synthetic_data_ate(self):
+ """Test the end-to-end estimate_effect with synthetic data for ATE."""
+ df = generate_synthetic_psm_data(n_samples=2000, treatment_effect=5.0, seed=123)
+ covariates = ['X1', 'X2']
+ treatment = 'treatment'
+ outcome = 'outcome'
+
+ # Run the weighting estimator for ATE
+ results = ps_weighting.estimate_effect(df, treatment, outcome, covariates, weight_type='ATE')
+
+ assert 'effect_estimate' in results
+ assert 'effect_se' in results
+ assert 'confidence_interval' in results
+ assert 'diagnostics' in results
+ assert 'parameters' in results
+
+ # Check if the estimated effect is reasonably close to the true effect
+ # ATE might be slightly different from ATT depending on effect heterogeneity (none here)
+ true_effect = 5.0
+ estimated_effect = results['effect_estimate']
+ assert abs(estimated_effect - true_effect) < 1.0 # Adjust tolerance
+
+ # Check if standard error and CI are plausible
+ assert results['effect_se'] > 0
+ assert results['confidence_interval'][0] < estimated_effect
+ assert results['confidence_interval'][1] > estimated_effect
+
+ # Check diagnostics structure (based on current placeholder implementation)
+ assert 'min_weight' in results['diagnostics']
+ assert 'max_weight' in results['diagnostics']
+ assert 'effective_sample_size' in results['diagnostics']
+ assert 'propensity_score_model' in results['diagnostics']
+
+ def test_estimate_effect_synthetic_data_att(self):
+ """Test the end-to-end estimate_effect with synthetic data for ATT."""
+ df = generate_synthetic_psm_data(n_samples=2000, treatment_effect=5.0, seed=456)
+ covariates = ['X1', 'X2']
+ treatment = 'treatment'
+ outcome = 'outcome'
+
+ # Run the weighting estimator for ATT
+ results = ps_weighting.estimate_effect(df, treatment, outcome, covariates, weight_type='ATT')
+
+ # Check if the estimated effect is reasonably close to the true effect
+ true_effect = 5.0
+ estimated_effect = results['effect_estimate']
+ assert abs(estimated_effect - true_effect) < 1.0 # Adjust tolerance
+
+ @patch('auto_causal.methods.propensity_score.weighting.get_llm_parameters')
+ @patch('auto_causal.methods.propensity_score.weighting.determine_optimal_weight_type')
+ @patch('auto_causal.methods.propensity_score.weighting.determine_optimal_trim_threshold')
+ @patch('auto_causal.methods.propensity_score.weighting.select_propensity_model')
+ def test_llm_parameter_usage(self, mock_select_model, mock_determine_trim, mock_determine_weight, mock_get_llm_params):
+ """Test that LLM helper functions are called and their results are potentially used."""
+ df = generate_synthetic_psm_data(n_samples=100, seed=789) # Smaller sample
+ covariates = ['X1', 'X2']
+ treatment = 'treatment'
+ outcome = 'outcome'
+ query = "What is the ATT?"
+
+ # Configure mocks
+ mock_get_llm_params.return_value = {
+ "parameters": {"weight_type": "ATT", "trim_threshold": 0.01},
+ "validation": {}
+ }
+ mock_determine_weight.return_value = 'ATE' # Fallback
+ mock_determine_trim.return_value = None # Fallback (no trim)
+ mock_select_model.return_value = 'logistic' # Fallback
+
+ # Call the function
+ results = ps_weighting.estimate_effect(df, treatment, outcome, covariates, query=query)
+
+ # Assertions
+ mock_get_llm_params.assert_called_once_with(df, query, "PS.Weighting")
+ # Helpers should not be called if LLM provided the value
+ mock_determine_weight.assert_not_called()
+ mock_determine_trim.assert_not_called()
+ # Model selection will still be called as it wasn't in mock LLM params
+ mock_select_model.assert_called_once()
+
+ # Check parameters reflect LLM suggestions
+ assert results['parameters']['weight_type'] == 'ATT'
+ assert results['parameters']['trim_threshold'] == 0.01
+ assert results['parameters']['propensity_model'] == 'logistic'
+
+ # TODO: Add tests for weight trimming effects
+ # TODO: Add tests for diagnostic outputs (e.g., checking weight distribution stats)
+ # TODO: Add tests for edge cases (e.g., extreme weights)
\ No newline at end of file
diff --git a/tests/auto_causal/methods/propensity_score/test_weighting.py b/tests/auto_causal/methods/propensity_score/test_weighting.py
new file mode 100644
index 0000000000000000000000000000000000000000..60e3fcb2b5a68939dbd86d1577932e94194e214d
--- /dev/null
+++ b/tests/auto_causal/methods/propensity_score/test_weighting.py
@@ -0,0 +1,69 @@
+import unittest
+import pandas as pd
+import numpy as np
+from unittest.mock import patch, MagicMock
+
+# Import the function to test
+from auto_causal.methods.propensity_score.weighting import estimate_effect
+
+class TestPropensityScoreWeighting(unittest.TestCase):
+
+ def setUp(self):
+ '''Set up a dummy DataFrame for testing.'''
+ self.df = pd.DataFrame({
+ 'treatment': [0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
+ 'outcome': [10, 12, 11, 13, 9, 14, 10, 15, 11, 16],
+ 'covariate1': [1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
+ 'covariate2': [5.5, 6.5, 5.8, 6.2, 5.1, 6.8, 5.3, 6.1, 5.9, 6.3]
+ })
+ self.treatment = 'treatment'
+ self.outcome = 'outcome'
+ self.covariates = ['covariate1', 'covariate2']
+
+ @patch('auto_causal.methods.propensity_score.weighting.get_llm_parameters')
+ @patch('auto_causal.methods.propensity_score.weighting.determine_optimal_weight_type')
+ @patch('auto_causal.methods.propensity_score.weighting.determine_optimal_trim_threshold')
+ @patch('auto_causal.methods.propensity_score.weighting.select_propensity_model')
+ @patch('auto_causal.methods.propensity_score.weighting.estimate_propensity_scores')
+ @patch('auto_causal.methods.propensity_score.weighting.assess_weight_distribution')
+ def test_estimate_effect_structure_and_types(self, mock_assess_weights, mock_estimate_ps,
+ mock_select_model, mock_determine_trim,
+ mock_determine_weight, mock_get_llm_params):
+ '''Test the basic structure and types of the estimate_effect output.'''
+ # Configure mocks
+ mock_get_llm_params.return_value = {"parameters": {"weight_type": "ATE", "trim_threshold": 0.0}, "validation": {}}
+ mock_determine_weight.return_value = 'ATE'
+ mock_determine_trim.return_value = 0.0 # No trimming
+ mock_select_model.return_value = 'logistic'
+ # Simulate propensity scores
+ mock_estimate_ps.return_value = np.random.uniform(0.1, 0.9, size=len(self.df))
+ # Simulate diagnostics output
+ mock_assess_weights.return_value = {
+ "min_weight": 0.5, "max_weight": 5.0, "mean_weight": 1.0, "std_dev_weight": 0.5,
+ "effective_sample_size": len(self.df) * 0.8, "potential_issues": False
+ }
+
+ # Call the function
+ result = estimate_effect(self.df, self.treatment, self.outcome, self.covariates, query="Test query")
+
+ # Assertions
+ self.assertIsInstance(result, dict)
+ expected_keys = ["effect_estimate", "effect_se", "confidence_interval",
+ "diagnostics", "method_details", "parameters"]
+ for key in expected_keys:
+ self.assertIn(key, result, f"Key '{key}' missing from result")
+
+ self.assertEqual(result["method_details"], "PS.Weighting")
+ self.assertIsInstance(result["effect_estimate"], float)
+ self.assertIsInstance(result["effect_se"], float)
+ self.assertIsInstance(result["confidence_interval"], list)
+ self.assertEqual(len(result["confidence_interval"]), 2)
+ self.assertIsInstance(result["diagnostics"], dict)
+ self.assertIsInstance(result["parameters"], dict)
+ self.assertIn("weight_type", result["parameters"])
+ self.assertIn("trim_threshold", result["parameters"])
+ self.assertIn("propensity_model", result["parameters"])
+ self.assertIn("effective_sample_size", result["diagnostics"])
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/methods/regression_discontinuity/__init__.py b/tests/auto_causal/methods/regression_discontinuity/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tests/auto_causal/methods/regression_discontinuity/test_rdd_diagnostics.py b/tests/auto_causal/methods/regression_discontinuity/test_rdd_diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..2afb788774bff6b39b2b4e97bdda96300f0f94a4
--- /dev/null
+++ b/tests/auto_causal/methods/regression_discontinuity/test_rdd_diagnostics.py
@@ -0,0 +1,103 @@
+import pytest
+import pandas as pd
+import numpy as np
+from auto_causal.methods.regression_discontinuity.diagnostics import run_rdd_diagnostics
+
+# --- Fixture for RDD data ---
+@pytest.fixture
+def sample_rdd_data():
+ """Generates synthetic data suitable for RDD testing."""
+ np.random.seed(123)
+ n_samples = 200
+ cutoff = 50.0
+ treatment_effect = 10.0
+
+ running_var = np.random.uniform(cutoff - 20, cutoff + 20, n_samples)
+ treatment = (running_var >= cutoff).astype(int)
+ # Covariate correlated with running variable (potential imbalance)
+ covariate1 = 0.5 * running_var + np.random.normal(0, 5, n_samples)
+ # Covariate uncorrelated (should be balanced)
+ covariate2 = np.random.normal(10, 2, n_samples)
+ error = np.random.normal(0, 5, n_samples)
+ outcome = (10 + 0.8 * running_var +
+ treatment_effect * treatment +
+ 1.2 * treatment * (running_var - cutoff) +
+ 2.0 * covariate1 + 1.0 * covariate2 + error)
+
+ df = pd.DataFrame({
+ 'outcome': outcome,
+ 'treatment_indicator': treatment,
+ 'running_var': running_var,
+ 'covariate1': covariate1,
+ 'covariate2': covariate2
+ })
+ return df
+
+# --- Test Cases ---
+
+def test_run_rdd_diagnostics_success(sample_rdd_data):
+ """Test the diagnostics function with covariates."""
+ covariates = ['covariate1', 'covariate2']
+ results = run_rdd_diagnostics(
+ sample_rdd_data,
+ 'outcome',
+ 'running_var',
+ cutoff=50.0,
+ covariates=covariates,
+ bandwidth=10.0 # Use a reasonable bandwidth
+ )
+
+ assert results["status"] == "Success (Partial Implementation)"
+ assert "details" in results
+ details = results["details"]
+
+ assert "covariate_balance" in details
+ balance = details['covariate_balance']
+ assert isinstance(balance, dict)
+ assert 'covariate1' in balance
+ assert 'covariate2' in balance
+
+ # Check structure of balance results
+ assert 't_statistic' in balance['covariate1']
+ assert 'p_value' in balance['covariate1']
+ assert 'balanced' in balance['covariate1']
+ assert 't_statistic' in balance['covariate2']
+ assert 'p_value' in balance['covariate2']
+ assert 'balanced' in balance['covariate2']
+
+ # Check expected balance (covariate1 likely unbalanced, covariate2 likely balanced)
+ # Due to random noise, these might occasionally fail, but should usually hold
+ assert balance['covariate1']['balanced'].startswith("No")
+ assert balance['covariate2']['balanced'] == "Yes"
+
+ # Check placeholders
+ assert details['continuity_density_test'] == "Not Implemented (Requires specialized libraries like rdd)"
+ assert details['visual_inspection'] == "Recommended (Plot outcome vs running variable with fits)"
+
+def test_run_rdd_diagnostics_no_covariates(sample_rdd_data):
+ """Test diagnostics when no covariates are provided."""
+ results = run_rdd_diagnostics(
+ sample_rdd_data, 'outcome', 'running_var', cutoff=50.0, covariates=None, bandwidth=10.0
+ )
+ assert results["status"] == "Success (Partial Implementation)"
+ assert results["details"]['covariate_balance'] == "No covariates provided to check."
+
+def test_run_rdd_diagnostics_small_bandwidth(sample_rdd_data):
+ """Test diagnostics handles cases with insufficient data in bandwidth."""
+ # Bandwidth so small it likely excludes one side
+ results = run_rdd_diagnostics(
+ sample_rdd_data, 'outcome', 'running_var', cutoff=50.0, covariates=['covariate1'], bandwidth=0.1
+ )
+ assert results["status"] == "Skipped"
+ assert "Insufficient data near cutoff" in results["reason"]
+
+def test_run_rdd_diagnostics_missing_covariate(sample_rdd_data):
+ """Test diagnostics handles missing covariate columns gracefully."""
+ results = run_rdd_diagnostics(
+ sample_rdd_data, 'outcome', 'running_var', cutoff=50.0, covariates=['covariate1', 'missing_cov'], bandwidth=10.0
+ )
+ assert results["status"] == "Success (Partial Implementation)"
+ balance = results["details"]['covariate_balance']
+ assert balance['missing_cov']['status'] == "Column Not Found"
+ assert 't_statistic' in balance['covariate1'] # Check other covariate was still processed
+
diff --git a/tests/auto_causal/methods/regression_discontinuity/test_rdd_estimator.py b/tests/auto_causal/methods/regression_discontinuity/test_rdd_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bfe2abdaa465fe1aa58d6f5cd667e273720bf7e
--- /dev/null
+++ b/tests/auto_causal/methods/regression_discontinuity/test_rdd_estimator.py
@@ -0,0 +1,202 @@
+import pytest
+import pandas as pd
+import numpy as np
+from unittest.mock import patch, MagicMock
+from auto_causal.methods.regression_discontinuity.estimator import estimate_effect
+
+# --- Fixtures ---
+
+@pytest.fixture
+def sample_rdd_data():
+ """Generates synthetic data suitable for RDD testing."""
+ np.random.seed(123)
+ n_samples = 200
+ cutoff = 50.0
+ treatment_effect = 10.0
+
+ # Running variable centered around cutoff
+ running_var = np.random.uniform(cutoff - 20, cutoff + 20, n_samples)
+ # Treatment assigned based on cutoff
+ treatment = (running_var >= cutoff).astype(int)
+ # Covariate correlated with running variable
+ covariate1 = 0.5 * running_var + np.random.normal(0, 5, n_samples)
+ # Outcome depends on running var (parallel slopes), treatment, and covariate
+ error = np.random.normal(0, 5, n_samples)
+ outcome = (10 + 0.8 * running_var +
+ treatment_effect * treatment +
+ 2.0 * covariate1 + error)
+
+ df = pd.DataFrame({
+ 'outcome': outcome,
+ 'treatment_indicator': treatment, # Actual treatment status
+ 'running_var': running_var,
+ 'covariate1': covariate1
+ })
+ return df
+
+# --- Mocks for DoWhy ---
+@pytest.fixture
+def mock_causal_model():
+ """Fixture for a mocked DoWhy CausalModel."""
+ mock_model_instance = MagicMock()
+ # Mock the estimate_effect method
+ mock_estimate = MagicMock()
+ mock_estimate.value = 10.5 # Simulate DoWhy estimate
+ mock_estimate.test_significance_pvalue = 0.01
+ mock_estimate.confidence_interval = [8.0, 13.0]
+ mock_estimate.standard_error = 1.25
+ mock_model_instance.estimate_effect.return_value = mock_estimate
+
+ # Patch the CausalModel class in the estimator module
+ with patch('auto_causal.methods.regression_discontinuity.estimator.CausalModel') as MockCM:
+ MockCM.return_value = mock_model_instance
+ yield MockCM, mock_model_instance
+
+# --- Test Cases ---
+
+def test_estimate_effect_missing_args(sample_rdd_data):
+ """Test that RDD estimation fails if required args are missing."""
+ with pytest.raises(ValueError, match="Missing required RDD arguments"):
+ estimate_effect(sample_rdd_data, 'treatment_indicator', 'outcome', running_variable=None, cutoff=50.0)
+ with pytest.raises(ValueError, match="Missing required RDD arguments"):
+ estimate_effect(sample_rdd_data, 'treatment_indicator', 'outcome', running_variable='running_var', cutoff=None)
+
+@patch('auto_causal.methods.regression_discontinuity.estimator.run_rdd_diagnostics')
+@patch('auto_causal.methods.regression_discontinuity.estimator.interpret_rdd_results')
+def test_estimate_effect_dowhy_success(mock_interpret, mock_diagnostics, mock_causal_model, sample_rdd_data):
+ """Test successful estimation using the mocked DoWhy path."""
+ MockCM, mock_model_instance = mock_causal_model
+ mock_diagnostics.return_value = {"status": "Success", "details": {"covariate_balance": "Checked"}}
+ mock_interpret.return_value = "LLM Interpretation"
+
+ results = estimate_effect(
+ sample_rdd_data,
+ 'treatment_indicator',
+ 'outcome',
+ running_variable='running_var',
+ cutoff=50.0,
+ bandwidth=5.0, # Specify bandwidth
+ use_dowhy=True
+ )
+
+ MockCM.assert_called_once()
+ mock_model_instance.estimate_effect.assert_called_once()
+ call_args, call_kwargs = mock_model_instance.estimate_effect.call_args
+ assert call_kwargs['method_name'] == "iv.regression_discontinuity"
+ assert call_kwargs['method_params']['rd_variable_name'] == 'running_var'
+ assert call_kwargs['method_params']['rd_threshold_value'] == 50.0
+ assert call_kwargs['method_params']['rd_bandwidth'] == 5.0
+
+ assert results['method_used'] == 'DoWhy RDD'
+ assert results['effect_estimate'] == 10.5
+ assert results['p_value'] == 0.01
+ assert results['confidence_interval'] == [8.0, 13.0]
+ assert results['standard_error'] == 1.25
+ assert 'DoWhy RDD (Bandwidth: 5.000)' in results['method_details']
+ assert 'diagnostics' in results
+ assert 'interpretation' in results
+ mock_diagnostics.assert_called_once()
+ mock_interpret.assert_called_once()
+
+@patch('auto_causal.methods.regression_discontinuity.estimator.run_rdd_diagnostics')
+@patch('auto_causal.methods.regression_discontinuity.estimator.interpret_rdd_results')
+def test_estimate_effect_fallback_success(mock_interpret, mock_diagnostics, sample_rdd_data):
+ """Test successful estimation using the fallback linear interaction method."""
+ mock_diagnostics.return_value = {"status": "Success", "details": {"covariate_balance": "Checked"}}
+ mock_interpret.return_value = "LLM Interpretation"
+
+ results = estimate_effect(
+ sample_rdd_data,
+ 'treatment_indicator',
+ 'outcome',
+ running_variable='running_var',
+ cutoff=50.0,
+ covariates=['covariate1'],
+ bandwidth=10.0,
+ use_dowhy=False # Force fallback
+ )
+
+ assert results['method_used'] == 'Fallback RDD (Linear Interaction)'
+ assert 'effect_estimate' in results
+ assert 'p_value' in results
+ assert 'confidence_interval' in results
+ assert 'standard_error' in results
+ assert 'model_summary' in results # Fallback provides summary
+ assert 'Fallback Linear Interaction (Bandwidth: 10.000)' in results['method_details']
+ # Check if estimate is reasonable (should be around 10)
+ assert abs(results['effect_estimate'] - 10.0) < 20.0
+ assert 'diagnostics' in results
+ assert 'interpretation' in results
+ mock_diagnostics.assert_called_once()
+ mock_interpret.assert_called_once()
+
+@patch('auto_causal.methods.regression_discontinuity.estimator.estimate_effect_dowhy')
+@patch('auto_causal.methods.regression_discontinuity.estimator.estimate_effect_fallback')
+def test_estimate_effect_dowhy_fails_fallback_succeeds(mock_fallback, mock_dowhy, sample_rdd_data):
+ """Test that fallback is used when DoWhy fails."""
+ mock_dowhy.side_effect = Exception("DoWhy broke")
+ # Simulate successful fallback results
+ mock_fallback.return_value = {
+ 'effect_estimate': 9.8,
+ 'p_value': 0.02,
+ 'confidence_interval': [1.0, 18.6],
+ 'standard_error': 4.0,
+ 'method_used': 'Fallback RDD (Linear Interaction)',
+ 'method_details': "Fallback Linear Interaction (Bandwidth: 10.000)",
+ 'formula': 'formula_str',
+ 'model_summary': 'summary_str'
+ }
+
+ # Need to also patch diagnostics and interpretation as they run after estimation
+ with patch('auto_causal.methods.regression_discontinuity.estimator.run_rdd_diagnostics'), \
+ patch('auto_causal.methods.regression_discontinuity.estimator.interpret_rdd_results'):
+
+ results = estimate_effect(
+ sample_rdd_data,
+ 'treatment_indicator',
+ 'outcome',
+ running_variable='running_var',
+ cutoff=50.0,
+ bandwidth=10.0,
+ use_dowhy=True # Try DoWhy first
+ )
+
+ mock_dowhy.assert_called_once()
+ mock_fallback.assert_called_once()
+ assert results['method_used'] == 'Fallback RDD (Linear Interaction)'
+ assert results['effect_estimate'] == 9.8
+ assert 'dowhy_error_info' in results # Check that DoWhy error was recorded
+ assert "DoWhy broke" in results['dowhy_error_info']
+
+@patch('auto_causal.methods.regression_discontinuity.estimator.estimate_effect_dowhy')
+@patch('auto_causal.methods.regression_discontinuity.estimator.estimate_effect_fallback')
+def test_estimate_effect_both_fail(mock_fallback, mock_dowhy, sample_rdd_data):
+ """Test that an error is raised if both DoWhy and fallback fail."""
+ mock_dowhy.side_effect = Exception("DoWhy broke")
+ mock_fallback.side_effect = ValueError("Fallback broke")
+
+ with pytest.raises(ValueError, match="RDD estimation failed using both DoWhy and fallback methods"):
+ estimate_effect(
+ sample_rdd_data,
+ 'treatment_indicator',
+ 'outcome',
+ running_variable='running_var',
+ cutoff=50.0,
+ use_dowhy=True
+ )
+ mock_dowhy.assert_called_once()
+ mock_fallback.assert_called_once()
+
+def test_estimate_effect_no_data_in_bandwidth(sample_rdd_data):
+ """Test error when bandwidth is too small, leading to no data."""
+ # Use a very small bandwidth that excludes all data
+ with pytest.raises(ValueError, match="No data within the specified bandwidth"):
+ estimate_effect(
+ sample_rdd_data,
+ 'treatment_indicator',
+ 'outcome',
+ running_variable='running_var',
+ cutoff=50.0,
+ bandwidth=0.01, # Extremely small bandwidth
+ use_dowhy=False # Force fallback for this specific error check
+ )
diff --git a/tests/auto_causal/methods/regression_discontinuity/test_rdd_llm_assist.py b/tests/auto_causal/methods/regression_discontinuity/test_rdd_llm_assist.py
new file mode 100644
index 0000000000000000000000000000000000000000..819e5eb9f36d429aa8d719984fe874ed705b27aa
--- /dev/null
+++ b/tests/auto_causal/methods/regression_discontinuity/test_rdd_llm_assist.py
@@ -0,0 +1,115 @@
+import pytest
+from unittest.mock import MagicMock, patch
+from auto_causal.methods.regression_discontinuity.llm_assist import (
+ suggest_rdd_parameters,
+ interpret_rdd_results
+)
+
+# Patch target for the helper function where it's used
+LLM_ASSIST_MODULE = "auto_causal.methods.regression_discontinuity.llm_assist"
+
+@pytest.fixture
+def mock_llm():
+ """Fixture for a basic mock LLM object."""
+ return MagicMock()
+
+@pytest.fixture
+def mock_rdd_results():
+ """Creates a mock RDD results dictionary."""
+ return {
+ 'effect_estimate': 10.5,
+ 'p_value': 0.01,
+ 'confidence_interval': [8.0, 13.0],
+ 'standard_error': 1.25,
+ 'method_used': 'DoWhy RDD' # Or Fallback RDD
+ }
+
+@pytest.fixture
+def mock_rdd_diagnostics_success():
+ """Creates a mock RDD diagnostics dictionary for successful checks."""
+ return {
+ "status": "Success (Partial Implementation)",
+ "details": {
+ 'covariate_balance': {
+ 'cov1': {'p_value': 0.6, 'balanced': 'Yes'},
+ 'cov2': {'p_value': 0.02, 'balanced': 'No (p <= 0.05)'}
+ },
+ 'continuity_density_test': 'Not Implemented',
+ 'visual_inspection': 'Recommended'
+ }
+ }
+
+def test_suggest_rdd_parameters_placeholder(mock_llm):
+ """Test the placeholder RDD parameter suggestion function."""
+ df_cols = ['score', 'age', 'outcome']
+ query = "Effect of passing score (50) on outcome?"
+
+ # Test without LLM
+ suggested_no_llm = suggest_rdd_parameters(df_cols, query, llm=None)
+ assert suggested_no_llm == {}
+
+ # Test with LLM (should still return empty dict for placeholder)
+ suggested_with_llm = suggest_rdd_parameters(df_cols, query, llm=mock_llm)
+ assert suggested_with_llm == {}
+ mock_llm.assert_not_called()
+
+@patch(f"{LLM_ASSIST_MODULE}.call_llm_with_json_output")
+def test_interpret_rdd_results_implementation(mock_call_json, mock_llm, mock_rdd_results, mock_rdd_diagnostics_success):
+ """Test the implemented RDD results interpretation function."""
+ mock_interpretation_text = "RDD shows a significant positive effect at the cutoff."
+ mock_call_json.return_value = {"interpretation": mock_interpretation_text}
+
+ # --- Test with LLM ---
+ interp_with_llm = interpret_rdd_results(
+ mock_rdd_results,
+ mock_rdd_diagnostics_success,
+ llm=mock_llm
+ )
+
+ assert interp_with_llm == mock_interpretation_text
+ mock_call_json.assert_called_once()
+ # Basic check on the prompt structure passed to the helper
+ call_args, call_kwargs = mock_call_json.call_args
+ prompt = call_args[1] # Second argument is the prompt string
+ assert "Regression Discontinuity Design (RDD) results" in prompt
+ assert "Estimation Results Summary:" in prompt
+ assert "Diagnostics Summary:" in prompt
+ assert "Effect Estimate': '10.500" in prompt # Check formatting
+ assert "Number of Unbalanced Covariates (p<=0.05)': 1" in prompt # Check diagnostics summary
+ assert "visual inspection of the running variable vs outcome is recommended" in prompt
+ assert "Return ONLY a valid JSON" in prompt
+
+ # --- Test LLM Call Failure ---
+ mock_call_json.reset_mock()
+ mock_call_json.return_value = None # Simulate LLM helper failure
+ interp_fail = interpret_rdd_results(mock_rdd_results, mock_rdd_diagnostics_success, llm=mock_llm)
+ assert "LLM interpretation not available for RDD" in interp_fail
+ mock_call_json.assert_called_once() # Ensure it was still called
+
+ # --- Test without LLM ---
+ mock_call_json.reset_mock()
+ interp_no_llm = interpret_rdd_results(mock_rdd_results, mock_rdd_diagnostics_success, llm=None)
+ assert isinstance(interp_no_llm, str)
+ assert "LLM interpretation not available for RDD" in interp_no_llm
+ mock_call_json.assert_not_called() # Ensure helper wasn't called
+
+# Test interpretation with failed diagnostics
+def test_interpret_rdd_results_failed_diagnostics(mock_llm):
+ """Test interpretation when diagnostics failed."""
+ mock_res = {'effect_estimate': 5.0, 'p_value': 0.04}
+ mock_diag = {"status": "Failed", "error": "Something broke"}
+
+ # Patch the call to LLM helper for this specific test case
+ with patch(f"{LLM_ASSIST_MODULE}.call_llm_with_json_output") as mock_call_json_fail:
+ mock_call_json_fail.return_value = {"interpretation": "Interpreted despite failed diagnostics"}
+
+ interp = interpret_rdd_results(mock_res, mock_diag, llm=mock_llm)
+
+ assert interp == "Interpreted despite failed diagnostics"
+ mock_call_json_fail.assert_called_once()
+ call_args, call_kwargs = mock_call_json_fail.call_args
+ prompt = call_args[1]
+ assert "Diagnostics Summary:" in prompt
+ assert "Status': 'Failed" in prompt # Check failed status is in prompt
+ assert "Error': 'Something broke" in prompt
+
diff --git a/tests/auto_causal/methods/test_diff_in_diff.py b/tests/auto_causal/methods/test_diff_in_diff.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d994ef0514b2cdc0da7c80c8533009a997eac66
--- /dev/null
+++ b/tests/auto_causal/methods/test_diff_in_diff.py
@@ -0,0 +1,81 @@
+import unittest
+import pandas as pd
+import numpy as np
+from unittest.mock import patch, MagicMock
+
+# Import the function to test
+from auto_causal.methods.diff_in_diff import estimate_effect
+
+class TestDifferenceInDifferences(unittest.TestCase):
+
+ def setUp(self):
+ '''Set up dummy panel data for testing.'''
+ # Simple 2 groups, 2 periods example
+ self.df = pd.DataFrame({
+ 'unit': [1, 1, 2, 2, 3, 3, 4, 4], # 2 treated (1,2), 2 control (3,4)
+ 'time': [0, 1, 0, 1, 0, 1, 0, 1],
+ 'treatment_group': [1, 1, 1, 1, 0, 0, 0, 0], # Group indicator
+ 'outcome': [10, 12, 11, 14, 9, 9.5, 10, 10.5], # Treated increase more in period 1
+ 'covariate1': [1, 1, 2, 2, 1, 1, 2, 2]
+ })
+ self.treatment = 'treatment_group' # This identifies the group
+ self.outcome = 'outcome'
+ self.covariates = ['covariate1']
+ self.time_var = 'time'
+ self.group_var = 'unit'
+
+ # Mock all helper/validation functions within diff_in_diff.py
+ @patch('auto_causal.methods.diff_in_diff.identify_time_variable')
+ @patch('auto_causal.methods.diff_in_diff.identify_treatment_group')
+ @patch('auto_causal.methods.diff_in_diff.determine_treatment_period')
+ @patch('auto_causal.methods.diff_in_diff.validate_parallel_trends')
+ # Mock estimate_did_model to avoid actual regression, return mock results
+ @patch('auto_causal.methods.diff_in_diff.estimate_did_model')
+ def test_estimate_effect_structure_and_types(self, mock_estimate_model, mock_validate_trends,
+ mock_determine_period, mock_identify_group, mock_identify_time):
+ '''Test the basic structure and types of the DiD estimate_effect output.'''
+ # Configure mocks
+ mock_identify_time.return_value = self.time_var
+ mock_identify_group.return_value = self.group_var
+ mock_determine_period.return_value = 1 # Assume treatment starts at time 1
+ mock_validate_trends.return_value = {"valid": True, "p_value": 0.9}
+
+ # Mock the statsmodels result object
+ mock_model_results = MagicMock()
+ # Define the interaction term based on how construct_did_formula names it
+ # Assuming treatment='treatment_group', post='post'
+ interaction_term = f"{self.treatment}_x_post"
+ mock_model_results.params = {interaction_term: 2.5, 'Intercept': 10.0}
+ mock_model_results.bse = {interaction_term: 0.5, 'Intercept': 0.2}
+ mock_model_results.pvalues = {interaction_term: 0.01, 'Intercept': 0.001}
+ # Mock the summary() method if format_did_results uses it
+ mock_model_results.summary.return_value = "Mocked Model Summary"
+ mock_estimate_model.return_value = mock_model_results
+
+ # Call the function (passing explicit vars to bypass internal identification mocks if desired)
+ result = estimate_effect(self.df, self.treatment, self.outcome, self.covariates,
+ time_var=self.time_var, group_var=self.group_var, query="Test query")
+
+ # Assertions
+ self.assertIsInstance(result, dict)
+ expected_keys = ["effect_estimate", "effect_se", "confidence_interval", "p_value",
+ "diagnostics", "method_details", "parameters", "model_summary"]
+ for key in expected_keys:
+ self.assertIn(key, result, f"Key '{key}' missing from result")
+
+ self.assertEqual(result["method_details"], "DiD.TWFE")
+ self.assertIsInstance(result["effect_estimate"], float)
+ self.assertIsInstance(result["effect_se"], float)
+ self.assertIsInstance(result["confidence_interval"], list)
+ self.assertEqual(len(result["confidence_interval"]), 2)
+ self.assertIsInstance(result["diagnostics"], dict)
+ self.assertIsInstance(result["parameters"], dict)
+ self.assertIn("time_var", result["parameters"])
+ self.assertIn("group_var", result["parameters"])
+ self.assertIn("interaction_term", result["parameters"])
+ self.assertEqual(result["parameters"]["interaction_term"], interaction_term)
+ self.assertIn("valid", result["diagnostics"])
+ self.assertIn("model_summary", result)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/methods/test_generalized_propensity_score.py b/tests/auto_causal/methods/test_generalized_propensity_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..6049c0ffed712d338c90f4e399b9c487e0f266f3
--- /dev/null
+++ b/tests/auto_causal/methods/test_generalized_propensity_score.py
@@ -0,0 +1,264 @@
+import unittest
+import pandas as pd
+import numpy as np
+import statsmodels.api as sm
+
+# Assuming the causalscientist package is installed or PYTHONPATH is set correctly
+from auto_causal.methods.generalized_propensity_score.estimator import (
+ _estimate_gps_values,
+ _estimate_outcome_model,
+ _generate_dose_response_function,
+ estimate_effect_gps
+)
+from auto_causal.methods.generalized_propensity_score.diagnostics import assess_gps_balance
+
+class TestGeneralizedPropensityScore(unittest.TestCase):
+
+ def _generate_synthetic_data(self, n=1000, seed=42, linear_gps=True, linear_outcome=False):
+ """
+ Generates synthetic data for testing GPS.
+ T = a0 + a1*X1 + a2*X2 + u
+ Y = b0 + b1*T + b2*T^2 + c1*X1 + c2*X2 + v (if non-linear outcome)
+ Y = b0 + b1*T + c1*X1 + c2*X2 + v (if linear outcome)
+ GPS is based on T ~ X.
+ Outcome model is Y ~ T, GPS, and their transformations.
+ """
+ np.random.seed(seed)
+ data = pd.DataFrame()
+ data['X1'] = np.random.normal(0, 1, n)
+ data['X2'] = np.random.binomial(1, 0.5, n)
+
+ # Treatment assignment
+ if linear_gps:
+ # Ensure treatment has reasonable variance and is affected by covariates
+ data['T'] = 0.5 + 1.5 * data['X1'] - 1.0 * data['X2'] + np.random.normal(0, 2, n)
+ else: # Non-linear treatment assignment (not directly used by current _estimate_gps_values OLS)
+ data['T'] = 0.5 + 1.5 * data['X1']**2 - 1.0 * data['X2'] + np.random.normal(0, 2, n)
+
+ # Outcome generation
+ # True Dose-Response: E[Y(t)] = 10 + 5*t - 0.5*t^2 (example)
+ # Confounding effect of X1, X2
+ confounding_effect = 2.0 * data['X1'] + 1.0 * data['X2']
+
+ if linear_outcome:
+ # Y = b0 + b1*T + c1*X1 + c2*X2 + v
+ data['Y'] = 10 + 3 * data['T'] + confounding_effect + np.random.normal(0, 3, n)
+ else:
+ # Y = b0 + b1*T + b2*T^2 + c1*X1 + c2*X2 + v
+ data['Y'] = 10 + 5 * data['T'] - 0.5 * data['T']**2 + confounding_effect + np.random.normal(0, 3, n)
+
+ return data
+
+ def test_generate_synthetic_data_smoke(self):
+ df = self._generate_synthetic_data()
+ self.assertEqual(len(df), 1000)
+ self.assertIn('T', df.columns)
+ self.assertIn('Y', df.columns)
+ self.assertIn('X1', df.columns)
+ self.assertIn('X2', df.columns)
+
+ def test_estimate_gps_values_linear_case(self):
+ df = self._generate_synthetic_data(n=100)
+ treatment_var = 'T'
+ covariate_vars = ['X1', 'X2']
+ gps_model_spec = {"type": "linear"}
+
+ df_with_gps, diagnostics = _estimate_gps_values(df.copy(), treatment_var, covariate_vars, gps_model_spec)
+
+ self.assertIn('gps_score', df_with_gps.columns)
+ self.assertFalse(df_with_gps['gps_score'].isnull().all(), "GPS scores should not be all NaNs")
+ self.assertGreater(df_with_gps['gps_score'].mean(), 0, "Mean GPS score should be positive")
+ self.assertIn("gps_model_type", diagnostics)
+ self.assertEqual(diagnostics["gps_model_type"], "linear_ols")
+ self.assertTrue(0 <= diagnostics["gps_model_rsquared"] <= 1)
+
+ def test_estimate_gps_values_no_covariates(self):
+ df = self._generate_synthetic_data(n=50)
+ treatment_var = 'T'
+ gps_model_spec = {"type": "linear"}
+
+ # Test with empty list of covariates
+ df_with_gps, diagnostics = _estimate_gps_values(df.copy(), treatment_var, [], gps_model_spec)
+ self.assertIn('gps_score', df_with_gps.columns) # Should still add the column
+ self.assertTrue(df_with_gps['gps_score'].isnull().all(), "GPS scores should be all NaN if no covariates")
+ self.assertIn("error", diagnostics)
+ self.assertEqual(diagnostics["error"], "No covariates provided.")
+
+ def test_estimate_gps_values_zero_variance_residual(self):
+ # Create data where T is perfectly predicted by X1 (zero residual variance)
+ df = pd.DataFrame({'X1': np.random.normal(0, 1, 50)})
+ df['T'] = 2 * df['X1'] # Perfect prediction
+ df['X2'] = np.random.binomial(1, 0.5, 50) # Dummy covariate
+ treatment_var = 'T'
+ covariate_vars = ['X1'] # Using only X1 for perfect prediction
+ gps_model_spec = {"type": "linear"}
+
+ df_with_gps, diagnostics = _estimate_gps_values(df.copy(), treatment_var, covariate_vars, gps_model_spec)
+ self.assertIn('gps_score', df_with_gps.columns)
+ self.assertTrue(df_with_gps['gps_score'].isnull().all(), "GPS should be NaN if residual variance is zero")
+ self.assertIn("warning_sigma_sq_hat_near_zero", diagnostics) # Check for the warning when it's very close to zero
+
+ def test_estimate_gps_values_not_enough_dof(self):
+ df = pd.DataFrame({
+ 'T': [1,2,3],
+ 'X1': [10,11,12],
+ 'X2': [1,0,1],
+ 'X3': [5,6,7] # T = X1 - 9 + X3 - 5 (perfectly determined)
+ })
+ # n=3, k_params (const, X1, X2, X3) = 4. n-k = -1
+ df_res, diagnostics = _estimate_gps_values(df.copy(), 'T', ['X1', 'X2', 'X3'], {"type": "linear"})
+ self.assertTrue(df_res['gps_score'].isnull().all())
+ self.assertIn("Not enough degrees of freedom for GPS variance", diagnostics.get("error", ""))
+
+
+ def test_estimate_outcome_model_structure(self):
+ df = self._generate_synthetic_data(n=100)
+ # First, get some GPS scores
+ df_with_gps, _ = _estimate_gps_values(df.copy(), 'T', ['X1', 'X2'], {"type": "linear"})
+ df_with_gps.dropna(subset=['gps_score', 'Y', 'T', 'X1', 'X2'], inplace=True) # Ensure no NaNs for model fitting
+
+ self.assertFalse(df_with_gps.empty, "DataFrame became empty after GPS estimation for outcome model test")
+
+ outcome_var = 'Y'
+ treatment_var = 'T'
+ gps_col_name = 'gps_score'
+ # Standard polynomial spec as used in estimator.py
+ outcome_model_spec = {"type": "polynomial", "degree": 2, "interaction": True}
+
+ fitted_model = _estimate_outcome_model(df_with_gps, outcome_var, treatment_var, gps_col_name, outcome_model_spec)
+
+ self.assertIsNotNone(fitted_model)
+ self.assertIsInstance(fitted_model, sm.regression.linear_model.RegressionResultsWrapper)
+
+ expected_terms = ['intercept', 'T', 'GPS', 'T_sq', 'GPS_sq', 'T_x_GPS']
+ for term in expected_terms:
+ self.assertIn(term, fitted_model.model.exog_names, f"Term {term} missing from outcome model")
+
+ def test_generate_dose_response_function(self):
+ df = self._generate_synthetic_data(n=200)
+ df_with_gps, _ = _estimate_gps_values(df.copy(), 'T', ['X1', 'X2'], {"type": "linear"})
+ df_with_gps.dropna(subset=['gps_score', 'Y', 'T'], inplace=True)
+ self.assertFalse(df_with_gps.empty, "Test setup failed: df_with_gps is empty")
+
+
+ outcome_model_spec = {"type": "polynomial", "degree": 2, "interaction": True}
+ fitted_outcome_model = _estimate_outcome_model(df_with_gps, 'Y', 'T', 'gps_score', outcome_model_spec)
+
+ t_values = np.linspace(df_with_gps['T'].min(), df_with_gps['T'].max(), 5).tolist()
+ adrf_estimates = _generate_dose_response_function(
+ df_with_gps, fitted_outcome_model, 'T', 'gps_score', outcome_model_spec, t_values
+ )
+
+ self.assertEqual(len(adrf_estimates), len(t_values))
+ self.assertFalse(np.isnan(adrf_estimates).any(), "ADRF estimates should not be NaN for valid inputs")
+
+ def test_generate_dose_response_empty_t_values(self):
+ df = self._generate_synthetic_data(n=50) # Dummy data
+ df_with_gps, _ = _estimate_gps_values(df.copy(), 'T', ['X1', 'X2'], {"type": "linear"})
+ df_with_gps.dropna(subset=['gps_score', 'Y', 'T'], inplace=True)
+ outcome_model_spec = {"type": "polynomial", "degree": 2, "interaction": True}
+ fitted_outcome_model = _estimate_outcome_model(df_with_gps, 'Y', 'T', 'gps_score', outcome_model_spec)
+
+ adrf_estimates = _generate_dose_response_function(
+ df_with_gps, fitted_outcome_model, 'T', 'gps_score', outcome_model_spec, [] # Empty t_values
+ )
+ self.assertEqual(len(adrf_estimates), 0)
+
+ def test_estimate_effect_gps_end_to_end_smoke(self):
+ df = self._generate_synthetic_data(n=200, seed=123)
+ results = estimate_effect_gps(
+ df,
+ treatment_var='T',
+ outcome_var='Y',
+ covariate_vars=['X1', 'X2'],
+ t_values_for_adrf=np.linspace(df['T'].min(), df['T'].max(), 7).tolist() # specify t_values
+ )
+
+ self.assertNotIn("error", results, f"estimate_effect_gps returned an error: {results.get('error')}")
+ self.assertIn("adrf_curve", results)
+ self.assertIn("diagnostics", results)
+ self.assertIn("method_details", results)
+ self.assertIn("parameters_used", results)
+
+ self.assertIn("t_levels", results["adrf_curve"])
+ self.assertIn("expected_outcomes", results["adrf_curve"])
+ self.assertEqual(len(results["adrf_curve"]["t_levels"]), 7)
+ self.assertEqual(len(results["adrf_curve"]["expected_outcomes"]), 7)
+ self.assertIsInstance(results["diagnostics"]["gps_estimation_diagnostics"], dict)
+ self.assertIsInstance(results["diagnostics"]["balance_check"], dict) # From assess_gps_balance
+
+ def test_estimate_effect_gps_gps_estimation_failure(self):
+ # Test case where GPS estimation might fail (e.g., no covariates)
+ df = self._generate_synthetic_data(n=50)
+ results = estimate_effect_gps(
+ df,
+ treatment_var='T',
+ outcome_var='Y',
+ covariate_vars=[] # No covariates
+ )
+ self.assertIn("error", results)
+ self.assertEqual(results["error"], "GPS estimation failed.")
+ self.assertIn("no covariates provided", results["diagnostics"]["error"].lower())
+
+
+ # --- Tests for assess_gps_balance (from diagnostics.py) ---
+ def test_assess_gps_balance_smoke(self):
+ df_synth = self._generate_synthetic_data(n=300)
+ df_with_gps, _ = _estimate_gps_values(df_synth.copy(), 'T', ['X1', 'X2'], {"type": "linear"})
+ df_with_gps.dropna(subset=['gps_score', 'T', 'X1', 'X2'], inplace=True)
+
+ self.assertGreater(len(df_with_gps), 100, "Not enough data after NaN drop for balance test setup.")
+
+ balance_results = assess_gps_balance(
+ df_with_gps,
+ treatment_var='T',
+ covariate_vars=['X1', 'X2'],
+ gps_col_name='gps_score',
+ num_strata=3 # Test with fewer strata
+ )
+
+ self.assertNotIn("error", balance_results, f"assess_gps_balance returned an error: {balance_results.get('error')}")
+ self.assertIn("balance_results_per_covariate", balance_results)
+ self.assertIn("summary_stats", balance_results)
+ self.assertIn("X1", balance_results["balance_results_per_covariate"])
+ self.assertIn("X2", balance_results["balance_results_per_covariate"])
+ self.assertIsInstance(balance_results["balance_results_per_covariate"]['X1']["strata_details"], list)
+ self.assertGreater(len(balance_results["balance_results_per_covariate"]['X1']["strata_details"]), 0)
+ self.assertEqual(balance_results["summary_stats"]["num_strata_used"], 3)
+
+
+ def test_assess_gps_balance_all_gps_nan(self):
+ df_synth = self._generate_synthetic_data(n=50)
+ df_synth['gps_score_all_nan'] = np.nan # All GPS scores are NaN
+
+ balance_results = assess_gps_balance(
+ df_synth,
+ treatment_var='T',
+ covariate_vars=['X1'],
+ gps_col_name='gps_score_all_nan'
+ )
+ self.assertIn("error", balance_results)
+ self.assertEqual(balance_results["error"], "All GPS scores are NaN.")
+
+ def test_assess_gps_balance_qcut_failure_fallback(self):
+ # Test qcut failure and fallback (e.g. GPS has very few unique values)
+ df = pd.DataFrame({
+ 'T': np.random.rand(50),
+ 'X1': np.random.rand(50),
+ 'gps_score': np.array([0.1]*20 + [0.2]*20 + [0.3]*10) # Only 3 unique GPS values
+ })
+ balance_results = assess_gps_balance(df, 'T', ['X1'], 'gps_score', num_strata=5)
+ self.assertNotIn("error", balance_results.get("summary_stats", {}).get("error", "")) # Check for critical error
+ self.assertIn("warnings", balance_results["summary_stats"])
+ # Check for the warning about forming fewer strata than requested
+ actual_strata_formed = balance_results["summary_stats"].get('actual_num_strata_formed', 0)
+ expected_warning_part = f"Only {actual_strata_formed} strata formed out of 5 requested"
+ current_warnings = balance_results["summary_stats"]["warnings"]
+ self.assertTrue(any(expected_warning_part in w
+ for w in current_warnings),
+ f"Expected warning '{expected_warning_part}' not found. Warnings: {current_warnings}")
+ self.assertEqual(balance_results["summary_stats"]["actual_num_strata_formed"], 3)
+
+
+if __name__ == '__main__':
+ unittest.main(argv=['first-arg-is-ignored'], exit=False)
\ No newline at end of file
diff --git a/tests/auto_causal/test_adhoc.py b/tests/auto_causal/test_adhoc.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f3216ffeac290951f39d93f7ce3e504ad8099bf
--- /dev/null
+++ b/tests/auto_causal/test_adhoc.py
@@ -0,0 +1,77 @@
+import unittest
+import os
+import json
+import re
+import pandas as pd
+# Import load_dotenv
+from dotenv import load_dotenv
+
+# Import the main entry point
+from auto_causal.agent import run_causal_analysis
+
+load_dotenv()
+
+# class TestAdhoc(unittest.TestCase):
+# def test_adhoc(self):
+# query = "Does receiving a voter turnout mailing increase the probability of voting compared to receiving no mailing?"
+# dataset_path = "data/voter_turnout_data.csv"
+# dataset_description = """The ISPS D001 dataset, titled "Social Pressure and Voter Turnout: Evidence from a Large-Scale Field Experiment,"
+# originates from a 2006 field experiment in Michigan. Researchers Gerber, Green, and Larimer investigated how different mail-based interventions
+# influenced voter turnout in a primary election. The study encompassed 180,002 households (344,084 registered voters), randomly assigned to a control
+# group or one of four treatment groups: Civic Duty, Hawthorne Effect, Self, and Neighbors. Each treatment involved a distinct mailing designed to
+# apply social pressure or appeal to civic responsibility. The primary outcome measured was voter turnout in the 2006 local elections.
+# Data were sourced from Michigan's Qualified Voter File (QVF), curated by Practical Political Consulting. The dataset includes individual
+# and household-level information, treatment assignments, and voting outcomes. Comprehensive documentation and replication materials are available
+# to facilitate further research and analysis."""
+
+# result = run_causal_analysis(query, dataset_path, dataset_description)
+# print(result)
+
+class TestAdhoc(unittest.TestCase):
+ def test_adhoc_from_structured_input(self):
+ # Define the input using the new structure
+ test_input_data ={
+ "paper": " What is the effect of home visits on the cognitive test scores of children who actually received the intervention?",
+ "dataset_description": """"The CSV file ihdp_4.csv contains data obtained from the Infant Health and Development Program (IHDP). The study is designed to evaluate the effect of home visit from specialist doctors on the cognitive test scores of premature infants. The confounders x (x1-x25) correspond to collected measurements of the children and their mothers, including measurements on the child (birth weight, head circumference, weeks born preterm, birth order, first born, neonatal health index, sex, twin status), as well as behaviors engaged in during the pregnancy (smoked cigarettes, drank alcohol, took drugs) and measurements on the mother at the time she gave birth (age, marital status, educational attainment, whether she worked during pregnancy, whether she received prenatal care) and the site (8 total) in which the family resided at the start of the intervention. There are 6 continuous covariates and 19 binary covariates.""",
+ "query": "What is the effect of home visits on the cognitive test scores of children who actually received the intervention?",
+ "answer": 0.0,
+ "method": "TWFE",
+ "dataset_path": "benchmark/all_data_1/ihdp_5.csv"
+ }
+
+ # Extract relevant info from the input data
+ query = test_input_data["query"]
+ dataset_path = test_input_data["dataset_path"]
+ dataset_description = test_input_data["dataset_description"]
+ expected_method = test_input_data["method"]
+ expected_answer = test_input_data["answer"]
+
+ # Ensure dataset_path is correct if it's relative to a specific root or needs joining
+ # For example, if your tests run from the root of the project:
+ # script_dir = os.path.dirname(__file__)
+ # project_root = os.path.abspath(os.path.join(script_dir, "../../..")) # Adjust based on test_adhoc.py location
+ # dataset_path = os.path.join(project_root, dataset_path)
+ # For now, assuming dataset_path is directly usable or handled by run_causal_analysis
+
+ print(f"Running adhoc test with query: {query}")
+ print(f"Dataset path: {dataset_path}")
+
+ # Call the main causal analysis function
+ # We need to know what `run_causal_analysis` returns to make assertions.
+ # Assuming it returns a dictionary that includes the method used and the effect estimate.
+ result = run_causal_analysis(query, dataset_path, dataset_description)
+
+ print("Causal analysis result:")
+ #print(json.dumps(result, indent=2)) # Pretty print the result dictionary
+
+ # Assertions (these are examples and depend on the actual structure of `result`)
+ # You'll need to adapt these based on what `run_causal_analysis` returns.
+
+ # Example: Assuming result is a dict and might have a top-level key for the final output summary
+ # and within that, information about method used and effect estimate.
+ # This is highly speculative and needs to be adjusted.
+ final_summary = result # or result.get("summary"), etc.
+
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/test_agent_workflow.py b/tests/auto_causal/test_agent_workflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f402393d30e6840b59b30c53b1178d2d7b66063e
--- /dev/null
+++ b/tests/auto_causal/test_agent_workflow.py
@@ -0,0 +1,103 @@
+import unittest
+import os
+from unittest.mock import patch, MagicMock
+
+# Import AIMessage for mocking
+from langchain_core.messages import AIMessage
+# Import ToolCall if needed for more complex mocking
+# from langchain_core.agents import AgentAction, AgentFinish
+# from langchain_core.tools import ToolCall
+
+# Assume run_causal_analysis is the main entry point
+from auto_causal.agent import run_causal_analysis
+
+# Helper to create a dummy dataset file for tests
+def create_dummy_csv(path='dummy_e2e_test_data.csv'):
+ import pandas as pd
+ df = pd.DataFrame({
+ 'treatment': [0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
+ 'outcome': [10, 12, 11, 13, 9, 14, 10, 15, 11, 16],
+ 'covariate1': [1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
+ 'covariate2': [5.5, 6.5, 5.8, 6.2, 5.1, 6.8, 5.3, 6.1, 5.9, 6.3]
+ })
+ df.to_csv(path, index=False)
+ return path
+
+class TestAgentWorkflow(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.dummy_data_path = create_dummy_csv()
+ # Set dummy API key for testing if needed by agent setup
+ os.environ["OPENAI_API_KEY"] = "test_key"
+
+ @classmethod
+ def tearDownClass(cls):
+ if os.path.exists(cls.dummy_data_path):
+ os.remove(cls.dummy_data_path)
+ del os.environ["OPENAI_API_KEY"]
+
+ # Patch the LLM call to avoid actual API calls during this basic test
+ @patch('auto_causal.agent.ChatOpenAI')
+ def test_agent_invocation(self, mock_chat_openai):
+ '''Test if the agent runs without critical errors using dummy data.'''
+ # Configure the mock LLM to return an AIMessage
+ mock_llm_instance = mock_chat_openai.return_value
+
+ # Simulate the LLM deciding to call the first tool
+ # We create an AIMessage containing a simulated tool call.
+ # The exact structure might vary slightly based on agent/langchain versions.
+ # For now, just providing a basic AIMessage output to satisfy the prompt format.
+ # A more robust mock would simulate the JSON/ToolCall structure.
+ mock_response = AIMessage(content="Okay, I need to parse the input first.",
+ # Example of adding a tool call if needed:
+ # tool_calls=[ToolCall(name="input_parser_tool",
+ # args={"query": "Test query", "dataset_path": "dummy_path"},
+ # id="call_123")]
+ )
+
+ # We also need to mock the agent's parsing of this AIMessage into an AgentAction
+ # or handle the AgentExecutor's internal calls. This gets complex.
+ # Let's try mocking the return value of the agent executor's chain directly for simplicity.
+
+ # Alternative simpler mock: Mock the final output of the AgentExecutor invoke
+ # Patch the AgentExecutor class itself if possible, or its invoke method.
+ # For now, let's stick to mocking the LLM but returning an AIMessage.
+ mock_llm_instance.invoke.return_value = mock_response
+
+ # Since the agent will try to *parse* the AIMessage and likely fail without
+ # a proper output parser mock or correctly formatted tool call structure,
+ # let's refine the mock to return what the final step *might* return.
+ # This is becoming less of a unit test and more of a placeholder.
+ # Reverting to the previous simple mock, but acknowledging its limitation.
+ mock_llm_instance.invoke.return_value = AIMessage(content="Processed successfully (mocked)")
+
+ query = "What is the effect of treatment on outcome?"
+ dataset_path = self.dummy_data_path
+
+ try:
+ # Run the main analysis function
+ # We expect this to fail later in the chain now, but hopefully not on prompt formatting.
+ # The mock needs to be sophisticated enough to handle the AgentExecutor loop.
+ # For this test, let's assume the mocked AIMessage is enough to prevent the immediate crash.
+
+ # Re-patching the AgentExecutor might be better for a simple invocation test.
+ with patch('auto_causal.agent.AgentExecutor.invoke') as mock_agent_invoke:
+ mock_agent_invoke.return_value = {"output": "Agent invoked successfully (mocked)"}
+
+ result = run_causal_analysis(query, dataset_path)
+
+ # Basic assertion: Check if we get a result dictionary
+ self.assertIsInstance(result, str) # run_causal_analysis returns result["output"] which is str
+ self.assertIn("Agent invoked successfully (mocked)", result) # Check if the mocked output is returned
+ print(f"Agent Result (Mocked): {result}")
+
+ except Exception as e:
+ # Catch the specific ValueError if it still occurs, otherwise fail
+ if isinstance(e, ValueError) and "agent_scratchpad" in str(e):
+ self.fail(f"ValueError related to agent_scratchpad persisted: {e}")
+ else:
+ self.fail(f"Agent invocation failed with unexpected exception: {e}")
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/test_components/test_query_interpreter.py b/tests/auto_causal/test_components/test_query_interpreter.py
new file mode 100644
index 0000000000000000000000000000000000000000..98910f193a79e41171554c98aca1557b543a1b81
--- /dev/null
+++ b/tests/auto_causal/test_components/test_query_interpreter.py
@@ -0,0 +1,110 @@
+import pytest
+from unittest.mock import patch, MagicMock
+
+from auto_causal.components.query_interpreter import interpret_query
+from auto_causal.models import LLMTreatmentReferenceLevel
+
+# Basic mock data setup
+MOCK_QUERY_INFO_REF_LEVEL = {
+ "query_text": "What is the effect of different fertilizers (Nitro, Phos, Control) on crop_yield, using Control as the baseline?",
+ "potential_treatments": ["fertilizer_type"],
+ "outcome_hints": ["crop_yield"],
+ "covariates_hints": ["soil_ph", "rainfall"]
+}
+
+MOCK_DATASET_ANALYSIS_REF_LEVEL = {
+ "columns": ["fertilizer_type", "crop_yield", "soil_ph", "rainfall"],
+ "column_categories": {
+ "fertilizer_type": "categorical_multi", # Assuming a category type for multi-level
+ "crop_yield": "continuous_numeric",
+ "soil_ph": "continuous_numeric",
+ "rainfall": "continuous_numeric"
+ },
+ "potential_treatments": ["fertilizer_type"],
+ "potential_outcomes": ["crop_yield"],
+ "value_counts": { # Added for providing unique values to the prompt
+ "fertilizer_type": {
+ "values": ["Nitro", "Phos", "Control"]
+ }
+ },
+ "columns_data_preview": { # Fallback if value_counts isn't structured as expected
+ "fertilizer_type": ["Nitro", "Phos", "Control", "Nitro", "Control"]
+ }
+ # Add other necessary fields from DatasetAnalysis model if interpret_query uses them
+}
+
+MOCK_DATASET_DESCRIPTION_REF_LEVEL = "A dataset from an agricultural experiment."
+
+def test_interpret_query_identifies_treatment_reference_level():
+ """
+ Test that interpret_query correctly identifies and returns the treatment_reference_level
+ when the LLM simulation provides one.
+ """
+ # Mock the LLM client and its structured output
+ mock_llm_instance = MagicMock()
+ mock_structured_llm = MagicMock()
+
+ # This will be the mock for the call related to TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE
+ # The other LLM calls (for T, O, C, IV, RDD, RCT) also need to be considered
+ # or made to return benign defaults for this specific test.
+
+ # Simulate LLM responses for different calls within interpret_query
+ def mock_llm_call_router(*args, **kwargs):
+ # The first argument to _call_llm_for_var is the llm instance,
+ # The second is the prompt string
+ # The third is the Pydantic model for structured output
+
+ # args[0] is llm, args[1] is prompt, args[2] is pydantic_model
+ pydantic_model_passed = args[2]
+
+ if pydantic_model_passed == LLMTreatmentReferenceLevel:
+ return LLMTreatmentReferenceLevel(reference_level="Control", reasoning="Identified from query text.")
+ # Add mocks for other LLM calls if interpret_query strictly needs them to proceed
+ # For example, for identifying treatment, outcome, covariates, IV, RDD, RCT:
+ elif "most likely treatment variable" in args[1]: # Simplified check for treatment prompt
+ return MagicMock(variable_name="fertilizer_type")
+ elif "most likely outcome variable" in args[1]: # Simplified check for outcome prompt
+ return MagicMock(variable_name="crop_yield")
+ elif "valid covariates" in args[1]: # Simplified check for covariates prompt
+ return MagicMock(covariates=["soil_ph", "rainfall"])
+ elif "Instrumental Variables" in args[1]: # Check for IV prompt
+ return MagicMock(instrument_variable=None)
+ elif "Regression Discontinuity Design" in args[1]: # Check for RDD prompt
+ return MagicMock(running_variable=None, cutoff_value=None)
+ elif "Randomized Controlled Trial" in args[1]: # Check for RCT prompt
+ return MagicMock(is_rct=False, reasoning="No indication of RCT.")
+ return MagicMock() # Default mock for other calls
+
+ # Patch _call_llm_for_var which is used internally by interpret_query's helpers
+ with patch('auto_causal.components.query_interpreter._call_llm_for_var', side_effect=mock_llm_call_router) as mock_llm_call:
+ # Patch get_llm_client to return our mock_llm_instance
+ # This ensures that _call_llm_for_var uses the intended LLM mock when called from within interpret_query
+ with patch('auto_causal.components.query_interpreter.get_llm_client', return_value=mock_llm_instance) as mock_get_llm:
+
+ result = interpret_query(
+ query_info=MOCK_QUERY_INFO_REF_LEVEL,
+ dataset_analysis=MOCK_DATASET_ANALYSIS_REF_LEVEL,
+ dataset_description=MOCK_DATASET_DESCRIPTION_REF_LEVEL
+ )
+
+ assert "treatment_reference_level" in result, "treatment_reference_level should be in the result"
+ assert result["treatment_reference_level"] == "Control", "Incorrect treatment_reference_level identified"
+
+ # Verify that the LLM was called to get the reference level
+ # This requires checking the calls made to the mock_llm_call
+ found_ref_level_call = False
+ for call_args in mock_llm_call.call_args_list:
+ # call_args is a tuple; call_args[0] contains positional args, call_args[1] has kwargs
+ # The third positional argument to _call_llm_for_var is the pydantic_model
+ if len(call_args[0]) >= 3 and call_args[0][2] == LLMTreatmentReferenceLevel:
+ found_ref_level_call = True
+ # Optionally, check the prompt content here too if needed
+ # prompt_content = call_args[0][1]
+ # assert "using Control as the baseline" in prompt_content
+ break
+ assert found_ref_level_call, "LLM call for treatment reference level was not made."
+
+ # Basic checks for other essential variables (assuming they are mocked simply)
+ assert result["treatment_variable"] == "fertilizer_type"
+ assert result["outcome_variable"] == "crop_yield"
+ assert result["is_rct"] is False # Based on mock
\ No newline at end of file
diff --git a/tests/auto_causal/test_e2e_did.py b/tests/auto_causal/test_e2e_did.py
new file mode 100644
index 0000000000000000000000000000000000000000..cecca73f83fcc168174a9c7831fd26c14d79f3c3
--- /dev/null
+++ b/tests/auto_causal/test_e2e_did.py
@@ -0,0 +1,113 @@
+import unittest
+import os
+import json
+import re
+import pandas as pd
+# Import load_dotenv
+from dotenv import load_dotenv
+
+# Import the main entry point
+from auto_causal.agent import run_causal_analysis
+
+# Ensure necessary environment variables are set for LLM calls (e.g., OPENAI_API_KEY)
+# Load from .env file if present
+# from dotenv import load_dotenv
+# load_dotenv()
+
+class TestE2EDID(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # Load environment variables from .env file
+ load_dotenv()
+
+ # Define query details based on the user's request
+ cls.query = "What was the impact of cigarette taxation rules on cigarette sales in California?"
+ cls.dataset_description = "To estimate the effect of cigarette taxation on its consumption, data from cigarette sales were collected and analyzed across 39 states in the United States from the years 1970 to 2000. Proposition 99, a Tobacco Tax and Health Protection Act passed in California in 1988, imposed a 25-cent per pack state excise tax on tobacco cigarettes and implemented additional restrictions, including the ban on cigarette vending machines in public areas accessible by juveniles and a ban on the individual sale of single cigarettes. Revenue generated was allocated for environmental and health care programs along with anti-tobacco advertising. We aim to determine if the imposition of this tax and the subsequent regulations led to a reduction in cigarette sales. The data is in the CSV file smoking2.csv."
+ # Expected effect from query data. Note: A positive value (tax increasing sales) is counter-intuitive for this scenario.
+ # The actual DiD effect for Prop 99 is often cited as negative (reducing sales). We use the provided value for test structure.
+ cls.expected_effect = 24.83
+ # Tolerance might need adjustment based on the specific DiD model implemented by the agent
+ cls.tolerance = 10.0
+
+ # Construct path relative to this test file's directory
+ base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
+ cls.dataset_path = os.path.join(base_dir, "data", "qrdata", "smoking2.csv")
+ print(f"DEBUG: E2E test using dataset path: {cls.dataset_path}")
+
+ # Check if data file exists
+ if not os.path.exists(cls.dataset_path):
+ raise FileNotFoundError(f"E2E test requires dataset at: {cls.dataset_path}")
+
+ # Ensure API key is available (or skip test)
+ if not os.getenv("OPENAI_API_KEY"):
+ raise unittest.SkipTest("Skipping E2E test: OPENAI_API_KEY not set or found in .env file.")
+
+ def extract_results_from_output(self, output_string: str) -> dict:
+ '''Helper to parse relevant info from the agent's final output string.'''
+ results = {
+ 'method': None,
+ 'effect': None
+ }
+ # Look for Method Used:
+ method_match = re.search(r"Method Used:\s*([^\n]+)", output_string, re.IGNORECASE)
+ if method_match:
+ # Strip potential markdown and extra spaces
+ method_name = method_match.group(1).strip().replace('*', '')
+ results['method'] = method_name
+ # Fallback: Look for 'Recommended Method:'
+ elif (method_match := re.search(r"Recommended Method:\s*([^\n]+)", output_string, re.IGNORECASE)):
+ method_name = method_match.group(1).strip().replace('*', '')
+ results['method'] = method_name
+
+
+ # Parse effect - Added more robust patterns
+ effect_patterns = [
+ r"Causal Effect:\s*([-\+]?\d*\.?\d+)",
+ r"estimated causal effect is\s*([-\+]?\d*\.?\d+)",
+ r"effect estimate:\s*([-\+]?\d*\.?\d+)"
+ ]
+ for pattern in effect_patterns:
+ effect_match = re.search(pattern, output_string, re.IGNORECASE)
+ if effect_match:
+ try:
+ results['effect'] = float(effect_match.group(1))
+ break # Stop after first successful match
+ except ValueError:
+ pass
+
+ return results
+
+ def test_did_e2e(self):
+ '''Run the full agent workflow on the smoking dataset.'''
+
+ # run_causal_analysis now returns the final explanation string directly
+ final_output_string = run_causal_analysis(self.query, self.dataset_path, self.dataset_description)
+
+ print("\n--- E2E Test Output (DiD) ---")
+ print(final_output_string)
+ print("-----------------------------\n")
+
+ # Parse the output string directly
+ parsed_results = self.extract_results_from_output(final_output_string)
+
+ # Assertions
+ self.assertIsNotNone(parsed_results['method'], "Could not extract method from final output string.")
+ # Check if the method is DiD (case-insensitive, ignoring spaces)
+ method_lower_no_space = parsed_results['method'].lower().replace(' ', '').replace('-', '')
+ expected_methods = ["differenceindifferences", "did", "diffindiff"]
+ self.assertTrue(
+ any(expected in method_lower_no_space for expected in expected_methods),
+ f"Expected DiD method, but found: {parsed_results['method']}"
+ )
+
+ # Check numerical effect
+ self.assertIsNotNone(parsed_results['effect'], "Could not extract effect estimate from final output string.")
+ # Note: DiD estimates can vary based on model specification (covariates, fixed effects).
+ # The expected value 24.83 might be based on a specific model or potentially incorrect.
+ # Adjust tolerance accordingly.
+ self.assertAlmostEqual(parsed_results['effect'], self.expected_effect, delta=self.tolerance,
+ msg=f"Estimated effect {parsed_results['effect']} not within {self.tolerance} of expected {self.expected_effect}")
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/test_e2e_ihdp.py b/tests/auto_causal/test_e2e_ihdp.py
new file mode 100644
index 0000000000000000000000000000000000000000..88497ed3fe7562b9894d67822f5dfdd4643c23d1
--- /dev/null
+++ b/tests/auto_causal/test_e2e_ihdp.py
@@ -0,0 +1,116 @@
+import unittest
+import os
+import json
+import re
+import pandas as pd
+# Import load_dotenv
+from dotenv import load_dotenv
+
+# Import the main entry point
+from auto_causal.agent import run_causal_analysis
+
+# Ensure necessary environment variables are set for LLM calls (e.g., OPENAI_API_KEY)
+# Load from .env file if present
+# from dotenv import load_dotenv
+# load_dotenv()
+
+class TestE2EIHDP(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # Load environment variables from .env file
+ load_dotenv()
+
+ # Define query details based on queries.json
+ cls.query = "What is the effect of home visits from specialist doctors on the cognitive scores of premature infants?"
+
+ # Construct path relative to this test file's directory
+ # __file__ is the path to the current file
+ # os.path.dirname gets the directory containing the file
+ # os.path.abspath ensures it's an absolute path
+ # Go up 2 levels (tests/auto_causal/ -> tests/ -> causalscientist/) then into data/
+ base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
+ cls.dataset_path = os.path.join(base_dir, "data", "qrdata", "ihdp_1.csv")
+ # cls.dataset_path = "data/qrdata/ihdp_1.csv" # Old relative path
+ print(f"DEBUG: E2E test using dataset path: {cls.dataset_path}") # Add print statement
+
+ cls.expected_effect = 4.05
+ cls.tolerance = 0.5 # Allow some variance from the expected 4.05
+
+ # Check if data file exists
+ if not os.path.exists(cls.dataset_path):
+ raise FileNotFoundError(f"E2E test requires dataset at: {cls.dataset_path}")
+
+ # Ensure API key is available (or skip test)
+ # This check will now happen *after* load_dotenv attempts to load it
+ if not os.getenv("OPENAI_API_KEY"):
+ raise unittest.SkipTest("Skipping E2E test: OPENAI_API_KEY not set or found in .env file.")
+
+ # Add dataset description from queries.json
+ cls.dataset_description = "The CSV file ihdp_1.csv contains data obtained from the Infant Health and Development Program (IHDP). The study is designed to evaluate the effect of home visit from specialist doctors on the cognitive test scores of premature infants. The confounders x (x1-x25) correspond to collected measurements of the children and their mothers, including measurements on the child (birth weight, head circumference, weeks born preterm, birth order, first born, neonatal health index, sex, twin status), as well as behaviors engaged in during the pregnancy (smoked cigarettes, drank alcohol, took drugs) and measurements on the mother at the time she gave birth (age, marital status, educational attainment, whether she worked during pregnancy, whether she received prenatal care) and the site (8 total) in which the family resided at the start of the intervention. There are 6 continuous covariates and 19 binary covariates."
+
+ def extract_results_from_output(self, output_string: str) -> dict:
+ '''Helper to parse relevant info from the agent's final output string.'''
+ results = {
+ 'method': None,
+ 'effect': None
+ }
+ # Try simpler regex pattern: Look for Method Used:, space, capture until newline
+ method_match = re.search(r"Method Used:\s*([^\n]+)", output_string, re.IGNORECASE)
+
+ if method_match:
+ # Strip potential markdown and extra spaces from the captured group
+ method_name = method_match.group(1).strip().replace('*', '')
+ results['method'] = method_name
+ # Keep fallback checks if needed
+ # else:
+ # Fallback if the first pattern fails (e.g., different formatting)
+ # method_match = re.search(r"recommended method is ([\w\s\.-]+)\.", output_string, re.IGNORECASE)
+ # if method_match:
+ # results['method'] = method_match.group(1).strip()
+
+ # Parse effect
+ effect_match = re.search(r"Causal Effect: ([\-\+]?\d*\.?\d+)", output_string, re.IGNORECASE)
+ if not effect_match: # Try summary pattern
+ effect_match = re.search(r"estimated causal effect is ([\-\+]?\d*\.?\d+)", output_string, re.IGNORECASE)
+
+ if effect_match:
+ try:
+ results['effect'] = float(effect_match.group(1))
+ except ValueError:
+ pass
+
+ return results
+
+ def test_ihdp_e2e(self):
+ '''Run the full agent workflow on the IHDP dataset.'''
+
+ # run_causal_analysis now returns the final explanation string directly
+ final_output_string = run_causal_analysis(self.query, self.dataset_path, self.dataset_description)
+
+ print("--- E2E Test Output ---")
+ print(final_output_string)
+ print("-----------------------")
+
+ # Parse the output string directly
+ parsed_results = self.extract_results_from_output(final_output_string)
+
+ # Assertions
+ self.assertIsNotNone(parsed_results['method'], "Could not extract method from final output string.")
+ # Check if the method is one of the PS methods we refactored
+ # Note: Method selection logic might still need debugging
+ self.assertIn(parsed_results['method'].lower(),
+ ["propensity score matching", "propensity score weighting",
+ "ps.matching", "ps.weighting", "regression adjustment"], # Allow RA as decision tree might choose it
+ f"Unexpected method found: {parsed_results['method']}")
+
+ # Check numerical effect
+ self.assertIsNotNone(parsed_results['effect'], "Could not extract effect estimate from final output string.")
+ self.assertAlmostEqual(parsed_results['effect'], self.expected_effect, delta=self.tolerance,
+ msg=f"Estimated effect {parsed_results['effect']} not within {self.tolerance} of expected {self.expected_effect}")
+
+if __name__ == '__main__':
+ # Ensure the working directory allows finding the data file relative path
+ # Might need adjustment depending on how tests are run
+ # Example: os.chdir('../') if running from tests/ directory
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/test_e2e_iv.py b/tests/auto_causal/test_e2e_iv.py
new file mode 100644
index 0000000000000000000000000000000000000000..905391ffddfa2a698aad53e552841b15d322668c
--- /dev/null
+++ b/tests/auto_causal/test_e2e_iv.py
@@ -0,0 +1,56 @@
+import unittest
+import os
+import sys
+import re # For parsing results
+
+from auto_causal.agent import run_causal_analysis
+
+class TestE2EIV(unittest.TestCase):
+
+ def test_iv_wage_education(self):
+ """Run the full agent workflow on the app_engagement_push dataset for IV."""
+
+ query = "Does the marketing push increase app purchases?"
+ # Assuming tests run from the project root directory
+ dataset_path = "data/qrdata/app_engagement_push.csv"
+ dataset_description = "A study is conducted to measure the effect of a marketing push on user engagement, specifically in-app purchases. Some customers who were assigned to receive the push are not receiving it, because they probably have an older phone that doesn’t support the kind of push the marketing team designed.\nThe dataset app_engagement_push.csv contains records for 10,000 random customers. Each record includes whether an in-app purchase was made (in_app_purchase), if a marketing push was assigned to the user (push_assigned), and if the marketing push was successfully delivered (push_delivered)"
+
+ # --- Execute the Agent ---
+ # Note: Ensure any required API keys (e.g., OPENAI_API_KEY) are set
+ # in the environment where the test runs, as get_llm_client() likely needs it.
+ print("--- Running E2E Test Output (IV) ---")
+ final_output_string = run_causal_analysis(
+ query=query,
+ dataset_path=dataset_path,
+ dataset_description=dataset_description
+ )
+ print(final_output_string)
+ print("-------------------------------------")
+
+ # --- Assertions ---
+ self.assertIsNotNone(final_output_string, "Agent returned None output.")
+ self.assertIsInstance(final_output_string, str, "Agent output is not a string.")
+
+ # Check for absence of common error messages
+ self.assertNotIn("Error:", final_output_string, "Output string contains 'Error:'.")
+ self.assertNotIn("Failed:", final_output_string, "Output string contains 'Failed:'.")
+ self.assertNotIn("Traceback", final_output_string, "Output string contains 'Traceback'.")
+
+ # Check if the correct method was likely selected and mentioned
+ self.assertIn("Instrumental Variable", final_output_string, "Method 'Instrumental Variable' not mentioned in output.")
+
+ # Check if key variables are mentioned
+ output_lower = final_output_string.lower()
+ self.assertIn("education", output_lower, "Treatment variable 'education' not mentioned.") # Or 'schooling'
+ self.assertIn("wage", output_lower, "Outcome variable 'wage' not mentioned.") # Or 'log wage'
+ self.assertIn("quarter", output_lower, "Instrument variable 'quarter' not mentioned.") # Or 'qob'
+
+ # Check if an effect estimate section/value exists
+ self.assertIn("Causal Effect", output_lower, "'Causal Effect' section missing.")
+ # Check for a number pattern near the effect estimate
+ # Check for positive effect (0.0853)
+ self.assertTrue(re.search(r"causal effect:?\s*\+?\d*\.?\d+", output_lower),
+ "Numerical effect estimate pattern not found.")
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/test_e2e_rdd.py b/tests/auto_causal/test_e2e_rdd.py
new file mode 100644
index 0000000000000000000000000000000000000000..bef293e62a024dd03871ca957428d75284da59d6
--- /dev/null
+++ b/tests/auto_causal/test_e2e_rdd.py
@@ -0,0 +1,63 @@
+import unittest
+import os
+import sys
+import re # For parsing results
+
+# Ensure the main package is discoverable
+# Adjust path as necessary based on your test execution context
+# SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+# sys.path.append(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
+
+from auto_causal.agent import run_causal_analysis
+
+class TestE2ERDD(unittest.TestCase):
+
+ def test_rdd_drinking_data(self):
+ """Run the full agent workflow on the drinking age dataset for RDD."""
+
+ query = "What is the effect of alcohol consumption on death by all causes at 21 years?"
+ # Assuming tests run from the project root directory
+ dataset_path = "data/qrdata/drinking.csv"
+ dataset_description = "To estimate the impacts of alcohol on death, we could use the fact that legal drinking age imposes a discontinuity on nature. In the US, those just under 21 years don't drink (or drink much less) while those just older than 21 do drink. The csv file drinking.csv contains mortality data aggregated by age. Each row is the average age of a group of people and the average mortality by all causes (all), by moving vehicle accident (mva) and by suicide (suicide)."
+
+ # --- Execute the Agent ---
+ # Note: Ensure any required API keys (e.g., OPENAI_API_KEY) are set
+ # in the environment where the test runs, as get_llm_client() likely needs it.
+ print("--- Running E2E Test Output (RDD) ---")
+ final_output_string = run_causal_analysis(
+ query=query,
+ dataset_path=dataset_path,
+ dataset_description=dataset_description
+ )
+ print(final_output_string)
+ print("-------------------------------------")
+
+ # --- Assertions ---
+ self.assertIsNotNone(final_output_string, "Agent returned None output.")
+ self.assertIsInstance(final_output_string, str, "Agent output is not a string.")
+
+ # Check for absence of common error messages
+ self.assertNotIn("Error:", final_output_string, "Output string contains 'Error:'.")
+ self.assertNotIn("Failed:", final_output_string, "Output string contains 'Failed:'.")
+ self.assertNotIn("Traceback", final_output_string, "Output string contains 'Traceback'.")
+
+ # Check if the correct method was likely selected and mentioned
+ self.assertIn("Regression Discontinuity", final_output_string, "Method 'Regression Discontinuity' not mentioned in output.")
+
+ # Check if key variables are mentioned
+ # (Use lowercase for case-insensitivity)
+ output_lower = final_output_string.lower()
+ self.assertIn("age", output_lower, "Running variable 'age' not mentioned.")
+ self.assertIn("21", output_lower, "Cutoff '21' not mentioned.")
+ # Outcome variable name is 'all' in the dataset
+ self.assertIn("all", output_lower, "Outcome variable 'all' not mentioned.")
+
+ # Check if an effect estimate section/value exists
+ self.assertIn("Causal Effect", output_lower, "'Causal Effect' section missing.")
+ # Check for a number pattern near the effect estimate
+ # This is less brittle than asserting the exact value 7.66
+ self.assertTrue(re.search(r"causal effect:?\s*[-+]?\d*\.?\d+", output_lower),
+ "Numerical effect estimate pattern not found.")
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/auto_causal/tools/__init__.py b/tests/auto_causal/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b15f4c17395b6b08ab701279a08248d595bb9b0c
--- /dev/null
+++ b/tests/auto_causal/tools/__init__.py
@@ -0,0 +1 @@
+# Tests for auto_causal tools
\ No newline at end of file
diff --git a/tests/process_qrdata_benchmark.py b/tests/process_qrdata_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a4f566301506246f62788d7fc3d7061a4e6a84d
--- /dev/null
+++ b/tests/process_qrdata_benchmark.py
@@ -0,0 +1,171 @@
+import csv
+import json
+import os
+import time
+from auto_causal.agent import run_causal_analysis
+import auto_causal.components.output_formatter as cs_output_formatter
+# Remove the direct import of cs_method_executor if it causes issues, we'll use importlib
+# import auto_causal.tools.method_executor_tool as cs_method_executor
+import importlib # Import importlib
+
+# --- Configuration ---
+# Absolute path as specified by user for the output log file
+OUTPUT_LOG_FILE = "Project/fork_/causalscientist/tests/output/qr_data_4o-mini_latest"
+# Relative path to the input CSV file from the workspace root
+INPUT_CSV_PATH = "benchmark/qr_revised.csv"
+# Prefix for constructing dataset paths
+DATA_FILES_BASE_DIR = "benchmark/all_data_1/"
+
+# --- Placeholder for the core analysis function ---
+# This function needs to be implemented or imported from elsewhere.
+# For the purpose of this script, it's a placeholder.
+def benchmark_causal_analysis(natural_language_query: str, dataset_path: str, data_description: str):
+ """
+ Placeholder for the actual causal analysis function.
+ This function would typically perform the analysis based on the inputs.
+ """
+ print(f"[INFO] run_causal_analysis called with:")
+ print(f" Natural Language Query: '{natural_language_query}'")
+ print(f" Dataset Path: '{dataset_path}'")
+ # print(f" Data Description: '{data_description[:100]}...' (truncated)") # Truncate for brevity if needed
+
+ # Simulate some processing time
+ # time.sleep(0.1) # Optional: Simulate work
+
+ run_causal_analysis(natural_language_query, dataset_path, data_description)
+
+ # TODO: Replace this with actual analysis logic.
+ # Example: Simulate failure for demonstration purposes.
+ # import random
+ # # Fail if "example_fail_condition" is in the query or randomly
+ # if "example_fail_condition" in natural_language_query.lower() or random.random() < 0.1: # ~10% chance of failure
+ # print("[WARN] Simulating a failure in run_causal_analysis.")
+ # raise Exception("Simulated analysis error from run_causal_analysis")
+
+ print(f"[INFO] run_causal_analysis for '{dataset_path}' completed successfully.")
+ # Actual implementation might return a result or have side effects.
+
+
+def main():
+ # Set the log file path for the output_formatter module
+ cs_output_formatter.CURRENT_OUTPUT_LOG_FILE = OUTPUT_LOG_FILE
+
+ # Set the log file path for the method_executor_tool module using importlib
+ try:
+ method_executor_module_name = "auto_causal.tools.method_executor_tool"
+ cs_method_executor_module = importlib.import_module(method_executor_module_name)
+ cs_method_executor_module.CURRENT_OUTPUT_LOG_FILE = OUTPUT_LOG_FILE
+ print(f"[INFO] Successfully set CURRENT_OUTPUT_LOG_FILE for {method_executor_module_name} to: {OUTPUT_LOG_FILE}")
+ except Exception as e:
+ print(f"[ERROR] Failed to set CURRENT_OUTPUT_LOG_FILE for method_executor_tool: {e}")
+ # Decide if you want to return or continue if this fails
+ return
+
+ # Ensure the output directory for the log file exists
+ output_log_dir = os.path.dirname(OUTPUT_LOG_FILE)
+ if not os.path.exists(output_log_dir):
+ try:
+ os.makedirs(output_log_dir)
+ print(f"Created directory: {output_log_dir}")
+ except OSError as e:
+ print(f"[ERROR] Failed to create directory '{output_log_dir}': {e}")
+ return # Stop if we can't create the log directory
+
+ current_query_sequence_number = 0
+ processed_csv_rows = 0
+
+ print(f"Starting processing of CSV: {INPUT_CSV_PATH}")
+ print(f"Output log will be written to: {OUTPUT_LOG_FILE}")
+
+ try:
+ with open(INPUT_CSV_PATH, mode='r', newline='', encoding='utf-8') as csv_file:
+ csv_reader = csv.DictReader(csv_file)
+
+ if not csv_reader.fieldnames:
+ print(f"[ERROR] CSV file '{INPUT_CSV_PATH}' is empty or has no header.")
+ return
+
+ required_columns = ['data_description', 'natural_language_query', 'data_files']
+ missing_cols = [col for col in required_columns if col not in csv_reader.fieldnames]
+ if missing_cols:
+ print(f"[ERROR] Missing required columns in CSV file '{INPUT_CSV_PATH}': {', '.join(missing_cols)}")
+ print(f"Available columns: {csv_reader.fieldnames}")
+ return
+
+ for row_number, row in enumerate(csv_reader, 1):
+ processed_csv_rows += 1
+ data_description = row.get('data_description', '').strip()
+ natural_language_query = row.get('natural_language_query', '').strip()
+ data_files_string = row.get('data_files', '').strip()
+ answer = row.get('answer', '').strip()
+
+ if not data_files_string:
+ print(f"[WARN] CSV Row {row_number}: 'data_files' field is empty. Skipping.")
+ continue
+
+ individual_files = [f.strip() for f in data_files_string.split(',') if f.strip()]
+
+ if not individual_files:
+ print(f"[WARN] CSV Row {row_number}: 'data_files' contained only separators or was effectively empty after stripping. Original: '{data_files_string}'. Skipping.")
+ continue
+
+ for file_name in individual_files:
+ current_query_sequence_number += 1
+
+ dataset_path = os.path.join(DATA_FILES_BASE_DIR, file_name)
+
+ log_data = {
+ "query_number": current_query_sequence_number,
+ "natural_language_query": natural_language_query,
+ "dataset_path": dataset_path,
+ "answer": answer
+ }
+
+ try:
+ with open(OUTPUT_LOG_FILE, mode='a', encoding='utf-8') as log_file:
+ log_file.write('\n' + json.dumps(log_data) + '\n')
+ except IOError as e:
+ print(f"[ERROR] Failed to write pre-analysis log for query #{current_query_sequence_number} to '{OUTPUT_LOG_FILE}': {e}")
+ continue # Skip to next file/row if logging fails
+
+ successful_analysis = False
+ for attempt in range(2): # Attempt 0 (first try), Attempt 1 (retry)
+ try:
+ print(f"[INFO] --- Starting Analysis (Attempt {attempt + 1}/2) ---")
+ print(f"[INFO] Query Sequence #: {current_query_sequence_number}")
+ print(f"[INFO] CSV Row: {row_number}, File: '{file_name}'")
+ benchmark_causal_analysis(
+ natural_language_query=natural_language_query,
+ dataset_path=dataset_path,
+ data_description=data_description
+ )
+ successful_analysis = True
+ print(f"[INFO] --- Analysis Successful (Attempt {attempt + 1}/2) ---")
+ break
+ except Exception as e:
+ print(f"[ERROR] run_causal_analysis failed on attempt {attempt + 1}/2 for query #{current_query_sequence_number}: {e}")
+ if attempt == 1: # This was the retry, and it also failed
+ print(f"[INFO] Both attempts failed for query #{current_query_sequence_number}.")
+ try:
+ with open(OUTPUT_LOG_FILE, mode='a', encoding='utf-8') as log_file:
+ log_file.write(f"\n{current_query_sequence_number}:Failed\n")
+ except IOError as ioe_fail:
+ print(f"[ERROR] Failed to write failure status for query #{current_query_sequence_number} to '{OUTPUT_LOG_FILE}': {ioe_fail}")
+ else:
+ print(f"[INFO] Will retry query #{current_query_sequence_number}.")
+ # time.sleep(1) # Optional: wait a bit before retrying
+
+ except FileNotFoundError:
+ print(f"[ERROR] Input CSV file not found: '{INPUT_CSV_PATH}'")
+ except Exception as e:
+ print(f"[ERROR] An unexpected error occurred during script execution: {e}")
+ import traceback
+ traceback.print_exc()
+ finally:
+ print(f"--- Script finished ---")
+ print(f"Total CSV rows processed: {processed_csv_rows}")
+ print(f"Total analysis calls attempted (query_number): {current_query_sequence_number}")
+ print(f"Log file: {OUTPUT_LOG_FILE}")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/tests/test_did.py b/tests/test_did.py
new file mode 100644
index 0000000000000000000000000000000000000000..618e7d18353c376b743a9eb529c8f7cd6b307123
--- /dev/null
+++ b/tests/test_did.py
@@ -0,0 +1,33 @@
+import pandas as pd
+import statsmodels.formula.api as smf
+import statsmodels.api as sm
+
+# --- Step 1: Load Data ---
+df = pd.read_csv("benchmark/all_data/billboard_impact.csv")
+
+# --- Step 2: Create Interaction Term ---
+# poa = 1 for treatment group (Porto Alegre), 0 otherwise
+# jul = 1 for post-intervention (July), 0 otherwise
+df['did_interaction'] = df['poa'] * df['jul']
+
+# --- Step 3: Specify the DiD Formula ---
+# Includes fixed effects for group (poa), time (jul), and their interaction
+formula = "deposits ~ did_interaction + C(poa) + C(jul)"
+
+# --- Step 4: Fit the Model ---
+model = smf.ols(formula=formula, data=df)
+results = model.fit()
+
+# --- Step 5: Extract and Print DiD Estimate ---
+coef = results.params['did_interaction']
+conf_int = results.conf_int().loc['did_interaction']
+stderr = results.bse['did_interaction']
+pval = results.pvalues['did_interaction']
+
+print("=== Difference-in-Differences Estimation ===")
+print(f"Treatment effect (DiD estimate): {coef:.2f}")
+print(f"Standard error: {stderr:.2f}")
+print(f"95% CI: ({conf_int[0]:.2f}, {conf_int[1]:.2f})")
+print(f"P-value: {pval:.4f}")
+print("\nModel Summary:")
+print(results.summary())