Spaces:
Running
Running
Commit
·
1721aea
0
Parent(s):
Initial clean commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env.example +0 -0
- .gitattributes +37 -0
- .gitignore +217 -0
- README copy.md +198 -0
- README.md +13 -0
- app.py +339 -0
- auto_causal/__init__.py +50 -0
- auto_causal/agent.py +394 -0
- auto_causal/components/__init__.py +28 -0
- auto_causal/components/dataset_analyzer.py +853 -0
- auto_causal/components/decision_tree.py +366 -0
- auto_causal/components/decision_tree_llm.py +218 -0
- auto_causal/components/explanation_generator.py +404 -0
- auto_causal/components/input_parser.py +456 -0
- auto_causal/components/method_validator.py +327 -0
- auto_causal/components/output_formatter.py +138 -0
- auto_causal/components/query_interpreter.py +580 -0
- auto_causal/components/state_manager.py +40 -0
- auto_causal/config.py +97 -0
- auto_causal/methods/__init__.py +44 -0
- auto_causal/methods/backdoor_adjustment/__init__.py +0 -0
- auto_causal/methods/backdoor_adjustment/diagnostics.py +92 -0
- auto_causal/methods/backdoor_adjustment/estimator.py +105 -0
- auto_causal/methods/backdoor_adjustment/llm_assist.py +176 -0
- auto_causal/methods/causal_method.py +88 -0
- auto_causal/methods/diff_in_means/__init__.py +0 -0
- auto_causal/methods/diff_in_means/diagnostics.py +60 -0
- auto_causal/methods/diff_in_means/estimator.py +107 -0
- auto_causal/methods/diff_in_means/llm_assist.py +95 -0
- auto_causal/methods/difference_in_differences/diagnostics.py +345 -0
- auto_causal/methods/difference_in_differences/estimator.py +463 -0
- auto_causal/methods/difference_in_differences/llm_assist.py +362 -0
- auto_causal/methods/difference_in_differences/utils.py +65 -0
- auto_causal/methods/generalized_propensity_score/__init__.py +3 -0
- auto_causal/methods/generalized_propensity_score/diagnostics.py +196 -0
- auto_causal/methods/generalized_propensity_score/estimator.py +386 -0
- auto_causal/methods/generalized_propensity_score/llm_assist.py +208 -0
- auto_causal/methods/instrumental_variable/__init__.py +1 -0
- auto_causal/methods/instrumental_variable/diagnostics.py +218 -0
- auto_causal/methods/instrumental_variable/estimator.py +370 -0
- auto_causal/methods/instrumental_variable/llm_assist.py +240 -0
- auto_causal/methods/linear_regression/__init__.py +0 -0
- auto_causal/methods/linear_regression/diagnostics.py +76 -0
- auto_causal/methods/linear_regression/estimator.py +355 -0
- auto_causal/methods/linear_regression/llm_assist.py +146 -0
- auto_causal/methods/propensity_score/__init__.py +13 -0
- auto_causal/methods/propensity_score/base.py +80 -0
- auto_causal/methods/propensity_score/diagnostics.py +74 -0
- auto_causal/methods/propensity_score/llm_assist.py +45 -0
- auto_causal/methods/propensity_score/matching.py +341 -0
.env.example
ADDED
File without changes
|
.gitattributes
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be added to the global gitignore or merged into this project gitignore. For a PyCharm
|
158 |
+
# project, it is recommended to include directory-based project settings:
|
159 |
+
.idea/
|
160 |
+
|
161 |
+
# VS Code
|
162 |
+
.vscode/
|
163 |
+
|
164 |
+
# Data files
|
165 |
+
*.csv
|
166 |
+
*.xlsx
|
167 |
+
*.xls
|
168 |
+
*.json
|
169 |
+
*.parquet
|
170 |
+
*.pickle
|
171 |
+
*.pkl
|
172 |
+
*.h5
|
173 |
+
*.hdf5
|
174 |
+
|
175 |
+
# Model files
|
176 |
+
*.model
|
177 |
+
*.joblib
|
178 |
+
*.sav
|
179 |
+
|
180 |
+
# Output directories
|
181 |
+
outputs/
|
182 |
+
results/
|
183 |
+
logs/
|
184 |
+
checkpoints/
|
185 |
+
wandb/
|
186 |
+
|
187 |
+
# Temporary files
|
188 |
+
*.tmp
|
189 |
+
*.temp
|
190 |
+
*~
|
191 |
+
|
192 |
+
# OS generated files
|
193 |
+
.DS_Store
|
194 |
+
.DS_Store?
|
195 |
+
._*
|
196 |
+
.Spotlight-V100
|
197 |
+
.Trashes
|
198 |
+
ehthumbs.db
|
199 |
+
Thumbs.db
|
200 |
+
|
201 |
+
# LLM API keys and secrets
|
202 |
+
.env.local
|
203 |
+
.env.production
|
204 |
+
secrets.json
|
205 |
+
api_keys.txt
|
206 |
+
|
207 |
+
# Experiment tracking
|
208 |
+
mlruns/
|
209 |
+
.mlflow/
|
210 |
+
|
211 |
+
# Large files (adjust sizes as needed)
|
212 |
+
*.zip
|
213 |
+
*.tar.gz
|
214 |
+
*.rar
|
215 |
+
|
216 |
+
# Project specific
|
217 |
+
tests/output/
|
README copy.md
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h1 align="center">
|
2 |
+
<img src="blob/main/asset/cais.png" width="400" alt="CAIS" />
|
3 |
+
<br>
|
4 |
+
Causal AI Scientist: Facilitating Causal Data Science with
|
5 |
+
Large Language Models
|
6 |
+
</h1>
|
7 |
+
<!-- <p align="center">
|
8 |
+
<a href="https://causalcopilot.com/"><b>[Demo]</b></a> •
|
9 |
+
<a href="https://github.com/Lancelot39/Causal-Copilot"><b>[Code]</b></a> •
|
10 |
+
<a href="">"Coming Soon"<b>[Arxiv(coming soon)]</b></a>
|
11 |
+
</p> -->
|
12 |
+
|
13 |
+
**Causal AI Scientist (CAIS)** is an LLM-powered tool for generating data-driven answers to natural language causal queries. It takes a natural language query (for example, "Does participating in a job training program lead to higher income?"), an accompanying dataset, and the corresponding description as inputs. CAIS then frames a suitable causal estimation problem by selecting appropriate treatment and outcome variables. It finds the suitable method for causal effect estimation, implements it, runs diagnostic tests, and finally interprets the numerical results in the context of the original query.
|
14 |
+
|
15 |
+
This repo includes instructions on both using the tool to perform causal analysis on a dataset of interest and reproducing results from our paper.
|
16 |
+
|
17 |
+
**Note** : This repository is a work in progress and will be updated with additional instructions and files.
|
18 |
+
|
19 |
+
<!-- ## 1. Introduction
|
20 |
+
|
21 |
+
Causal effect estimation is central to evidence-based decision-making across domains like social sciences, healthcare, and economics. However, it requires specialized expertise to select the right inference method, identify valid variables, and validate results.
|
22 |
+
|
23 |
+
**CAIS (Causal AI Scientist)** automates this process using Large Language Models (LLMs) to:
|
24 |
+
- Parse a natural language causal query.
|
25 |
+
- Analyze the dataset characteristics.
|
26 |
+
- Select the appropriate causal inference method via a decision tree and prompting strategies.
|
27 |
+
- Execute the method using pre-defined code templates.
|
28 |
+
- Validate and interpret the results.
|
29 |
+
|
30 |
+
<div style="text-align: center;">
|
31 |
+
<img src="blob/main/asset/CAIS-arch.png" width="990" alt="CAIS" />
|
32 |
+
</div>
|
33 |
+
</h1>
|
34 |
+
|
35 |
+
**Key Features:**
|
36 |
+
- End-to-end causal estimation with minimal user input.
|
37 |
+
- Supports a wide range of methods:
|
38 |
+
- **Econometric:** Difference-in-Differences (DiD), Instrumental Variables (IV), Ordinary Least Squares (OLS), Regression Discontinuity Design (RDD).
|
39 |
+
- **Causal Graph-based:** Backdoor adjustment, Frontdoor adjustment.
|
40 |
+
- Combines structured reasoning (decision tree) with LLM-powered interpretation.
|
41 |
+
- Works on clean textbook datasets, messy real-world datasets, and synthetic scenarios.
|
42 |
+
|
43 |
+
|
44 |
+
CAIS consists of three main stages, powered by a **decision-tree-driven reasoning pipeline**:
|
45 |
+
|
46 |
+
### **Stage 1: Variable and Method Selection**
|
47 |
+
1. **Dataset & Query Analysis**
|
48 |
+
- The LLM inspects the dataset description, variable names, and statistical summaries.
|
49 |
+
- Identifies treatment, outcome, and covariates.
|
50 |
+
2. **Property Detection**
|
51 |
+
- Uses targeted prompts to detect dataset properties:
|
52 |
+
- Randomized vs observational
|
53 |
+
- Presence of temporal/running variables
|
54 |
+
- Availability of valid instruments
|
55 |
+
3. **Decision Tree Traversal**
|
56 |
+
- Traverses a predefined causal inference decision tree (Fig. B in paper).
|
57 |
+
- Maps detected properties to the most appropriate estimation method.
|
58 |
+
|
59 |
+
---
|
60 |
+
|
61 |
+
### **Stage 2: Causal Inference Execution**
|
62 |
+
1. **Template-based Code Generation**
|
63 |
+
- Predefined Python templates for each method (e.g., DiD, IV, OLS).
|
64 |
+
- Variables from Stage 1 are substituted into templates.
|
65 |
+
2. **Diagnostics & Validation**
|
66 |
+
- Runs statistical tests and checks assumptions where applicable.
|
67 |
+
- Handles basic data preprocessing (e.g., type conversion for DoWhy).
|
68 |
+
|
69 |
+
---
|
70 |
+
|
71 |
+
### **Stage 3: Result Interpretation**
|
72 |
+
- LLM interprets numerical results and diagnostics in the context of the user’s causal query.
|
73 |
+
- Outputs:
|
74 |
+
- Estimated causal effect (ATE, ATT, or LATE).
|
75 |
+
- Standard errors, confidence intervals.
|
76 |
+
- Plain-language explanation.
|
77 |
+
|
78 |
+
---
|
79 |
+
## 3. Evaluation
|
80 |
+
|
81 |
+
We evaluate **CAIS** across three diverse dataset collections:
|
82 |
+
1. **QRData (Textbook Examples)** – curated, clean datasets with known causal effects.
|
83 |
+
2. **Real-World Studies** – empirical datasets from research papers (economics, health, political science).
|
84 |
+
3. **Synthetic Data** – generated with controlled causal structures to ensure balanced method coverage.
|
85 |
+
|
86 |
+
### **Metrics**
|
87 |
+
We assess CAIS on:
|
88 |
+
- **Method Selection Accuracy (MSA)** – % of cases where CAIS selects the correct inference method as per the reference.
|
89 |
+
- **Mean Relative Error (MRE)** – Average relative error between CAIS’s estimated causal effect and the reference value.
|
90 |
+
|
91 |
+
|
92 |
+
<p align="center">
|
93 |
+
<table>
|
94 |
+
<tr>
|
95 |
+
<td align="center">
|
96 |
+
<img src="blob/main/asset/CAIS-MRE.png" width="450" alt="CAIS MRE"/>
|
97 |
+
</td>
|
98 |
+
<td align="center">
|
99 |
+
<img src="blob/main/asset/CAIS-msa.png" width="450" alt="CAIS MSA"/>
|
100 |
+
</td>
|
101 |
+
</tr>
|
102 |
+
</table>
|
103 |
+
</p>
|
104 |
+
-->
|
105 |
+
|
106 |
+
## Getting Started
|
107 |
+
|
108 |
+
#### 🔧 Environment Installation
|
109 |
+
|
110 |
+
|
111 |
+
**Prerequisites:**
|
112 |
+
- **Python 3.10** (create a new conda environment first)
|
113 |
+
- Required Python libraries (specified in `requirements.txt`)
|
114 |
+
|
115 |
+
|
116 |
+
**Step 1: Copy the example configuration**
|
117 |
+
```bash
|
118 |
+
cp .env.example .env
|
119 |
+
```
|
120 |
+
|
121 |
+
**Step 2: Create Python 3.10 environment**
|
122 |
+
```bash
|
123 |
+
# Create a new conda environment with Python 3.10
|
124 |
+
conda create -n auto_causal python=3.10
|
125 |
+
conda activate auto_causal
|
126 |
+
pip install -r requirement.txt
|
127 |
+
```
|
128 |
+
|
129 |
+
**Step3: Setup auto_causal library**
|
130 |
+
```bash
|
131 |
+
pip install -e .
|
132 |
+
```
|
133 |
+
|
134 |
+
## Dataset Information
|
135 |
+
|
136 |
+
All datasets used to evaluate CAIs and the baseline models are available in the data/ directory. Specifically:
|
137 |
+
|
138 |
+
* `all_data`: Folder containing all CSV files from the QRData and real-world study collections.
|
139 |
+
* `synthetic_data`: Folder containing all CSV files corresponding to synthetic datasets.
|
140 |
+
* `qr_info.csv`: Metadata for QRData files. For each file, this includes the filename, description, causal query, reference causal effect, intended inference method, and additional remarks.
|
141 |
+
* `real_info.csv`: Metadata for the real-world datasets.
|
142 |
+
* `synthetic_info.csv`: Metadata for the synthetic datasets.
|
143 |
+
|
144 |
+
## Run
|
145 |
+
To execute CAIS, run
|
146 |
+
```python
|
147 |
+
python main/run_cais.py \
|
148 |
+
--metadata_path {path_to_metadata} \
|
149 |
+
--data_dir {path_to_data_folder} \
|
150 |
+
--output_dir {output_folder} \
|
151 |
+
--output_name {output_filename} \
|
152 |
+
--llm_name {llm_name}
|
153 |
+
```
|
154 |
+
Args:
|
155 |
+
|
156 |
+
* metadata_path (str): Path to the CSV file containing the queries, dataset descriptions, and data file names
|
157 |
+
* data_dir (str): Path to the folder containing the data in CSV format
|
158 |
+
* output_dir (str): Path to the folder where the output JSON results will be saved
|
159 |
+
* output_name (str): Name of the JSON file where the outputs will be saved
|
160 |
+
* llm_name (str): Name of the LLM to be used (e.g., 'gpt-4', 'claude-3', etc.)
|
161 |
+
|
162 |
+
A specific example,
|
163 |
+
```python
|
164 |
+
python main/run_cais.py \
|
165 |
+
--metadata_path "data/qr_info.csv" \
|
166 |
+
--data_dir "data/all_data" \
|
167 |
+
--output_dir "output" \
|
168 |
+
--output_name "results_qr_4o" \
|
169 |
+
--llm_name "gpt-4o-mini"
|
170 |
+
```
|
171 |
+
|
172 |
+
|
173 |
+
## Reproducing paper results
|
174 |
+
**Will be updated soon**
|
175 |
+
|
176 |
+
**⚠️ Important Notes:**
|
177 |
+
- Keep your `.env` file secure and never commit it to version control
|
178 |
+
|
179 |
+
## License
|
180 |
+
|
181 |
+
Distributed under the MIT License. See `LICENSE` for more information.
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
<!--## Contributors
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
**Core Contributors**: Vishal Verma, Sawal Acharya, Devansh Bhardwaj
|
190 |
+
|
191 |
+
**Other Contributors**: Zhijing Jin, Ana Hagihat, Samuel Simko
|
192 |
+
|
193 |
+
---
|
194 |
+
|
195 |
+
## Contact
|
196 |
+
|
197 |
+
For additional information, questions, or feedback, please contact ours **[Vishal Verma]([email protected])**, **[Sawal Acharya]([email protected])**, **[Devansh Bhardwaj]([email protected])**. We welcome contributions! Come and join us now!
|
198 |
+
-->
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Causal AI Scientist
|
3 |
+
emoji: 🌍
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.41.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
from pathlib import Path
|
5 |
+
import gradio as gr
|
6 |
+
import time
|
7 |
+
|
8 |
+
# Make your repo importable (expecting a folder named causal-agent at repo root)
|
9 |
+
sys.path.append(str(Path(__file__).parent / "causal-agent"))
|
10 |
+
|
11 |
+
from auto_causal.agent import run_causal_analysis # uses env for provider/model
|
12 |
+
|
13 |
+
# -------- LLM config (OpenAI only; key via HF Secrets) --------
|
14 |
+
os.environ.setdefault("LLM_PROVIDER", "openai")
|
15 |
+
os.environ.setdefault("LLM_MODEL", "gpt-4o")
|
16 |
+
|
17 |
+
# Lazy import to avoid import-time errors if key missing
|
18 |
+
def _get_openai_client():
|
19 |
+
if os.getenv("LLM_PROVIDER", "openai") != "openai":
|
20 |
+
raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.")
|
21 |
+
if not os.getenv("OPENAI_API_KEY"):
|
22 |
+
raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).")
|
23 |
+
try:
|
24 |
+
# OpenAI SDK v1+
|
25 |
+
from openai import OpenAI
|
26 |
+
return OpenAI()
|
27 |
+
except Exception as e:
|
28 |
+
raise RuntimeError(f"OpenAI SDK not available: {e}")
|
29 |
+
|
30 |
+
# -------- System prompt you asked for (verbatim) --------
|
31 |
+
SYSTEM_PROMPT = """You are an expert in statistics and causal inference.
|
32 |
+
You will be given:
|
33 |
+
1) The original research question.
|
34 |
+
2) The analysis method used.
|
35 |
+
3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group.
|
36 |
+
4) A brief dataset description.
|
37 |
+
|
38 |
+
Your task is to produce a clear, concise, and non-technical summary that:
|
39 |
+
- Directly answers the research question.
|
40 |
+
- States whether the effect is statistically significant.
|
41 |
+
- Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change).
|
42 |
+
- Mentions the method used in one sentence.
|
43 |
+
- Optionally ranks the treatment effects from largest to smallest if multiple treatments exist.
|
44 |
+
|
45 |
+
Formatting rules:
|
46 |
+
- Use bullet points or short paragraphs.
|
47 |
+
- Report effect sizes to two decimal places.
|
48 |
+
- Clearly state the interpretation in plain English without technical jargon.
|
49 |
+
|
50 |
+
Example Output Structure:
|
51 |
+
- **Method:** [Name of method + 1-line rationale]
|
52 |
+
- **Key Finding:** [Main answer to the research question]
|
53 |
+
- **Details:**
|
54 |
+
- [Treatment name]: +X.XX percentage points (95% CI: [L, U]), p < 0.001 — [Significance comment]
|
55 |
+
- …
|
56 |
+
- **Rank Order of Effects:** [Largest → Smallest]
|
57 |
+
"""
|
58 |
+
|
59 |
+
def _extract_minimal_payload(agent_result: dict) -> dict:
|
60 |
+
"""
|
61 |
+
Extract the minimal, LLM-friendly payload from run_causal_analysis output.
|
62 |
+
Falls back gracefully if any fields are missing.
|
63 |
+
"""
|
64 |
+
# Try both top-level and nested (your JSON showed both patterns)
|
65 |
+
res = agent_result or {}
|
66 |
+
results = res.get("results", {}) if isinstance(res.get("results"), dict) else {}
|
67 |
+
inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {}
|
68 |
+
vars_ = results.get("variables", {}) if isinstance(results.get("variables"), dict) else {}
|
69 |
+
dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {}
|
70 |
+
|
71 |
+
# Pull best-available fields
|
72 |
+
question = (
|
73 |
+
results.get("original_query")
|
74 |
+
or dataset_analysis.get("original_query")
|
75 |
+
or res.get("query")
|
76 |
+
or "N/A"
|
77 |
+
)
|
78 |
+
method = (
|
79 |
+
inner.get("method_used")
|
80 |
+
or res.get("method_used")
|
81 |
+
or results.get("method_used")
|
82 |
+
or "N/A"
|
83 |
+
)
|
84 |
+
|
85 |
+
effect_estimate = (
|
86 |
+
inner.get("effect_estimate")
|
87 |
+
or res.get("effect_estimate")
|
88 |
+
or {}
|
89 |
+
)
|
90 |
+
confidence_interval = (
|
91 |
+
inner.get("confidence_interval")
|
92 |
+
or res.get("confidence_interval")
|
93 |
+
or {}
|
94 |
+
)
|
95 |
+
standard_error = (
|
96 |
+
inner.get("standard_error")
|
97 |
+
or res.get("standard_error")
|
98 |
+
or {}
|
99 |
+
)
|
100 |
+
p_value = (
|
101 |
+
inner.get("p_value")
|
102 |
+
or res.get("p_value")
|
103 |
+
or {}
|
104 |
+
)
|
105 |
+
|
106 |
+
dataset_desc = (
|
107 |
+
results.get("dataset_description")
|
108 |
+
or res.get("dataset_description")
|
109 |
+
or "N/A"
|
110 |
+
)
|
111 |
+
|
112 |
+
return {
|
113 |
+
"original_question": question,
|
114 |
+
"method_used": method,
|
115 |
+
"estimates": {
|
116 |
+
"effect_estimate": effect_estimate,
|
117 |
+
"confidence_interval": confidence_interval,
|
118 |
+
"standard_error": standard_error,
|
119 |
+
"p_value": p_value,
|
120 |
+
},
|
121 |
+
"dataset_description": dataset_desc,
|
122 |
+
}
|
123 |
+
|
124 |
+
def _format_effects_md(effect_estimate: dict) -> str:
|
125 |
+
"""
|
126 |
+
Minimal human-readable view of effect estimates for display.
|
127 |
+
"""
|
128 |
+
if not effect_estimate or not isinstance(effect_estimate, dict):
|
129 |
+
return "_No effect estimates found._"
|
130 |
+
# Render as bullet list
|
131 |
+
lines = []
|
132 |
+
for k, v in effect_estimate.items():
|
133 |
+
try:
|
134 |
+
lines.append(f"- **{k}**: {float(v):+.4f}")
|
135 |
+
except Exception:
|
136 |
+
lines.append(f"- **{k}**: {v}")
|
137 |
+
return "\n".join(lines)
|
138 |
+
|
139 |
+
def _summarize_with_llm(payload: dict) -> str:
|
140 |
+
"""
|
141 |
+
Calls OpenAI with the provided SYSTEM_PROMPT and the JSON payload as the user message.
|
142 |
+
Returns the model's text, or raises on error.
|
143 |
+
"""
|
144 |
+
client = _get_openai_client()
|
145 |
+
model_name = os.getenv("LLM_MODEL", "gpt-4o-mini")
|
146 |
+
|
147 |
+
user_content = (
|
148 |
+
"Summarize the following causal analysis results:\n\n"
|
149 |
+
+ json.dumps(payload, indent=2, ensure_ascii=False)
|
150 |
+
)
|
151 |
+
|
152 |
+
# Use Chat Completions for broad compatibility
|
153 |
+
resp = client.chat.completions.create(
|
154 |
+
model=model_name,
|
155 |
+
messages=[
|
156 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
157 |
+
{"role": "user", "content": user_content},
|
158 |
+
],
|
159 |
+
temperature=0
|
160 |
+
)
|
161 |
+
text = resp.choices[0].message.content.strip()
|
162 |
+
return text
|
163 |
+
|
164 |
+
def run_agent(query: str, csv_path: str, dataset_description: str):
|
165 |
+
"""
|
166 |
+
Modified to use yield for progressive updates and immediate feedback
|
167 |
+
"""
|
168 |
+
# Immediate feedback - show processing has started
|
169 |
+
processing_html = """
|
170 |
+
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
|
171 |
+
<div style='font-size: 16px; margin-bottom: 5px;'>🔄 Analysis in Progress...</div>
|
172 |
+
<div style='font-size: 14px; color: #666;'>This may take 1-2 minutes depending on dataset size</div>
|
173 |
+
</div>
|
174 |
+
"""
|
175 |
+
|
176 |
+
yield (
|
177 |
+
processing_html, # method_out
|
178 |
+
processing_html, # effects_out
|
179 |
+
processing_html, # explanation_out
|
180 |
+
{"status": "Processing started..."} # raw_results
|
181 |
+
)
|
182 |
+
|
183 |
+
# Input validation
|
184 |
+
if not os.getenv("OPENAI_API_KEY"):
|
185 |
+
error_html = "<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>⚠️ Set a Space Secret named OPENAI_API_KEY</div>"
|
186 |
+
yield (error_html, "", "", {})
|
187 |
+
return
|
188 |
+
|
189 |
+
if not csv_path:
|
190 |
+
error_html = "<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>Please upload a CSV dataset.</div>"
|
191 |
+
yield (error_html, "", "", {})
|
192 |
+
return
|
193 |
+
|
194 |
+
try:
|
195 |
+
# Update status to show causal analysis is running
|
196 |
+
analysis_html = """
|
197 |
+
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
|
198 |
+
<div style='font-size: 16px; margin-bottom: 5px;'>📊 Running Causal Analysis...</div>
|
199 |
+
<div style='font-size: 14px; color: #666;'>Analyzing dataset and selecting optimal method</div>
|
200 |
+
</div>
|
201 |
+
"""
|
202 |
+
|
203 |
+
yield (
|
204 |
+
analysis_html,
|
205 |
+
analysis_html,
|
206 |
+
analysis_html,
|
207 |
+
{"status": "Running causal analysis..."}
|
208 |
+
)
|
209 |
+
|
210 |
+
result = run_causal_analysis(
|
211 |
+
query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
|
212 |
+
dataset_path=csv_path,
|
213 |
+
dataset_description=(dataset_description or "").strip(),
|
214 |
+
)
|
215 |
+
|
216 |
+
# Update to show LLM summarization step
|
217 |
+
llm_html = """
|
218 |
+
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
|
219 |
+
<div style='font-size: 16px; margin-bottom: 5px;'>🤖 Generating Summary...</div>
|
220 |
+
<div style='font-size: 14px; color: #666;'>Creating human-readable interpretation</div>
|
221 |
+
</div>
|
222 |
+
"""
|
223 |
+
|
224 |
+
yield (
|
225 |
+
llm_html,
|
226 |
+
llm_html,
|
227 |
+
llm_html,
|
228 |
+
{"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}}
|
229 |
+
)
|
230 |
+
|
231 |
+
except Exception as e:
|
232 |
+
error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>❌ Error: {e}</div>"
|
233 |
+
yield (error_html, "", "", {})
|
234 |
+
return
|
235 |
+
|
236 |
+
try:
|
237 |
+
payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
|
238 |
+
method = payload.get("method_used", "N/A")
|
239 |
+
|
240 |
+
# Format method output with simple styling
|
241 |
+
method_html = f"""
|
242 |
+
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
|
243 |
+
<h3 style='margin: 0 0 10px 0; font-size: 18px;'>Selected Method</h3>
|
244 |
+
<p style='margin: 0; font-size: 16px;'>{method}</p>
|
245 |
+
</div>
|
246 |
+
"""
|
247 |
+
|
248 |
+
# Format effects with simple styling
|
249 |
+
effect_estimate = payload.get("estimates", {}).get("effect_estimate", {})
|
250 |
+
if effect_estimate:
|
251 |
+
effects_html = "<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>"
|
252 |
+
effects_html += "<h3 style='margin: 0 0 10px 0; font-size: 18px;'>Effect Estimates</h3>"
|
253 |
+
# for k, v in effect_estimate.items():
|
254 |
+
# try:
|
255 |
+
# value = f"{float(v):+.4f}"
|
256 |
+
# effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #ffffff;'><strong>{k}:</strong> <span style='font-size: 16px;'>{value}</span></div>"
|
257 |
+
# except:
|
258 |
+
effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #333333;'>{effect_estimate}</div>"
|
259 |
+
effects_html += "</div>"
|
260 |
+
else:
|
261 |
+
effects_html = "<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; color: #666; font-style: italic; background-color: #333333;'>No effect estimates found</div>"
|
262 |
+
|
263 |
+
# Generate explanation and format it
|
264 |
+
try:
|
265 |
+
explanation = _summarize_with_llm(payload)
|
266 |
+
explanation_html = f"""
|
267 |
+
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
|
268 |
+
<h3 style='margin: 0 0 15px 0; font-size: 18px;'>Detailed Explanation</h3>
|
269 |
+
<div style='line-height: 1.6; white-space: pre-wrap;'>{explanation}</div>
|
270 |
+
</div>
|
271 |
+
"""
|
272 |
+
except Exception as e:
|
273 |
+
explanation_html = f"<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>⚠️ LLM summary failed: {e}</div>"
|
274 |
+
|
275 |
+
except Exception as e:
|
276 |
+
error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>❌ Failed to parse results: {e}</div>"
|
277 |
+
yield (error_html, "", "", {})
|
278 |
+
return
|
279 |
+
|
280 |
+
# Final result
|
281 |
+
yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {})
|
282 |
+
|
283 |
+
with gr.Blocks() as demo:
|
284 |
+
gr.Markdown("# Causal Agent")
|
285 |
+
gr.Markdown("Upload your dataset and ask causal questions in natural language. The system will automatically select the appropriate causal inference method and provide clear explanations.")
|
286 |
+
|
287 |
+
with gr.Row():
|
288 |
+
query = gr.Textbox(
|
289 |
+
label="Your causal question (natural language)",
|
290 |
+
placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?",
|
291 |
+
lines=2,
|
292 |
+
)
|
293 |
+
|
294 |
+
with gr.Row():
|
295 |
+
csv_file = gr.File(
|
296 |
+
label="Dataset (CSV)",
|
297 |
+
file_types=[".csv"],
|
298 |
+
type="filepath"
|
299 |
+
)
|
300 |
+
|
301 |
+
dataset_description = gr.Textbox(
|
302 |
+
label="Dataset description (optional)",
|
303 |
+
placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.",
|
304 |
+
lines=4,
|
305 |
+
)
|
306 |
+
|
307 |
+
run_btn = gr.Button("Run analysis", variant="primary")
|
308 |
+
|
309 |
+
with gr.Row():
|
310 |
+
with gr.Column(scale=1):
|
311 |
+
method_out = gr.HTML(label="Selected Method")
|
312 |
+
with gr.Column(scale=1):
|
313 |
+
effects_out = gr.HTML(label="Effect Estimates")
|
314 |
+
|
315 |
+
with gr.Row():
|
316 |
+
explanation_out = gr.HTML(label="Detailed Explanation")
|
317 |
+
|
318 |
+
# Add the collapsible raw results section
|
319 |
+
with gr.Accordion("Raw Results (Advanced)", open=False):
|
320 |
+
raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
|
321 |
+
|
322 |
+
run_btn.click(
|
323 |
+
fn=run_agent,
|
324 |
+
inputs=[query, csv_file, dataset_description],
|
325 |
+
outputs=[method_out, effects_out, explanation_out, raw_results],
|
326 |
+
show_progress=True
|
327 |
+
)
|
328 |
+
|
329 |
+
gr.Markdown(
|
330 |
+
"""
|
331 |
+
**Tips:**
|
332 |
+
- Be specific about your treatment, outcome, and control variables
|
333 |
+
- Include relevant context in the dataset description
|
334 |
+
- The analysis may take 1-2 minutes for complex datasets
|
335 |
+
"""
|
336 |
+
)
|
337 |
+
|
338 |
+
if __name__ == "__main__":
|
339 |
+
demo.queue().launch()
|
auto_causal/__init__.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Auto Causal module for causal inference.
|
3 |
+
|
4 |
+
This module provides automated causal inference capabilities
|
5 |
+
through a pipeline that selects and applies appropriate causal methods.
|
6 |
+
"""
|
7 |
+
|
8 |
+
__version__ = "0.1.0"
|
9 |
+
|
10 |
+
# Import components
|
11 |
+
from auto_causal.components import (
|
12 |
+
parse_input,
|
13 |
+
analyze_dataset,
|
14 |
+
interpret_query,
|
15 |
+
validate_method,
|
16 |
+
generate_explanation,
|
17 |
+
format_output,
|
18 |
+
create_workflow_state_update
|
19 |
+
)
|
20 |
+
|
21 |
+
# Import tools
|
22 |
+
from auto_causal.tools import (
|
23 |
+
input_parser_tool,
|
24 |
+
dataset_analyzer_tool,
|
25 |
+
query_interpreter_tool,
|
26 |
+
method_selector_tool,
|
27 |
+
method_validator_tool,
|
28 |
+
method_executor_tool,
|
29 |
+
explanation_generator_tool,
|
30 |
+
output_formatter_tool
|
31 |
+
)
|
32 |
+
|
33 |
+
# Import the main agent function
|
34 |
+
from .agent import run_causal_analysis
|
35 |
+
|
36 |
+
# Remove backward compatibility for old pipeline
|
37 |
+
# try:
|
38 |
+
# from .pipeline import CausalInferencePipeline
|
39 |
+
# except ImportError:
|
40 |
+
# # Define a placeholder class if the old pipeline doesn't exist
|
41 |
+
# class CausalInferencePipeline:
|
42 |
+
# """Placeholder for CausalInferencePipeline."""
|
43 |
+
#
|
44 |
+
# def __init__(self, *args, **kwargs):
|
45 |
+
# pass
|
46 |
+
|
47 |
+
# Update __all__ to export the main function
|
48 |
+
__all__ = [
|
49 |
+
'run_causal_analysis'
|
50 |
+
]
|
auto_causal/agent.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LangChain agent for the auto_causal module.
|
3 |
+
|
4 |
+
This module configures a LangChain agent with specialized tools for causal inference,
|
5 |
+
allowing for an interactive approach to analyzing datasets and applying appropriate
|
6 |
+
causal inference methods.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
from typing import Dict, List, Any, Optional
|
11 |
+
from langchain.agents.react.agent import create_react_agent
|
12 |
+
from langchain.agents import AgentExecutor, create_structured_chat_agent, create_tool_calling_agent
|
13 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
14 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
15 |
+
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
|
16 |
+
from langchain.tools import tool
|
17 |
+
# Import the callback handler
|
18 |
+
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
19 |
+
# Import tool rendering utility
|
20 |
+
from langchain.tools.render import render_text_description
|
21 |
+
# Import LCEL components
|
22 |
+
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
|
23 |
+
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
|
24 |
+
from langchain_core.runnables import RunnablePassthrough
|
25 |
+
from langchain_core.language_models import BaseChatModel
|
26 |
+
from langchain_anthropic.chat_models import convert_to_anthropic_tool
|
27 |
+
import os
|
28 |
+
# Import actual tools from the tools directory
|
29 |
+
from auto_causal.tools.input_parser_tool import input_parser_tool
|
30 |
+
from auto_causal.tools.dataset_analyzer_tool import dataset_analyzer_tool
|
31 |
+
from auto_causal.tools.query_interpreter_tool import query_interpreter_tool
|
32 |
+
from auto_causal.tools.method_selector_tool import method_selector_tool
|
33 |
+
from auto_causal.tools.method_validator_tool import method_validator_tool
|
34 |
+
from auto_causal.tools.method_executor_tool import method_executor_tool
|
35 |
+
from auto_causal.tools.explanation_generator_tool import explanation_generator_tool
|
36 |
+
from auto_causal.tools.output_formatter_tool import output_formatter_tool
|
37 |
+
#from auto_causal.prompts import SYSTEM_PROMPT # Assuming SYSTEM_PROMPT is defined here or imported
|
38 |
+
from langchain_core.output_parsers import StrOutputParser
|
39 |
+
# Import the centralized factory function
|
40 |
+
from .config import get_llm_client
|
41 |
+
#from .prompts import SYSTEM_PROMPT
|
42 |
+
from langchain_core.messages import AIMessage, AIMessageChunk
|
43 |
+
import re
|
44 |
+
import json
|
45 |
+
from typing import Union
|
46 |
+
from langchain_core.output_parsers import BaseOutputParser
|
47 |
+
from langchain.schema import AgentAction, AgentFinish
|
48 |
+
from langchain_anthropic.output_parsers import ToolsOutputParser
|
49 |
+
from langchain.agents.react.output_parser import ReActOutputParser
|
50 |
+
from langchain.agents import AgentOutputParser
|
51 |
+
from langchain.agents.agent import AgentAction, AgentFinish, OutputParserException
|
52 |
+
import re
|
53 |
+
from typing import Union, List
|
54 |
+
from auto_causal.models import *
|
55 |
+
|
56 |
+
from langchain_core.agents import AgentAction, AgentFinish
|
57 |
+
from langchain_core.exceptions import OutputParserException
|
58 |
+
|
59 |
+
from langchain.agents.agent import AgentOutputParser
|
60 |
+
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
61 |
+
|
62 |
+
FINAL_ANSWER_ACTION = "Final Answer:"
|
63 |
+
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
|
64 |
+
"Invalid Format: Missing 'Action:' after 'Thought:'"
|
65 |
+
)
|
66 |
+
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE = (
|
67 |
+
"Invalid Format: Missing 'Action Input:' after 'Action:'"
|
68 |
+
)
|
69 |
+
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
|
70 |
+
"Parsing LLM output produced both a final answer and parse-able actions"
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
class ReActMultiInputOutputParser(AgentOutputParser):
|
75 |
+
"""Parses ReAct-style output that may contain multiple tool calls."""
|
76 |
+
|
77 |
+
def get_format_instructions(self) -> str:
|
78 |
+
# You can reuse the original FORMAT_INSTRUCTIONS,
|
79 |
+
# but let the model know it may emit multiple actions.
|
80 |
+
return FORMAT_INSTRUCTIONS + (
|
81 |
+
"\n\nIf you need to call more than one tool, simply repeat:\n"
|
82 |
+
"Action: <tool_name>\n"
|
83 |
+
"Action Input: <json or text>\n"
|
84 |
+
"…for each tool in sequence."
|
85 |
+
)
|
86 |
+
|
87 |
+
@property
|
88 |
+
def _type(self) -> str:
|
89 |
+
return "react-multi-input"
|
90 |
+
|
91 |
+
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
92 |
+
includes_answer = FINAL_ANSWER_ACTION in text
|
93 |
+
print('-------------------')
|
94 |
+
print(text)
|
95 |
+
print('-------------------')
|
96 |
+
# Grab every Action / Action Input block
|
97 |
+
pattern = (
|
98 |
+
r"Action\s*\d*\s*:[\s]*(.*?)\s*"
|
99 |
+
r"Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*?)(?=(?:Action\s*\d*\s*:|$))"
|
100 |
+
)
|
101 |
+
matches = list(re.finditer(pattern, text, re.DOTALL))
|
102 |
+
|
103 |
+
# If we found tool calls…
|
104 |
+
if matches:
|
105 |
+
if includes_answer:
|
106 |
+
# both a final answer *and* tool calls is ambiguous
|
107 |
+
raise OutputParserException(
|
108 |
+
f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
|
109 |
+
)
|
110 |
+
|
111 |
+
actions: List[AgentAction] = []
|
112 |
+
for m in matches:
|
113 |
+
tool_name = m.group(1).strip()
|
114 |
+
tool_input = m.group(2).strip().strip('"')
|
115 |
+
print('\n--------------------------')
|
116 |
+
print(tool_input)
|
117 |
+
print('--------------------------')
|
118 |
+
actions.append(AgentAction(tool_name, json.loads(tool_input), text))
|
119 |
+
|
120 |
+
return actions
|
121 |
+
|
122 |
+
# Otherwise, if there's a final answer, finish
|
123 |
+
if includes_answer:
|
124 |
+
answer = text.split(FINAL_ANSWER_ACTION, 1)[1].strip()
|
125 |
+
return AgentFinish({"output": answer}, text)
|
126 |
+
|
127 |
+
# No calls and no final answer → figure out which error to throw
|
128 |
+
if not re.search(r"Action\s*\d*\s*Input\s*\d*:", text):
|
129 |
+
raise OutputParserException(
|
130 |
+
f"Could not parse LLM output: `{text}`",
|
131 |
+
observation=MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
132 |
+
llm_output=text,
|
133 |
+
send_to_llm=True,
|
134 |
+
)
|
135 |
+
|
136 |
+
# Fallback
|
137 |
+
raise OutputParserException(f"Could not parse LLM output: `{text}`")
|
138 |
+
|
139 |
+
# Set up basic logging
|
140 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
141 |
+
logger = logging.getLogger(__name__)
|
142 |
+
|
143 |
+
|
144 |
+
def create_agent_prompt(tools: List[tool]) -> ChatPromptTemplate:
|
145 |
+
"""Create the prompt template for the causal inference agent, emphasizing workflow and data handoff.
|
146 |
+
(This is the version required by the LCEL agent structure below)
|
147 |
+
"""
|
148 |
+
# Get the tool descriptions
|
149 |
+
tool_description = render_text_description(tools)
|
150 |
+
tool_names = ", ".join([t.name for t in tools])
|
151 |
+
|
152 |
+
# Define the system prompt template string
|
153 |
+
system_template = """
|
154 |
+
You are a causal inference expert helping users answer causal questions by following a strict workflow using specialized tools.
|
155 |
+
|
156 |
+
Remember you always have to always generate the Thought, Action and Action Input block.
|
157 |
+
TOOLS:
|
158 |
+
------
|
159 |
+
You have access to the following tools:
|
160 |
+
|
161 |
+
{tools}
|
162 |
+
|
163 |
+
To use a tool, please use the following format:
|
164 |
+
|
165 |
+
Thought: Do I need to use a tool? Yes
|
166 |
+
Action: the action to take, should be one of [{tool_names}]
|
167 |
+
Action Input: the input to the action, as a single, valid JSON object string. Check the tool definition for required arguments and structure.
|
168 |
+
Observation: the result of the action, often containing structured data like 'variables', 'dataset_analysis', 'method_info', etc.
|
169 |
+
|
170 |
+
When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
|
171 |
+
|
172 |
+
Thought: Do I need to use a tool? No
|
173 |
+
Final Answer: [your response here]
|
174 |
+
|
175 |
+
DO NOT UNDER ANY CIRCUMSTANCE CALL MORE THAN ONE TOOL IN A STEP
|
176 |
+
|
177 |
+
**IMPORTANT TOOL USAGE:**
|
178 |
+
1. **Action Input Format:** The value for 'Action Input' MUST be a single, valid JSON object string. Do NOT include any other text or formatting around the JSON string.
|
179 |
+
2. **Argument Gathering:** You MUST gather ALL required arguments for the Action Input JSON from the initial Human input AND the 'Observation' outputs of PREVIOUS steps. Look carefully at the required arguments for the tool you are calling.
|
180 |
+
3. **Data Handoff:** The 'Observation' from a previous step often contains structured data needed by the next tool. For example, the 'variables' output from `query_interpreter_tool` contains fields like `treatment_variable`, `outcome_variable`, `covariates`, `time_variable`, `instrument_variable`, `running_variable`, `cutoff_value`, and `is_rct`. When calling `method_selector_tool`, you MUST construct its required `variables` input argument by including **ALL** these relevant fields identified by the `query_interpreter_tool` in the previous Observation. Similarly, pass the full `dataset_analysis`, `dataset_description`, and `original_query` when required by the next tool.
|
181 |
+
|
182 |
+
IMPORTANT WORKFLOW:
|
183 |
+
-------------------
|
184 |
+
You must follow this exact workflow, selecting the appropriate tool for each step:
|
185 |
+
|
186 |
+
1. ALWAYS start with `input_parser_tool` to understand the query
|
187 |
+
2. THEN use `dataset_analyzer_tool` to analyze the dataset
|
188 |
+
3. THEN use `query_interpreter_tool` to identify variables (output includes `variables` and `dataset_analysis`)
|
189 |
+
4. THEN use `method_selector_tool` (input requires `variables` and `dataset_analysis` from previous step)
|
190 |
+
5. THEN use `method_validator_tool` (input requires `method_info` and `variables` from previous step)
|
191 |
+
6. THEN use `method_executor_tool` (input requires `method`, `variables`, `dataset_path`)
|
192 |
+
7. THEN use `explanation_generator_tool` (input requires results, method_info, variables, etc.)
|
193 |
+
8. FINALLY use `output_formatter_tool` to return the results
|
194 |
+
|
195 |
+
REASONING PROCESS:
|
196 |
+
------------------
|
197 |
+
EXPLICITLY REASON about:
|
198 |
+
1. What step you're currently on (based on previous tool's Observation)
|
199 |
+
2. Why you're selecting a particular tool (should follow the workflow)
|
200 |
+
3. How the output of the previous tool (especially structured data like `variables`, `dataset_analysis`, `method_info`) informs the inputs required for the current tool.
|
201 |
+
|
202 |
+
IMPORTANT RULES:
|
203 |
+
1. Do not make more than one tool call in a single step.
|
204 |
+
2. Do not include ``` in your output at all.
|
205 |
+
3. Don't use action names like default_api.dataset_analyzer_tool, instead use tool names like dataset_analyzer_tool.
|
206 |
+
4. Always start, action, and observation with a new line.
|
207 |
+
5. Don't use '\\' before double quotes
|
208 |
+
6. Don't include ```json for Action Input. Also ensure that Action Input is a valid json. DO no add any text after Action Iput.
|
209 |
+
7. You have to always choose one of the tools unless it's the final answer.
|
210 |
+
Begin!
|
211 |
+
"""
|
212 |
+
|
213 |
+
# Create the prompt template
|
214 |
+
prompt = ChatPromptTemplate.from_messages([
|
215 |
+
("system", system_template),
|
216 |
+
MessagesPlaceholder("chat_history", optional=True), # Use MessagesPlaceholder
|
217 |
+
# MessagesPlaceholder("agent_scratchpad"),
|
218 |
+
|
219 |
+
("human", "{input}\n Thought:{agent_scratchpad}"),
|
220 |
+
# ("ai", "{agent_scratchpad}"),
|
221 |
+
# MessagesPlaceholder("agent_scratchpad" ), # Use MessagesPlaceholder
|
222 |
+
# "agent_scratchpad"
|
223 |
+
])
|
224 |
+
return prompt
|
225 |
+
|
226 |
+
def create_causal_agent(llm: BaseChatModel) -> AgentExecutor:
|
227 |
+
"""
|
228 |
+
Create and configure the LangChain agent with causal inference tools.
|
229 |
+
(Using explicit LCEL construction, compatible with shared LLM client)
|
230 |
+
"""
|
231 |
+
# Define tools available to the agent
|
232 |
+
agent_tools = [
|
233 |
+
input_parser_tool,
|
234 |
+
dataset_analyzer_tool,
|
235 |
+
query_interpreter_tool,
|
236 |
+
method_selector_tool,
|
237 |
+
method_validator_tool,
|
238 |
+
method_executor_tool,
|
239 |
+
explanation_generator_tool,
|
240 |
+
output_formatter_tool
|
241 |
+
]
|
242 |
+
# anthropic_agent_tools = [ convert_to_anthropic_tool(anthropic_tool) for anthropic_tool in agent_tools]
|
243 |
+
# Create the prompt using the helper
|
244 |
+
prompt = create_agent_prompt(agent_tools)
|
245 |
+
# Bind tools to the LLM (using the passed shared instance)
|
246 |
+
|
247 |
+
|
248 |
+
# Create memory
|
249 |
+
# Consider if memory needs to be passed in or created here
|
250 |
+
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
251 |
+
|
252 |
+
# Manually construct the agent runnable using LCEL
|
253 |
+
from langchain_anthropic.output_parsers import ToolsOutputParser
|
254 |
+
from langchain.agents.output_parsers.json import JSONAgentOutputParser
|
255 |
+
# from langchain.agents.react.output_parser import MultiActionAgentOutputParsers ReActMultiInputOutputParser
|
256 |
+
provider = os.getenv("LLM_PROVIDER", "openai")
|
257 |
+
if provider == "gemini":
|
258 |
+
base_parser=ReActMultiInputOutputParser()
|
259 |
+
llm_with_tools = llm.bind_tools(agent_tools)
|
260 |
+
else:
|
261 |
+
base_parser=ToolsAgentOutputParser()
|
262 |
+
llm_with_tools = llm.bind_tools(agent_tools, tool_choice="any")
|
263 |
+
agent = create_react_agent(llm_with_tools, agent_tools, prompt, output_parser=base_parser)
|
264 |
+
|
265 |
+
|
266 |
+
# Create executor (should now work with the manually constructed agent)
|
267 |
+
executor = AgentExecutor(
|
268 |
+
agent=agent,
|
269 |
+
tools=agent_tools,
|
270 |
+
memory=memory, # Pass the memory object
|
271 |
+
verbose=True,
|
272 |
+
callbacks=[ConsoleCallbackHandler()], # Optional: for console debugging
|
273 |
+
handle_parsing_errors=True, # Let AE handle parsing errors
|
274 |
+
max_retries = 100
|
275 |
+
)
|
276 |
+
|
277 |
+
return executor
|
278 |
+
|
279 |
+
def run_causal_analysis(query: str, dataset_path: str,
|
280 |
+
dataset_description: Optional[str] = None,
|
281 |
+
api_key: Optional[str] = None) -> Dict[str, Any]:
|
282 |
+
"""
|
283 |
+
Run causal analysis on a dataset based on a user query.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
query: User's causal question
|
287 |
+
dataset_path: Path to the dataset
|
288 |
+
dataset_description: Optional textual description of the dataset
|
289 |
+
api_key: Optional OpenAI API key (DEPRECATED - will be ignored)
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
Dictionary containing the final formatted analysis results from the agent's last step.
|
293 |
+
"""
|
294 |
+
# Log the start of the analysis
|
295 |
+
logger.info("Starting causal analysis run...")
|
296 |
+
|
297 |
+
try:
|
298 |
+
# --- Instantiate the shared LLM client ---
|
299 |
+
model_name = os.getenv("LLM_MODEL", "gpt-4")
|
300 |
+
if model_name in ['o3', 'o4-mini', 'o3-mini']:
|
301 |
+
print('-------------------------')
|
302 |
+
shared_llm = get_llm_client()
|
303 |
+
else:
|
304 |
+
shared_llm = get_llm_client(temperature=0) # Or read provider/model from env
|
305 |
+
|
306 |
+
# --- Dependency Injection Note (REMAINS RELEVANT) ---
|
307 |
+
# If tools need the LLM, they must be adapted. Example using partial:
|
308 |
+
# from functools import partial
|
309 |
+
# from .components import input_parser
|
310 |
+
# # Assume input_parser.parse_input needs llm
|
311 |
+
# input_parser_tool_with_llm = tool(partial(input_parser.parse_input, llm=shared_llm))
|
312 |
+
# Use input_parser_tool_with_llm in the tools list passed to the agent below.
|
313 |
+
# Similar adjustments needed for decision_tree._recommend_ps_method if used.
|
314 |
+
# --- End Note ---
|
315 |
+
|
316 |
+
# --- Create agent using the shared LLM ---
|
317 |
+
# agent_executor = create_causal_agent(shared_llm)
|
318 |
+
|
319 |
+
# Construct input, including description if available
|
320 |
+
# IMPORTANT: Agent now expects 'input' and potentially 'chat_history'
|
321 |
+
# The input needs to contain all initial info the first tool might need.
|
322 |
+
input_text = f"My question is: {query}\n"
|
323 |
+
input_text += f"The dataset is located at: {dataset_path}\n"
|
324 |
+
if dataset_description:
|
325 |
+
input_text += f"Dataset Description: {dataset_description}\n"
|
326 |
+
input_text += "Please perform the causal analysis following the workflow."
|
327 |
+
|
328 |
+
# Log the constructed input text
|
329 |
+
logger.info(f"Constructed input for agent: \n{input_text}")
|
330 |
+
|
331 |
+
input_parsing_result = input_parser_tool(input_text)
|
332 |
+
dataset_analysis_result = dataset_analyzer_tool.func(dataset_path=input_parsing_result["dataset_path"], dataset_description=input_parsing_result["dataset_description"], original_query=input_parsing_result["original_query"]).analysis_results
|
333 |
+
query_info = QueryInfo(
|
334 |
+
query_text=input_parsing_result["original_query"],
|
335 |
+
potential_treatments=input_parsing_result["extracted_variables"].get("treatment"),
|
336 |
+
potential_outcomes=input_parsing_result["extracted_variables"].get("outcome"),
|
337 |
+
covariates_hints=input_parsing_result["extracted_variables"].get("covariates_mentioned"),
|
338 |
+
instrument_hints=input_parsing_result["extracted_variables"].get("instruments_mentioned")
|
339 |
+
)
|
340 |
+
|
341 |
+
query_interpreter_output = query_interpreter_tool.func(query_info=query_info, dataset_analysis=dataset_analysis_result, dataset_description=input_parsing_result["dataset_description"], original_query = input_parsing_result["original_query"]).variables
|
342 |
+
method_selector_output = method_selector_tool.func(variables=query_interpreter_output,
|
343 |
+
dataset_analysis=dataset_analysis_result,
|
344 |
+
dataset_description=input_parsing_result["dataset_description"],
|
345 |
+
original_query = input_parsing_result["original_query"],
|
346 |
+
excluded_methods=None)
|
347 |
+
method_info = MethodInfo(
|
348 |
+
**method_selector_output['method_info']
|
349 |
+
)
|
350 |
+
method_validator_input = MethodValidatorInput(
|
351 |
+
method_info=method_info,
|
352 |
+
variables=query_interpreter_output,
|
353 |
+
dataset_analysis=dataset_analysis_result,
|
354 |
+
dataset_description=input_parsing_result["dataset_description"],
|
355 |
+
original_query = input_parsing_result["original_query"]
|
356 |
+
)
|
357 |
+
method_validator_output = method_validator_tool.func(method_validator_input)
|
358 |
+
method_executor_input = MethodExecutorInput(
|
359 |
+
**method_validator_output
|
360 |
+
)
|
361 |
+
method_executor_output = method_executor_tool.func(method_executor_input, original_query = input_parsing_result["original_query"])
|
362 |
+
|
363 |
+
explainer_output = explanation_generator_tool.func( method_info=method_info,
|
364 |
+
validation_info=method_validator_output,
|
365 |
+
variables=query_interpreter_output,
|
366 |
+
results=method_executor_output,
|
367 |
+
dataset_analysis=dataset_analysis_result,
|
368 |
+
dataset_description=input_parsing_result["dataset_description"],
|
369 |
+
original_query = input_parsing_result["original_query"])
|
370 |
+
result = explainer_output
|
371 |
+
result['results']['results']["method_used"] = method_validator_output['method']
|
372 |
+
logger.info(result)
|
373 |
+
logger.info("Causal analysis run finished.")
|
374 |
+
|
375 |
+
# Ensure result is a dict and extract the 'output' part
|
376 |
+
if isinstance(result, dict):
|
377 |
+
final_output = result
|
378 |
+
if isinstance(final_output, dict):
|
379 |
+
return final_output # Return only the dictionary from the final tool
|
380 |
+
else:
|
381 |
+
logger.error(f"Agent result['output'] was not a dictionary: {type(final_output)}. Returning error dict.")
|
382 |
+
return {"error": "Agent did not produce the expected dictionary output in the 'output' key.", "raw_agent_result": result}
|
383 |
+
else:
|
384 |
+
logger.error(f"Agent returned non-dict type: {type(result)}. Returning error dict.")
|
385 |
+
return {"error": "Agent did not return expected dictionary output.", "raw_output": str(result)}
|
386 |
+
|
387 |
+
except ValueError as e:
|
388 |
+
logger.error(f"Configuration Error: {e}")
|
389 |
+
# Return an error dictionary in case of exception too
|
390 |
+
return {"error": f"Error: Configuration issue - {e}"} # Ensure consistent error return type
|
391 |
+
except Exception as e:
|
392 |
+
logger.error(f"An unexpected error occurred during causal analysis: {e}", exc_info=True)
|
393 |
+
# Return an error dictionary in case of exception too
|
394 |
+
return {"error": f"An unexpected error occurred: {e}"}
|
auto_causal/components/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Auto Causal components package.
|
3 |
+
|
4 |
+
This package contains the core components for the auto_causal module,
|
5 |
+
each handling a specific part of the causal inference workflow.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from auto_causal.components.input_parser import parse_input
|
9 |
+
from auto_causal.components.dataset_analyzer import analyze_dataset
|
10 |
+
from auto_causal.components.query_interpreter import interpret_query
|
11 |
+
from auto_causal.components.decision_tree import select_method
|
12 |
+
from auto_causal.components.method_validator import validate_method
|
13 |
+
from auto_causal.components.explanation_generator import generate_explanation
|
14 |
+
from auto_causal.components.output_formatter import format_output
|
15 |
+
from auto_causal.components.state_manager import create_workflow_state_update
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"parse_input",
|
19 |
+
"analyze_dataset",
|
20 |
+
"interpret_query",
|
21 |
+
"select_method",
|
22 |
+
"validate_method",
|
23 |
+
"generate_explanation",
|
24 |
+
"format_output",
|
25 |
+
"create_workflow_state_update"
|
26 |
+
]
|
27 |
+
|
28 |
+
# This file makes Python treat the directory as a package.
|
auto_causal/components/dataset_analyzer.py
ADDED
@@ -0,0 +1,853 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dataset analyzer component for causal inference.
|
3 |
+
|
4 |
+
This module provides functionality to analyze datasets to detect characteristics
|
5 |
+
relevant for causal inference methods, including temporal structure, potential
|
6 |
+
instrumental variables, discontinuities, and variable relationships.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import pandas as pd
|
11 |
+
import numpy as np
|
12 |
+
from typing import Dict, List, Any, Optional, Tuple
|
13 |
+
from scipy import stats
|
14 |
+
import logging
|
15 |
+
import json
|
16 |
+
from langchain_core.language_models import BaseChatModel
|
17 |
+
from auto_causal.utils.llm_helpers import llm_identify_temporal_and_unit_vars
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
def _calculate_per_group_stats(df: pd.DataFrame, potential_treatments: List[str]) -> Dict[str, Dict]:
|
22 |
+
"""Calculates summary stats for numeric covariates grouped by potential binary treatments."""
|
23 |
+
stats_dict = {}
|
24 |
+
numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
|
25 |
+
|
26 |
+
for treat_var in potential_treatments:
|
27 |
+
if treat_var not in df.columns:
|
28 |
+
logger.warning(f"Potential treatment '{treat_var}' not found in DataFrame columns.")
|
29 |
+
continue
|
30 |
+
|
31 |
+
# Ensure treatment is binary (0/1 or similar)
|
32 |
+
unique_vals = df[treat_var].dropna().unique()
|
33 |
+
if len(unique_vals) != 2:
|
34 |
+
logger.info(f"Skipping stats for potential treatment '{treat_var}' as it is not binary ({len(unique_vals)} unique values).")
|
35 |
+
continue
|
36 |
+
|
37 |
+
# Attempt to map values to 0 and 1 if possible
|
38 |
+
try:
|
39 |
+
# Ensure boolean is converted to int
|
40 |
+
if df[treat_var].dtype == 'bool':
|
41 |
+
df[treat_var] = df[treat_var].astype(int)
|
42 |
+
unique_vals = df[treat_var].dropna().unique()
|
43 |
+
|
44 |
+
# Basic check if values are interpretable as 0/1
|
45 |
+
if not set(unique_vals).issubset({0, 1}):
|
46 |
+
# Attempt conversion if possible (e.g., True/False strings?)
|
47 |
+
logger.warning(f"Potential treatment '{treat_var}' has values {unique_vals}, not {0, 1}. Cannot calculate group stats reliably.")
|
48 |
+
continue
|
49 |
+
except Exception as e:
|
50 |
+
logger.warning(f"Could not process potential treatment '{treat_var}' values ({unique_vals}): {e}")
|
51 |
+
continue
|
52 |
+
|
53 |
+
logger.info(f"Calculating group stats for treatment: '{treat_var}'")
|
54 |
+
treat_stats = {'group_sizes': {}, 'covariate_stats': {}}
|
55 |
+
|
56 |
+
try:
|
57 |
+
grouped = df.groupby(treat_var)
|
58 |
+
sizes = grouped.size()
|
59 |
+
treat_stats['group_sizes']['treated'] = int(sizes.get(1, 0))
|
60 |
+
treat_stats['group_sizes']['control'] = int(sizes.get(0, 0))
|
61 |
+
|
62 |
+
if treat_stats['group_sizes']['treated'] == 0 or treat_stats['group_sizes']['control'] == 0:
|
63 |
+
logger.warning(f"Treatment '{treat_var}' has zero samples in one group. Skipping covariate stats.")
|
64 |
+
stats_dict[treat_var] = treat_stats
|
65 |
+
continue
|
66 |
+
|
67 |
+
# Calculate mean and std for numeric covariates
|
68 |
+
cov_stats = grouped[numeric_cols].agg(['mean', 'std']).unstack()
|
69 |
+
|
70 |
+
for cov in numeric_cols:
|
71 |
+
if cov == treat_var: continue # Skip treatment variable itself
|
72 |
+
|
73 |
+
mean_control = cov_stats.get(('mean', 0, cov), np.nan)
|
74 |
+
std_control = cov_stats.get(('std', 0, cov), np.nan)
|
75 |
+
mean_treated = cov_stats.get(('mean', 1, cov), np.nan)
|
76 |
+
std_treated = cov_stats.get(('std', 1, cov), np.nan)
|
77 |
+
|
78 |
+
treat_stats['covariate_stats'][cov] = {
|
79 |
+
'mean_control': float(mean_control) if pd.notna(mean_control) else None,
|
80 |
+
'std_control': float(std_control) if pd.notna(std_control) else None,
|
81 |
+
'mean_treat': float(mean_treated) if pd.notna(mean_treated) else None,
|
82 |
+
'std_treat': float(std_treated) if pd.notna(std_treated) else None,
|
83 |
+
}
|
84 |
+
stats_dict[treat_var] = treat_stats
|
85 |
+
except Exception as e:
|
86 |
+
logger.error(f"Error calculating stats for treatment '{treat_var}': {e}", exc_info=True)
|
87 |
+
# Store partial info if possible
|
88 |
+
if treat_var not in stats_dict:
|
89 |
+
stats_dict[treat_var] = {'error': str(e)}
|
90 |
+
elif 'error' not in stats_dict[treat_var]:
|
91 |
+
stats_dict[treat_var]['error'] = str(e)
|
92 |
+
|
93 |
+
return stats_dict
|
94 |
+
|
95 |
+
def analyze_dataset(
|
96 |
+
dataset_path: str,
|
97 |
+
llm_client: Optional[BaseChatModel] = None,
|
98 |
+
dataset_description: Optional[str] = None,
|
99 |
+
original_query: Optional[str] = None
|
100 |
+
) -> Dict[str, Any]:
|
101 |
+
"""
|
102 |
+
Analyze a dataset to identify important characteristics for causal inference.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
dataset_path: Path to the dataset file
|
106 |
+
llm_client: Optional LLM client for enhanced analysis
|
107 |
+
dataset_description: Optional description of the dataset for context
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
Dict containing dataset analysis results:
|
111 |
+
- dataset_info: Basic information about the dataset
|
112 |
+
- columns: List of column names
|
113 |
+
- potential_treatments: List of potential treatment variables (possibly LLM augmented)
|
114 |
+
- potential_outcomes: List of potential outcome variables (possibly LLM augmented)
|
115 |
+
- temporal_structure_detected: Whether temporal structure was detected
|
116 |
+
- panel_data_detected: Whether panel data structure was detected
|
117 |
+
- potential_instruments_detected: Whether potential instruments were detected
|
118 |
+
- discontinuities_detected: Whether discontinuities were detected
|
119 |
+
- llm_augmentation: Status of LLM augmentation if used
|
120 |
+
"""
|
121 |
+
llm_augmentation = "Not used" if not llm_client else "Initialized"
|
122 |
+
|
123 |
+
# Check if file exists
|
124 |
+
if not os.path.exists(dataset_path):
|
125 |
+
logger.error(f"Dataset file not found at {dataset_path}")
|
126 |
+
return {"error": f"Dataset file not found at {dataset_path}"}
|
127 |
+
|
128 |
+
try:
|
129 |
+
# Load the dataset
|
130 |
+
df = pd.read_csv(dataset_path)
|
131 |
+
|
132 |
+
# Basic dataset information
|
133 |
+
sample_size = len(df)
|
134 |
+
columns_list = df.columns.tolist()
|
135 |
+
num_covariates = len(columns_list) - 2 # Rough estimate (total - T - Y)
|
136 |
+
dataset_info = {
|
137 |
+
"num_rows": sample_size,
|
138 |
+
"num_columns": len(columns_list),
|
139 |
+
"file_path": dataset_path,
|
140 |
+
"file_name": os.path.basename(dataset_path)
|
141 |
+
}
|
142 |
+
|
143 |
+
# --- Detailed Analysis (Keep internal) ---
|
144 |
+
column_types_detailed = {col: str(df[col].dtype) for col in df.columns}
|
145 |
+
missing_values_detailed = df.isnull().sum().to_dict()
|
146 |
+
column_categories_detailed = _categorize_columns(df)
|
147 |
+
column_nunique_counts_detailed = {col: df[col].nunique() for col in df.columns} # Calculate nunique
|
148 |
+
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
149 |
+
correlations_detailed = df[numeric_cols].corr() if numeric_cols else pd.DataFrame()
|
150 |
+
temporal_structure_detailed = detect_temporal_structure(df, llm_client, dataset_description, original_query)
|
151 |
+
|
152 |
+
# First, identify potential treatment and outcome variables
|
153 |
+
potential_variables = _identify_potential_variables(
|
154 |
+
df,
|
155 |
+
column_categories_detailed,
|
156 |
+
llm_client=llm_client,
|
157 |
+
dataset_description=dataset_description
|
158 |
+
)
|
159 |
+
|
160 |
+
if llm_client:
|
161 |
+
llm_augmentation = "Used for variable identification"
|
162 |
+
|
163 |
+
# Then use that info to help find potential instrumental variables
|
164 |
+
potential_instruments_detailed = find_potential_instruments(
|
165 |
+
df,
|
166 |
+
llm_client=llm_client,
|
167 |
+
potential_treatments=potential_variables.get("potential_treatments", []),
|
168 |
+
potential_outcomes=potential_variables.get("potential_outcomes", []),
|
169 |
+
dataset_description=dataset_description
|
170 |
+
)
|
171 |
+
|
172 |
+
# Other analyses
|
173 |
+
discontinuities_detailed = detect_discontinuities(df)
|
174 |
+
variable_relationships_detailed = assess_variable_relationships(df, correlations_detailed)
|
175 |
+
|
176 |
+
# Calculate per-group stats for potential binary treatments
|
177 |
+
potential_binary_treatments = [
|
178 |
+
t for t in potential_variables["potential_treatments"]
|
179 |
+
if column_categories_detailed.get(t) == 'binary'
|
180 |
+
or column_categories_detailed.get(t) == 'binary_categorical'
|
181 |
+
]
|
182 |
+
per_group_stats = _calculate_per_group_stats(df.copy(), potential_binary_treatments)
|
183 |
+
|
184 |
+
# --- Summarized Analysis (For Output) ---
|
185 |
+
|
186 |
+
# Get boolean flags and essential lists
|
187 |
+
has_temporal = temporal_structure_detailed.get("has_temporal_structure", False)
|
188 |
+
is_panel = temporal_structure_detailed.get("is_panel_data", False)
|
189 |
+
logger.info(f"iv is {potential_instruments_detailed}")
|
190 |
+
has_instruments = len(potential_instruments_detailed) > 0
|
191 |
+
has_discontinuities = discontinuities_detailed.get("has_discontinuities", False)
|
192 |
+
|
193 |
+
# --- Extract only instrument names for the final output ---
|
194 |
+
potential_instrument_names = [
|
195 |
+
inst_dict.get('variable')
|
196 |
+
for inst_dict in potential_instruments_detailed
|
197 |
+
if isinstance(inst_dict, dict) and 'variable' in inst_dict
|
198 |
+
]
|
199 |
+
logger.info(f"iv is {potential_instrument_names}")
|
200 |
+
# --- Final Output Dictionary (Highly Summarized) ---
|
201 |
+
return {
|
202 |
+
"dataset_info": dataset_info, # Keep basic info
|
203 |
+
"columns": columns_list,
|
204 |
+
"potential_treatments": potential_variables["potential_treatments"],
|
205 |
+
"potential_outcomes": potential_variables["potential_outcomes"],
|
206 |
+
# Return concise flags instead of detailed dicts/lists
|
207 |
+
"temporal_structure_detected": has_temporal,
|
208 |
+
"panel_data_detected": is_panel,
|
209 |
+
"potential_instruments_detected": has_instruments,
|
210 |
+
"discontinuities_detected": has_discontinuities,
|
211 |
+
# Use the extracted list of names here
|
212 |
+
"potential_instruments": potential_instrument_names,
|
213 |
+
"discontinuities": discontinuities_detailed,
|
214 |
+
"temporal_structure": temporal_structure_detailed,
|
215 |
+
"column_categories": column_categories_detailed,
|
216 |
+
"column_nunique_counts": column_nunique_counts_detailed, # Add nunique counts to output
|
217 |
+
"sample_size": sample_size,
|
218 |
+
"num_covariates_estimate": num_covariates,
|
219 |
+
"llm_augmentation": llm_augmentation
|
220 |
+
}
|
221 |
+
|
222 |
+
except Exception as e:
|
223 |
+
logger.error(f"Error analyzing dataset '{dataset_path}': {e}", exc_info=True)
|
224 |
+
return {
|
225 |
+
"error": f"Error analyzing dataset: {str(e)}",
|
226 |
+
"llm_augmentation": llm_augmentation
|
227 |
+
}
|
228 |
+
|
229 |
+
|
230 |
+
def _categorize_columns(df: pd.DataFrame) -> Dict[str, str]:
|
231 |
+
"""
|
232 |
+
Categorize columns into types relevant for causal inference.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
df: DataFrame to analyze
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
Dict mapping column names to their types
|
239 |
+
"""
|
240 |
+
result = {}
|
241 |
+
|
242 |
+
for col in df.columns:
|
243 |
+
# Check if column is numeric
|
244 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
245 |
+
# Count number of unique values
|
246 |
+
n_unique = df[col].nunique()
|
247 |
+
|
248 |
+
# Binary numeric variable
|
249 |
+
if n_unique == 2:
|
250 |
+
result[col] = "binary"
|
251 |
+
# Likely categorical represented as numeric
|
252 |
+
elif n_unique < 10:
|
253 |
+
result[col] = "categorical_numeric"
|
254 |
+
# Discrete numeric (integers)
|
255 |
+
elif pd.api.types.is_integer_dtype(df[col]):
|
256 |
+
result[col] = "discrete_numeric"
|
257 |
+
# Continuous numeric
|
258 |
+
else:
|
259 |
+
result[col] = "continuous_numeric"
|
260 |
+
|
261 |
+
# Check for datetime
|
262 |
+
elif pd.api.types.is_datetime64_any_dtype(df[col]) or _is_date_string(df, col):
|
263 |
+
result[col] = "datetime"
|
264 |
+
|
265 |
+
# Check for categorical
|
266 |
+
elif pd.api.types.is_categorical_dtype(df[col]) or df[col].nunique() < 20:
|
267 |
+
if df[col].nunique() == 2:
|
268 |
+
result[col] = "binary_categorical"
|
269 |
+
else:
|
270 |
+
result[col] = "categorical"
|
271 |
+
|
272 |
+
# Must be text or other
|
273 |
+
else:
|
274 |
+
result[col] = "text_or_other"
|
275 |
+
|
276 |
+
return result
|
277 |
+
|
278 |
+
|
279 |
+
def _is_date_string(df: pd.DataFrame, col: str) -> bool:
|
280 |
+
"""
|
281 |
+
Check if a column contains date strings.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
df: DataFrame to check
|
285 |
+
col: Column name to check
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
True if the column appears to contain date strings
|
289 |
+
"""
|
290 |
+
# Try to convert to datetime
|
291 |
+
if not pd.api.types.is_string_dtype(df[col]):
|
292 |
+
return False
|
293 |
+
|
294 |
+
# Check sample of values
|
295 |
+
sample = df[col].dropna().sample(min(10, len(df[col].dropna()))).tolist()
|
296 |
+
|
297 |
+
try:
|
298 |
+
for val in sample:
|
299 |
+
pd.to_datetime(val)
|
300 |
+
return True
|
301 |
+
except:
|
302 |
+
return False
|
303 |
+
|
304 |
+
|
305 |
+
def _identify_potential_variables(
|
306 |
+
df: pd.DataFrame,
|
307 |
+
column_categories: Dict[str, str],
|
308 |
+
llm_client: Optional[BaseChatModel] = None,
|
309 |
+
dataset_description: Optional[str] = None
|
310 |
+
) -> Dict[str, List[str]]:
|
311 |
+
"""
|
312 |
+
Identify potential treatment and outcome variables in the dataset, using LLM if available.
|
313 |
+
Falls back to heuristic method if LLM fails or is not available.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
df: DataFrame to analyze
|
317 |
+
column_categories: Dictionary mapping column names to their types
|
318 |
+
llm_client: Optional LLM client for enhanced identification
|
319 |
+
dataset_description: Optional description of the dataset for context
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
Dict with potential treatment and outcome variables
|
323 |
+
"""
|
324 |
+
# Try LLM approach if client is provided
|
325 |
+
if llm_client:
|
326 |
+
try:
|
327 |
+
logger.info("Using LLM to identify potential treatment and outcome variables")
|
328 |
+
|
329 |
+
# Create a concise prompt with just column information
|
330 |
+
columns_list = df.columns.tolist()
|
331 |
+
column_types = {col: str(df[col].dtype) for col in columns_list}
|
332 |
+
|
333 |
+
# Get binary columns for extra context
|
334 |
+
binary_cols = [col for col in columns_list
|
335 |
+
if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() == 2]
|
336 |
+
|
337 |
+
# Add dataset description if available
|
338 |
+
description_text = f"\nDataset Description: {dataset_description}" if dataset_description else ""
|
339 |
+
|
340 |
+
prompt = f"""
|
341 |
+
You are an expert causal inference data scientist. Identify potential treatment and outcome variables from this dataset.{description_text}
|
342 |
+
|
343 |
+
Dataset columns:
|
344 |
+
{columns_list}
|
345 |
+
|
346 |
+
Column types:
|
347 |
+
{column_types}
|
348 |
+
|
349 |
+
Binary columns (good treatment candidates):
|
350 |
+
{binary_cols}
|
351 |
+
|
352 |
+
Instructions:
|
353 |
+
1. Identify TREATMENT variables: interventions, treatments, programs, policies, or binary state changes.
|
354 |
+
Look for binary variables or names with 'treatment', 'intervention', 'program', 'policy', etc.
|
355 |
+
|
356 |
+
2. Identify OUTCOME variables: results, effects, or responses to treatments.
|
357 |
+
Look for numeric variables (especially non-binary) or names with 'outcome', 'result', 'effect', 'score', etc.
|
358 |
+
|
359 |
+
Return ONLY a valid JSON object with two lists: "potential_treatments" and "potential_outcomes".
|
360 |
+
Example: {{"potential_treatments": ["treatment_a", "program_b"], "potential_outcomes": ["result_score", "outcome_measure"]}}
|
361 |
+
"""
|
362 |
+
|
363 |
+
# Call the LLM and parse the response
|
364 |
+
response = llm_client.invoke(prompt)
|
365 |
+
response_text = response.content if hasattr(response, 'content') else str(response)
|
366 |
+
|
367 |
+
# Extract JSON from the response text
|
368 |
+
import re
|
369 |
+
json_match = re.search(r'{.*}', response_text, re.DOTALL)
|
370 |
+
|
371 |
+
if json_match:
|
372 |
+
result = json.loads(json_match.group(0))
|
373 |
+
|
374 |
+
# Validate the response
|
375 |
+
if (isinstance(result, dict) and
|
376 |
+
"potential_treatments" in result and
|
377 |
+
"potential_outcomes" in result and
|
378 |
+
isinstance(result["potential_treatments"], list) and
|
379 |
+
isinstance(result["potential_outcomes"], list)):
|
380 |
+
|
381 |
+
# Ensure all suggestions are valid columns
|
382 |
+
valid_treatments = [col for col in result["potential_treatments"] if col in df.columns]
|
383 |
+
valid_outcomes = [col for col in result["potential_outcomes"] if col in df.columns]
|
384 |
+
|
385 |
+
if valid_treatments and valid_outcomes:
|
386 |
+
logger.info(f"LLM identified {len(valid_treatments)} treatments and {len(valid_outcomes)} outcomes")
|
387 |
+
return {
|
388 |
+
"potential_treatments": valid_treatments,
|
389 |
+
"potential_outcomes": valid_outcomes
|
390 |
+
}
|
391 |
+
else:
|
392 |
+
logger.warning("LLM suggested invalid columns, falling back to heuristic method")
|
393 |
+
else:
|
394 |
+
logger.warning("Invalid LLM response format, falling back to heuristic method")
|
395 |
+
else:
|
396 |
+
logger.warning("Could not extract JSON from LLM response, falling back to heuristic method")
|
397 |
+
|
398 |
+
except Exception as e:
|
399 |
+
logger.error(f"Error in LLM identification: {e}", exc_info=True)
|
400 |
+
logger.info("Falling back to heuristic method")
|
401 |
+
|
402 |
+
# Fallback to heuristic method
|
403 |
+
logger.info("Using heuristic method to identify potential treatment and outcome variables")
|
404 |
+
|
405 |
+
# Identify potential treatment variables
|
406 |
+
potential_treatments = []
|
407 |
+
|
408 |
+
# Look for binary variables (good treatment candidates)
|
409 |
+
binary_cols = [col for col in df.columns
|
410 |
+
if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() == 2]
|
411 |
+
|
412 |
+
# Look for variables with names suggesting treatment
|
413 |
+
treatment_keywords = ['treatment', 'treat', 'intervention', 'program', 'policy',
|
414 |
+
'exposed', 'assigned', 'received', 'participated']
|
415 |
+
|
416 |
+
for col in df.columns:
|
417 |
+
col_lower = col.lower()
|
418 |
+
if any(keyword in col_lower for keyword in treatment_keywords):
|
419 |
+
potential_treatments.append(col)
|
420 |
+
|
421 |
+
# Add binary variables if we don't have enough candidates
|
422 |
+
if len(potential_treatments) < 3:
|
423 |
+
for col in binary_cols:
|
424 |
+
if col not in potential_treatments:
|
425 |
+
potential_treatments.append(col)
|
426 |
+
if len(potential_treatments) >= 3:
|
427 |
+
break
|
428 |
+
|
429 |
+
# Identify potential outcome variables
|
430 |
+
potential_outcomes = []
|
431 |
+
|
432 |
+
# Look for numeric variables that aren't binary
|
433 |
+
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
434 |
+
non_binary_numeric = [col for col in numeric_cols if col not in binary_cols]
|
435 |
+
|
436 |
+
# Look for variables with names suggesting outcomes
|
437 |
+
outcome_keywords = ['outcome', 'result', 'effect', 'response', 'score', 'performance',
|
438 |
+
'achievement', 'success', 'failure', 'improvement']
|
439 |
+
|
440 |
+
for col in df.columns:
|
441 |
+
col_lower = col.lower()
|
442 |
+
if any(keyword in col_lower for keyword in outcome_keywords):
|
443 |
+
potential_outcomes.append(col)
|
444 |
+
|
445 |
+
# Add numeric non-binary variables if we don't have enough candidates
|
446 |
+
if len(potential_outcomes) < 3:
|
447 |
+
for col in non_binary_numeric:
|
448 |
+
if col not in potential_outcomes and col not in potential_treatments:
|
449 |
+
potential_outcomes.append(col)
|
450 |
+
if len(potential_outcomes) >= 3:
|
451 |
+
break
|
452 |
+
|
453 |
+
return {
|
454 |
+
"potential_treatments": potential_treatments,
|
455 |
+
"potential_outcomes": potential_outcomes
|
456 |
+
}
|
457 |
+
|
458 |
+
|
459 |
+
def detect_temporal_structure(
|
460 |
+
df: pd.DataFrame,
|
461 |
+
llm_client: Optional[BaseChatModel] = None,
|
462 |
+
dataset_description: Optional[str] = None,
|
463 |
+
original_query: Optional[str] = None
|
464 |
+
) -> Dict[str, Any]:
|
465 |
+
"""
|
466 |
+
Detect temporal structure in the dataset, using LLM for enhanced identification.
|
467 |
+
|
468 |
+
Args:
|
469 |
+
df: DataFrame to analyze
|
470 |
+
llm_client: Optional LLM client for enhanced identification
|
471 |
+
dataset_description: Optional description of the dataset for context
|
472 |
+
|
473 |
+
Returns:
|
474 |
+
Dict with information about temporal structure:
|
475 |
+
- has_temporal_structure: Whether temporal structure exists
|
476 |
+
- temporal_columns: Primary time column identified (or list if multiple from heuristic)
|
477 |
+
- is_panel_data: Whether data is in panel format
|
478 |
+
- time_column: Primary time column identified for panel data
|
479 |
+
- id_column: Primary unit ID column identified for panel data
|
480 |
+
- time_periods: Number of time periods (if panel data)
|
481 |
+
- units: Number of unique units (if panel data)
|
482 |
+
- identification_method: How time/unit vars were identified ('LLM', 'Heuristic', 'None')
|
483 |
+
"""
|
484 |
+
result = {
|
485 |
+
"has_temporal_structure": False,
|
486 |
+
"temporal_columns": [], # Will store primary time column or heuristic list
|
487 |
+
"is_panel_data": False,
|
488 |
+
"time_column": None,
|
489 |
+
"id_column": None,
|
490 |
+
"time_periods": None,
|
491 |
+
"units": None,
|
492 |
+
"identification_method": "None"
|
493 |
+
}
|
494 |
+
|
495 |
+
# --- Step 1: Heuristic identification (as before) ---
|
496 |
+
#heuristic_datetime_cols = []
|
497 |
+
#for col in df.columns:
|
498 |
+
# if pd.api.types.is_datetime64_any_dtype(df[col]):
|
499 |
+
# heuristic_datetime_cols.append(col)
|
500 |
+
# elif pd.api.types.is_string_dtype(df[col]):
|
501 |
+
# try:
|
502 |
+
# if pd.to_datetime(df[col], errors='coerce').notna().any():
|
503 |
+
# heuristic_datetime_cols.append(col)
|
504 |
+
# except:
|
505 |
+
# pass # Ignore conversion errors
|
506 |
+
|
507 |
+
#time_keywords = ['year', 'month', 'day', 'date', 'time', 'period', 'quarter', 'week']
|
508 |
+
#for col in df.columns:
|
509 |
+
# col_lower = col.lower()
|
510 |
+
# if any(keyword in col_lower for keyword in time_keywords) and col not in heuristic_datetime_cols:
|
511 |
+
# heuristic_datetime_cols.append(col)
|
512 |
+
|
513 |
+
#id_keywords = ['id', 'individual', 'person', 'unit', 'entity', 'firm', 'company', 'state', 'country']
|
514 |
+
#heuristic_potential_id_cols = []
|
515 |
+
#for col in df.columns:
|
516 |
+
# col_lower = col.lower()
|
517 |
+
# # Exclude columns already identified as time-related by heuristics
|
518 |
+
# if any(keyword in col_lower for keyword in id_keywords) and col not in heuristic_datetime_cols:
|
519 |
+
# heuristic_potential_id_cols.append(col)
|
520 |
+
|
521 |
+
# --- Step 2: LLM-assisted identification ---
|
522 |
+
llm_identified_time_var = None
|
523 |
+
llm_identified_unit_var = None
|
524 |
+
heuristic_datetime_cols = []
|
525 |
+
heuristic_potential_id_cols = []
|
526 |
+
dataset_summary = df.describe(include='all')
|
527 |
+
|
528 |
+
if llm_client:
|
529 |
+
logger.info("Attempting LLM-assisted identification of temporal/unit variables.")
|
530 |
+
column_names = df.columns.tolist()
|
531 |
+
column_dtypes_dict = {col: str(df[col].dtype) for col in column_names}
|
532 |
+
|
533 |
+
try:
|
534 |
+
llm_suggestions = llm_identify_temporal_and_unit_vars(
|
535 |
+
column_names=column_names,
|
536 |
+
column_dtypes=column_dtypes_dict,
|
537 |
+
dataset_description=dataset_description if dataset_description else "No dataset description provided.",
|
538 |
+
dataset_summary=dataset_summary,
|
539 |
+
heuristic_time_candidates=heuristic_datetime_cols,
|
540 |
+
heuristic_id_candidates=heuristic_potential_id_cols,
|
541 |
+
query=original_query if original_query else "No query provided.",
|
542 |
+
llm=llm_client
|
543 |
+
)
|
544 |
+
llm_identified_time_var = llm_suggestions.get("time_variable")
|
545 |
+
llm_identified_unit_var = llm_suggestions.get("unit_variable")
|
546 |
+
result["identification_method"] = "LLM"
|
547 |
+
|
548 |
+
if not llm_identified_time_var and not llm_identified_unit_var:
|
549 |
+
result["identification_method"] = "LLM_NoIdentification"
|
550 |
+
except Exception as e:
|
551 |
+
logger.warning(f"LLM call for temporal/unit vars failed: {e}. Falling back to heuristics.")
|
552 |
+
result["identification_method"] = "Heuristic_LLM_Error"
|
553 |
+
else:
|
554 |
+
result["identification_method"] = "Heuristic_NoLLM"
|
555 |
+
|
556 |
+
# --- Step 3: Combine LLM and Heuristic Results ---
|
557 |
+
final_time_var = None
|
558 |
+
final_unit_var = None
|
559 |
+
|
560 |
+
if llm_identified_time_var:
|
561 |
+
final_time_var = llm_identified_time_var
|
562 |
+
logger.info(f"Prioritizing LLM identified time variable: {final_time_var}")
|
563 |
+
elif heuristic_datetime_cols:
|
564 |
+
final_time_var = heuristic_datetime_cols[0] # Fallback to first heuristic time col
|
565 |
+
logger.info(f"Using heuristic time variable: {final_time_var}")
|
566 |
+
|
567 |
+
if llm_identified_unit_var:
|
568 |
+
final_unit_var = llm_identified_unit_var
|
569 |
+
logger.info(f"Prioritizing LLM identified unit variable: {final_unit_var}")
|
570 |
+
elif heuristic_potential_id_cols:
|
571 |
+
final_unit_var = heuristic_potential_id_cols[0] # Fallback to first heuristic ID col
|
572 |
+
logger.info(f"Using heuristic unit variable: {final_unit_var}")
|
573 |
+
|
574 |
+
# Update results based on final selections
|
575 |
+
if final_time_var:
|
576 |
+
result["has_temporal_structure"] = True
|
577 |
+
result["temporal_columns"] = [final_time_var] # Store as a list with the primary time var
|
578 |
+
result["time_column"] = final_time_var
|
579 |
+
else: # If no time var found by LLM or heuristic, use original heuristic list for temporal_columns
|
580 |
+
if heuristic_datetime_cols:
|
581 |
+
result["has_temporal_structure"] = True
|
582 |
+
result["temporal_columns"] = heuristic_datetime_cols
|
583 |
+
# time_column remains None
|
584 |
+
|
585 |
+
if final_unit_var:
|
586 |
+
result["id_column"] = final_unit_var
|
587 |
+
|
588 |
+
# --- Step 4: Update Panel Data Logic (based on final_time_var and final_unit_var) ---
|
589 |
+
if final_time_var and final_unit_var:
|
590 |
+
# Check if there are multiple time periods per unit using the identified variables
|
591 |
+
try:
|
592 |
+
# Ensure columns exist before groupby
|
593 |
+
if final_time_var in df.columns and final_unit_var in df.columns:
|
594 |
+
if df.groupby(final_unit_var)[final_time_var].nunique().mean() > 1.0:
|
595 |
+
result["is_panel_data"] = True
|
596 |
+
result["time_periods"] = df[final_time_var].nunique()
|
597 |
+
result["units"] = df[final_unit_var].nunique()
|
598 |
+
logger.info(f"Panel data detected: Time='{final_time_var}', Unit='{final_unit_var}', Periods={result['time_periods']}, Units={result['units']}")
|
599 |
+
else:
|
600 |
+
logger.info("Not panel data: Each unit does not have multiple time periods.")
|
601 |
+
else:
|
602 |
+
logger.warning(f"Final time ('{final_time_var}') or unit ('{final_unit_var}') var not in DataFrame. Cannot confirm panel structure.")
|
603 |
+
except Exception as e:
|
604 |
+
logger.error(f"Error checking panel data structure with time='{final_time_var}', unit='{final_unit_var}': {e}")
|
605 |
+
result["is_panel_data"] = False # Default to false on error
|
606 |
+
else:
|
607 |
+
logger.info("Not panel data: Missing either time or unit variable for panel structure.")
|
608 |
+
|
609 |
+
logger.debug(f"Final temporal structure detection result: {result}")
|
610 |
+
return result
|
611 |
+
|
612 |
+
|
613 |
+
def find_potential_instruments(
|
614 |
+
df: pd.DataFrame,
|
615 |
+
llm_client: Optional[BaseChatModel] = None,
|
616 |
+
potential_treatments: List[str] = None,
|
617 |
+
potential_outcomes: List[str] = None,
|
618 |
+
dataset_description: Optional[str] = None
|
619 |
+
) -> List[Dict[str, Any]]:
|
620 |
+
"""
|
621 |
+
Find potential instrumental variables in the dataset, using LLM if available.
|
622 |
+
Falls back to heuristic method if LLM fails or is not available.
|
623 |
+
|
624 |
+
Args:
|
625 |
+
df: DataFrame to analyze
|
626 |
+
llm_client: Optional LLM client for enhanced identification
|
627 |
+
potential_treatments: Optional list of potential treatment variables
|
628 |
+
potential_outcomes: Optional list of potential outcome variables
|
629 |
+
dataset_description: Optional description of the dataset for context
|
630 |
+
|
631 |
+
Returns:
|
632 |
+
List of potential instrumental variables with their properties
|
633 |
+
"""
|
634 |
+
# Try LLM approach if client is provided
|
635 |
+
if llm_client:
|
636 |
+
try:
|
637 |
+
logger.info("Using LLM to identify potential instrumental variables")
|
638 |
+
|
639 |
+
# Create a concise prompt with just column information
|
640 |
+
columns_list = df.columns.tolist()
|
641 |
+
|
642 |
+
# Exclude known treatment and outcome variables from consideration
|
643 |
+
excluded_columns = []
|
644 |
+
if potential_treatments:
|
645 |
+
excluded_columns.extend(potential_treatments)
|
646 |
+
if potential_outcomes:
|
647 |
+
excluded_columns.extend(potential_outcomes)
|
648 |
+
|
649 |
+
# Filter columns to exclude treatments and outcomes
|
650 |
+
candidate_columns = [col for col in columns_list if col not in excluded_columns]
|
651 |
+
|
652 |
+
if not candidate_columns:
|
653 |
+
logger.warning("No eligible columns for instrumental variables after filtering treatments and outcomes")
|
654 |
+
return []
|
655 |
+
|
656 |
+
# Get column types for context
|
657 |
+
column_types = {col: str(df[col].dtype) for col in candidate_columns}
|
658 |
+
|
659 |
+
# Add dataset description if available
|
660 |
+
description_text = f"\nDataset Description: {dataset_description}" if dataset_description else ""
|
661 |
+
|
662 |
+
prompt = f"""
|
663 |
+
You are an expert causal inference data scientist. Identify potential instrumental variables from this dataset.{description_text}
|
664 |
+
|
665 |
+
DEFINITION: Instrumental variables must:
|
666 |
+
1. Be correlated with the treatment variable (relevance)
|
667 |
+
2. Only affect the outcome through the treatment (exclusion restriction)
|
668 |
+
3. Not be correlated with unmeasured confounders (exogeneity)
|
669 |
+
|
670 |
+
Treatment variables: {potential_treatments if potential_treatments else "Unknown"}
|
671 |
+
Outcome variables: {potential_outcomes if potential_outcomes else "Unknown"}
|
672 |
+
|
673 |
+
Available columns (excluding treatments and outcomes):
|
674 |
+
{candidate_columns}
|
675 |
+
|
676 |
+
Column types:
|
677 |
+
{column_types}
|
678 |
+
|
679 |
+
Look for variables likely to be:
|
680 |
+
- Random assignments
|
681 |
+
- Policy changes
|
682 |
+
- Geographic or temporal variations
|
683 |
+
- Variables with names containing: 'instrument', 'iv', 'assigned', 'random', 'lottery', 'exogenous'
|
684 |
+
|
685 |
+
Return ONLY a JSON array of objects, each with "variable", "reason", and "data_type" fields.
|
686 |
+
Example:
|
687 |
+
[
|
688 |
+
{{"variable": "random_assignment", "reason": "Random assignment variable", "data_type": "int64"}},
|
689 |
+
{{"variable": "distance_to_facility", "reason": "Geographic variation", "data_type": "float64"}}
|
690 |
+
]
|
691 |
+
"""
|
692 |
+
|
693 |
+
# Call the LLM and parse the response
|
694 |
+
response = llm_client.invoke(prompt)
|
695 |
+
response_text = response.content if hasattr(response, 'content') else str(response)
|
696 |
+
|
697 |
+
# Extract JSON from the response text
|
698 |
+
import re
|
699 |
+
json_match = re.search(r'\[\s*{.*}\s*\]', response_text, re.DOTALL)
|
700 |
+
|
701 |
+
if json_match:
|
702 |
+
result = json.loads(json_match.group(0))
|
703 |
+
|
704 |
+
# Validate the response
|
705 |
+
if isinstance(result, list) and len(result) > 0:
|
706 |
+
# Filter for valid entries
|
707 |
+
valid_instruments = []
|
708 |
+
for item in result:
|
709 |
+
if not isinstance(item, dict) or "variable" not in item:
|
710 |
+
continue
|
711 |
+
|
712 |
+
if item["variable"] not in df.columns:
|
713 |
+
continue
|
714 |
+
|
715 |
+
# Ensure all required fields are present
|
716 |
+
if "reason" not in item:
|
717 |
+
item["reason"] = "Identified by LLM"
|
718 |
+
if "data_type" not in item:
|
719 |
+
item["data_type"] = str(df[item["variable"]].dtype)
|
720 |
+
|
721 |
+
valid_instruments.append(item)
|
722 |
+
|
723 |
+
if valid_instruments:
|
724 |
+
logger.info(f"LLM identified {len(valid_instruments)} potential instrumental variables {valid_instruments}")
|
725 |
+
return valid_instruments
|
726 |
+
else:
|
727 |
+
logger.warning("No valid instruments found by LLM, falling back to heuristic method")
|
728 |
+
else:
|
729 |
+
logger.warning("Invalid LLM response format, falling back to heuristic method")
|
730 |
+
else:
|
731 |
+
logger.warning("Could not extract JSON from LLM response, falling back to heuristic method")
|
732 |
+
|
733 |
+
except Exception as e:
|
734 |
+
logger.error(f"Error in LLM identification of instruments: {e}", exc_info=True)
|
735 |
+
logger.info("Falling back to heuristic method")
|
736 |
+
|
737 |
+
# Fallback to heuristic method
|
738 |
+
logger.info("Using heuristic method to identify potential instrumental variables")
|
739 |
+
potential_instruments = []
|
740 |
+
|
741 |
+
# Look for variables with instrumental-related names
|
742 |
+
instrument_keywords = ['instrument', 'iv', 'assigned', 'random', 'lottery', 'exogenous']
|
743 |
+
|
744 |
+
for col in df.columns:
|
745 |
+
# Skip treatment and outcome variables
|
746 |
+
if potential_treatments and col in potential_treatments:
|
747 |
+
continue
|
748 |
+
if potential_outcomes and col in potential_outcomes:
|
749 |
+
continue
|
750 |
+
|
751 |
+
col_lower = col.lower()
|
752 |
+
if any(keyword in col_lower for keyword in instrument_keywords):
|
753 |
+
instrument_info = {
|
754 |
+
"variable": col,
|
755 |
+
"reason": f"Name contains instrument-related keyword",
|
756 |
+
"data_type": str(df[col].dtype)
|
757 |
+
}
|
758 |
+
potential_instruments.append(instrument_info)
|
759 |
+
|
760 |
+
return potential_instruments
|
761 |
+
|
762 |
+
|
763 |
+
def detect_discontinuities(df: pd.DataFrame) -> Dict[str, Any]:
|
764 |
+
"""
|
765 |
+
Identify discontinuities in continuous variables (for RDD).
|
766 |
+
|
767 |
+
Args:
|
768 |
+
df: DataFrame to analyze
|
769 |
+
|
770 |
+
Returns:
|
771 |
+
Dict with information about detected discontinuities
|
772 |
+
"""
|
773 |
+
discontinuities = []
|
774 |
+
|
775 |
+
# For each numeric column, check for potential discontinuities
|
776 |
+
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
777 |
+
|
778 |
+
for col in numeric_cols:
|
779 |
+
# Skip columns with too many unique values
|
780 |
+
if df[col].nunique() > 100:
|
781 |
+
continue
|
782 |
+
|
783 |
+
values = df[col].dropna().sort_values().values
|
784 |
+
|
785 |
+
# Calculate gaps between consecutive values
|
786 |
+
if len(values) > 10:
|
787 |
+
gaps = np.diff(values)
|
788 |
+
mean_gap = np.mean(gaps)
|
789 |
+
std_gap = np.std(gaps)
|
790 |
+
|
791 |
+
# Look for unusually large gaps (potential discontinuities)
|
792 |
+
large_gaps = np.where(gaps > mean_gap + 2*std_gap)[0]
|
793 |
+
|
794 |
+
if len(large_gaps) > 0:
|
795 |
+
for idx in large_gaps:
|
796 |
+
cutpoint = (values[idx] + values[idx+1]) / 2
|
797 |
+
discontinuities.append({
|
798 |
+
"variable": col,
|
799 |
+
"cutpoint": float(cutpoint),
|
800 |
+
"gap_size": float(gaps[idx]),
|
801 |
+
"mean_gap": float(mean_gap)
|
802 |
+
})
|
803 |
+
|
804 |
+
return {
|
805 |
+
"has_discontinuities": len(discontinuities) > 0,
|
806 |
+
"discontinuities": discontinuities
|
807 |
+
}
|
808 |
+
|
809 |
+
|
810 |
+
def assess_variable_relationships(df: pd.DataFrame, corr_matrix: pd.DataFrame) -> Dict[str, Any]:
|
811 |
+
"""
|
812 |
+
Assess relationships between variables in the dataset.
|
813 |
+
|
814 |
+
Args:
|
815 |
+
df: DataFrame to analyze
|
816 |
+
corr_matrix: Precomputed correlation matrix for numeric columns
|
817 |
+
|
818 |
+
Returns:
|
819 |
+
Dict with information about variable relationships:
|
820 |
+
- strongly_correlated_pairs: Pairs of strongly correlated variables
|
821 |
+
- potential_confounders: Variables that might be confounders
|
822 |
+
"""
|
823 |
+
result = {"strongly_correlated_pairs": [], "potential_confounders": []}
|
824 |
+
|
825 |
+
numeric_cols = corr_matrix.columns.tolist()
|
826 |
+
if len(numeric_cols) < 2:
|
827 |
+
return result
|
828 |
+
|
829 |
+
# Use the precomputed correlation matrix
|
830 |
+
corr_matrix_abs = corr_matrix.abs()
|
831 |
+
|
832 |
+
# Find strongly correlated variable pairs
|
833 |
+
for i in range(len(numeric_cols)):
|
834 |
+
for j in range(i+1, len(numeric_cols)):
|
835 |
+
if abs(corr_matrix_abs.iloc[i, j]) > 0.7: # Correlation threshold
|
836 |
+
result["strongly_correlated_pairs"].append({
|
837 |
+
"variables": [numeric_cols[i], numeric_cols[j]],
|
838 |
+
"correlation": float(corr_matrix.iloc[i, j])
|
839 |
+
})
|
840 |
+
|
841 |
+
# Identify potential confounders (variables correlated with multiple others)
|
842 |
+
confounder_counts = {col: 0 for col in numeric_cols}
|
843 |
+
|
844 |
+
for pair in result["strongly_correlated_pairs"]:
|
845 |
+
confounder_counts[pair["variables"][0]] += 1
|
846 |
+
confounder_counts[pair["variables"][1]] += 1
|
847 |
+
|
848 |
+
# Variables correlated with multiple others are potential confounders
|
849 |
+
for col, count in confounder_counts.items():
|
850 |
+
if count >= 2:
|
851 |
+
result["potential_confounders"].append({"variable": col, "num_correlations": count})
|
852 |
+
|
853 |
+
return result
|
auto_causal/components/decision_tree.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
decision tree component for selecting causal inference methods
|
3 |
+
|
4 |
+
this module implements the decision tree logic to select the most appropriate
|
5 |
+
causal inference method based on dataset characteristics and available variables
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
from typing import Dict, List, Any, Optional
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
# define method names
|
13 |
+
BACKDOOR_ADJUSTMENT = "backdoor_adjustment"
|
14 |
+
LINEAR_REGRESSION = "linear_regression"
|
15 |
+
DIFF_IN_MEANS = "diff_in_means"
|
16 |
+
DIFF_IN_DIFF = "difference_in_differences"
|
17 |
+
REGRESSION_DISCONTINUITY = "regression_discontinuity_design"
|
18 |
+
PROPENSITY_SCORE_MATCHING = "propensity_score_matching"
|
19 |
+
INSTRUMENTAL_VARIABLE = "instrumental_variable"
|
20 |
+
CORRELATION_ANALYSIS = "correlation_analysis"
|
21 |
+
PROPENSITY_SCORE_WEIGHTING = "propensity_score_weighting"
|
22 |
+
GENERALIZED_PROPENSITY_SCORE = "generalized_propensity_score"
|
23 |
+
FRONTDOOR_ADJUSTMENT = "frontdoor_adjustment"
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
# method assumptions mapping
|
29 |
+
METHOD_ASSUMPTIONS = {
|
30 |
+
BACKDOOR_ADJUSTMENT: [
|
31 |
+
"no unmeasured confounders (conditional ignorability given covariates)",
|
32 |
+
"correct model specification for outcome conditional on treatment and covariates",
|
33 |
+
"positivity/overlap (for all covariate values, units could potentially receive either treatment level)"
|
34 |
+
],
|
35 |
+
LINEAR_REGRESSION: [
|
36 |
+
"linear relationship between treatment, covariates, and outcome",
|
37 |
+
"no unmeasured confounders (if observational)",
|
38 |
+
"correct model specification",
|
39 |
+
"homoscedasticity of errors",
|
40 |
+
"normally distributed errors (for inference)"
|
41 |
+
],
|
42 |
+
DIFF_IN_MEANS: [
|
43 |
+
"treatment is randomly assigned (or as-if random)",
|
44 |
+
"no spillover effects",
|
45 |
+
"stable unit treatment value assumption (SUTVA)"
|
46 |
+
],
|
47 |
+
DIFF_IN_DIFF: [
|
48 |
+
"parallel trends between treatment and control groups before treatment",
|
49 |
+
"no spillover effects between groups",
|
50 |
+
"no anticipation effects before treatment",
|
51 |
+
"stable composition of treatment and control groups",
|
52 |
+
"treatment timing is exogenous"
|
53 |
+
],
|
54 |
+
REGRESSION_DISCONTINUITY: [
|
55 |
+
"units cannot precisely manipulate the running variable around the cutoff",
|
56 |
+
"continuity of conditional expectation functions of potential outcomes at the cutoff",
|
57 |
+
"no other changes occurring precisely at the cutoff"
|
58 |
+
],
|
59 |
+
PROPENSITY_SCORE_MATCHING: [
|
60 |
+
"no unmeasured confounders (conditional ignorability)",
|
61 |
+
"sufficient overlap (common support) between treatment and control groups",
|
62 |
+
"correct propensity score model specification"
|
63 |
+
],
|
64 |
+
INSTRUMENTAL_VARIABLE: [
|
65 |
+
"instrument is correlated with treatment (relevance)",
|
66 |
+
"instrument affects outcome only through treatment (exclusion restriction)",
|
67 |
+
"instrument is independent of unmeasured confounders (exogeneity/independence)"
|
68 |
+
],
|
69 |
+
CORRELATION_ANALYSIS: [
|
70 |
+
"data represents a sample from the population of interest",
|
71 |
+
"variables are measured appropriately"
|
72 |
+
],
|
73 |
+
PROPENSITY_SCORE_WEIGHTING: [
|
74 |
+
"no unmeasured confounders (conditional ignorability)",
|
75 |
+
"sufficient overlap (common support) between treatment and control groups",
|
76 |
+
"correct propensity score model specification",
|
77 |
+
"weights correctly specified (e.g., ATE, ATT)"
|
78 |
+
],
|
79 |
+
GENERALIZED_PROPENSITY_SCORE: [
|
80 |
+
"conditional mean independence",
|
81 |
+
"positivity/common support for GPS",
|
82 |
+
"correct specification of the GPS model",
|
83 |
+
"correct specification of the outcome model",
|
84 |
+
"no unmeasured confounders affecting both treatment and outcome, given X",
|
85 |
+
"treatment variable is continuous"
|
86 |
+
],
|
87 |
+
FRONTDOOR_ADJUSTMENT: [
|
88 |
+
"mediator is affected by treatment and affects outcome",
|
89 |
+
"mediator is not affected by any confounders of the treatment-outcome relationship"
|
90 |
+
]
|
91 |
+
}
|
92 |
+
|
93 |
+
|
94 |
+
def select_method(dataset_properties: Dict[str, Any], excluded_methods: Optional[List[str]] = None) -> Dict[str, Any]:
|
95 |
+
excluded_methods = set(excluded_methods or [])
|
96 |
+
logger.info(f"Excluded methods: {sorted(excluded_methods)}")
|
97 |
+
|
98 |
+
treatment = dataset_properties.get("treatment_variable")
|
99 |
+
outcome = dataset_properties.get("outcome_variable")
|
100 |
+
if not treatment or not outcome:
|
101 |
+
raise ValueError("Both treatment and outcome variables must be specified")
|
102 |
+
|
103 |
+
instrument_var = dataset_properties.get("instrument_variable")
|
104 |
+
running_var = dataset_properties.get("running_variable")
|
105 |
+
cutoff_val = dataset_properties.get("cutoff_value")
|
106 |
+
time_var = dataset_properties.get("time_variable")
|
107 |
+
is_rct = dataset_properties.get("is_rct", False)
|
108 |
+
has_temporal = dataset_properties.get("has_temporal_structure", False)
|
109 |
+
frontdoor = dataset_properties.get("frontdoor_criterion", False)
|
110 |
+
covariate_overlap_result = dataset_properties.get("covariate_overlap_score")
|
111 |
+
covariates = dataset_properties.get("covariates", [])
|
112 |
+
treatment_variable_type = dataset_properties.get("treatment_variable_type", "binary")
|
113 |
+
|
114 |
+
# Helpers to collect candidates
|
115 |
+
candidates = [] # list of (method, priority_index)
|
116 |
+
justifications: Dict[str, str] = {}
|
117 |
+
assumptions: Dict[str, List[str]] = {}
|
118 |
+
|
119 |
+
def add(method: str, justification: str, prio_order: List[str]):
|
120 |
+
if method in justifications: # already added
|
121 |
+
return
|
122 |
+
justifications[method] = justification
|
123 |
+
assumptions[method] = METHOD_ASSUMPTIONS[method]
|
124 |
+
# priority index from provided order (fallback large if not present)
|
125 |
+
try:
|
126 |
+
idx = prio_order.index(method)
|
127 |
+
except ValueError:
|
128 |
+
idx = 10**6
|
129 |
+
candidates.append((method, idx))
|
130 |
+
|
131 |
+
# ----- Build candidate set (no returns here) -----
|
132 |
+
|
133 |
+
# RCT branch
|
134 |
+
if is_rct:
|
135 |
+
logger.info("Dataset is from a randomized controlled trial (RCT)")
|
136 |
+
rct_priority = [INSTRUMENTAL_VARIABLE, LINEAR_REGRESSION, DIFF_IN_MEANS]
|
137 |
+
|
138 |
+
if instrument_var and instrument_var != treatment:
|
139 |
+
add(INSTRUMENTAL_VARIABLE,
|
140 |
+
f"RCT encouragement: instrument '{instrument_var}' differs from treatment '{treatment}'.",
|
141 |
+
rct_priority)
|
142 |
+
|
143 |
+
if covariates:
|
144 |
+
add(LINEAR_REGRESSION,
|
145 |
+
"RCT with covariates—use OLS for precision.",
|
146 |
+
rct_priority)
|
147 |
+
else:
|
148 |
+
add(DIFF_IN_MEANS,
|
149 |
+
"Pure RCT without covariates—difference-in-means.",
|
150 |
+
rct_priority)
|
151 |
+
|
152 |
+
# Observational branch
|
153 |
+
obs_priority_binary = [
|
154 |
+
INSTRUMENTAL_VARIABLE,
|
155 |
+
PROPENSITY_SCORE_MATCHING,
|
156 |
+
PROPENSITY_SCORE_WEIGHTING,
|
157 |
+
FRONTDOOR_ADJUSTMENT,
|
158 |
+
LINEAR_REGRESSION,
|
159 |
+
]
|
160 |
+
obs_priority_nonbinary = [
|
161 |
+
INSTRUMENTAL_VARIABLE,
|
162 |
+
FRONTDOOR_ADJUSTMENT,
|
163 |
+
LINEAR_REGRESSION,
|
164 |
+
]
|
165 |
+
|
166 |
+
# Common early structural signals first (still only add as candidates)
|
167 |
+
if has_temporal and time_var:
|
168 |
+
add(DIFF_IN_DIFF,
|
169 |
+
f"Temporal structure via '{time_var}'—consider Difference-in-Differences (assumes parallel trends).",
|
170 |
+
[DIFF_IN_DIFF]) # highest among itself
|
171 |
+
|
172 |
+
if running_var and cutoff_val is not None:
|
173 |
+
add(REGRESSION_DISCONTINUITY,
|
174 |
+
f"Running variable '{running_var}' with cutoff {cutoff_val}—consider RDD.",
|
175 |
+
[REGRESSION_DISCONTINUITY])
|
176 |
+
|
177 |
+
# Binary vs non-binary pathways
|
178 |
+
if treatment_variable_type == "binary":
|
179 |
+
if instrument_var:
|
180 |
+
add(INSTRUMENTAL_VARIABLE,
|
181 |
+
f"Instrumental variable '{instrument_var}' available.",
|
182 |
+
obs_priority_binary)
|
183 |
+
|
184 |
+
# Propensity score methods only if covariates exist
|
185 |
+
if covariates:
|
186 |
+
if covariate_overlap_result is not None:
|
187 |
+
ps_method = (PROPENSITY_SCORE_WEIGHTING
|
188 |
+
if covariate_overlap_result < 0.1
|
189 |
+
else PROPENSITY_SCORE_MATCHING)
|
190 |
+
else:
|
191 |
+
ps_method = PROPENSITY_SCORE_MATCHING
|
192 |
+
add(ps_method,
|
193 |
+
"Covariates observed; PS method chosen based on overlap.",
|
194 |
+
obs_priority_binary)
|
195 |
+
|
196 |
+
if frontdoor:
|
197 |
+
add(FRONTDOOR_ADJUSTMENT,
|
198 |
+
"Front-door criterion satisfied.",
|
199 |
+
obs_priority_binary)
|
200 |
+
|
201 |
+
add(LINEAR_REGRESSION,
|
202 |
+
"OLS as a fallback specification.",
|
203 |
+
obs_priority_binary)
|
204 |
+
|
205 |
+
else:
|
206 |
+
logger.info(f"Non-binary treatment variable detected: {treatment_variable_type}")
|
207 |
+
if instrument_var:
|
208 |
+
add(INSTRUMENTAL_VARIABLE,
|
209 |
+
f"Instrument '{instrument_var}' candidate for non-binary treatment.",
|
210 |
+
obs_priority_nonbinary)
|
211 |
+
if frontdoor:
|
212 |
+
add(FRONTDOOR_ADJUSTMENT,
|
213 |
+
"Front-door criterion satisfied.",
|
214 |
+
obs_priority_nonbinary)
|
215 |
+
add(LINEAR_REGRESSION,
|
216 |
+
"Fallback for non-binary treatment without stronger identification.",
|
217 |
+
obs_priority_nonbinary)
|
218 |
+
|
219 |
+
# ----- Centralized exclusion handling -----
|
220 |
+
# Remove excluded
|
221 |
+
filtered = [(m, p) for (m, p) in candidates if m not in excluded_methods]
|
222 |
+
|
223 |
+
# If nothing survives, attempt a safe fallback not excluded
|
224 |
+
if not filtered:
|
225 |
+
logger.warning(f"All candidates excluded. Candidates were: {[m for m,_ in candidates]}. Excluded: {sorted(excluded_methods)}")
|
226 |
+
fallback_order = [
|
227 |
+
LINEAR_REGRESSION,
|
228 |
+
DIFF_IN_MEANS,
|
229 |
+
PROPENSITY_SCORE_MATCHING,
|
230 |
+
PROPENSITY_SCORE_WEIGHTING,
|
231 |
+
DIFF_IN_DIFF,
|
232 |
+
REGRESSION_DISCONTINUITY,
|
233 |
+
INSTRUMENTAL_VARIABLE,
|
234 |
+
FRONTDOOR_ADJUSTMENT,
|
235 |
+
]
|
236 |
+
fallback = next((m for m in fallback_order if m in justifications and m not in excluded_methods), None)
|
237 |
+
if not fallback:
|
238 |
+
# truly nothing left; raise with context
|
239 |
+
raise RuntimeError("No viable method remains after exclusions.")
|
240 |
+
selected_method = fallback
|
241 |
+
alternatives = []
|
242 |
+
justifications[selected_method] = justifications.get(selected_method, "Fallback after exclusions.")
|
243 |
+
else:
|
244 |
+
# Pick by smallest priority index, then stable by insertion
|
245 |
+
filtered.sort(key=lambda x: x[1])
|
246 |
+
selected_method = filtered[0][0]
|
247 |
+
alternatives = [m for (m, _) in filtered[1:] if m != selected_method]
|
248 |
+
|
249 |
+
logger.info(f"Selected method: {selected_method}; alternatives: {alternatives}")
|
250 |
+
|
251 |
+
return {
|
252 |
+
"selected_method": selected_method,
|
253 |
+
"method_justification": justifications[selected_method],
|
254 |
+
"method_assumptions": assumptions[selected_method],
|
255 |
+
"alternatives": alternatives,
|
256 |
+
"excluded_methods": sorted(excluded_methods),
|
257 |
+
}
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
def rule_based_select_method(dataset_analysis, variables, is_rct, llm, dataset_description, original_query, excluded_methods=None):
|
262 |
+
"""
|
263 |
+
Wrapped function to select causal method based on dataset properties and query
|
264 |
+
|
265 |
+
Args:
|
266 |
+
dataset_analysis (Dict): results of dataset analysis
|
267 |
+
variables (Dict): dictionary of variable names and types
|
268 |
+
is_rct (bool): whether the dataset is from a randomized controlled trial
|
269 |
+
llm (BaseChatModel): language model instance for generating prompts
|
270 |
+
dataset_description (str): description of the dataset
|
271 |
+
original_query (str): the original user query
|
272 |
+
excluded_methods (List[str], optional): list of methods to exclude from selection
|
273 |
+
"""
|
274 |
+
|
275 |
+
logger.info("Running rule-based method selection")
|
276 |
+
|
277 |
+
|
278 |
+
properties = {"treatment_variable": variables.get("treatment_variable"), "instrument_variable":variables.get("instrument_variable"),
|
279 |
+
"covariates": variables.get("covariates", []), "outcome_variable": variables.get("outcome_variable"),
|
280 |
+
"time_variable": variables.get("time_variable"), "running_variable": variables.get("running_variable"),
|
281 |
+
"treatment_variable_type": variables.get("treatment_variable_type", "binary"),
|
282 |
+
"has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
|
283 |
+
"frontdoor_criterion": variables.get("frontdoor_criterion", False),
|
284 |
+
"cutoff_value": variables.get("cutoff_value"),
|
285 |
+
"covariate_overlap_score": variables.get("covariate_overlap_result", 0)}
|
286 |
+
|
287 |
+
properties["is_rct"] = is_rct
|
288 |
+
logger.info(f"Dataset properties for method selection: {properties}")
|
289 |
+
|
290 |
+
return select_method(properties, excluded_methods)
|
291 |
+
|
292 |
+
|
293 |
+
|
294 |
+
class DecisionTreeEngine:
|
295 |
+
"""
|
296 |
+
Engine for applying decision trees to select appropriate causal methods.
|
297 |
+
|
298 |
+
This class wraps the functional decision tree implementation to provide
|
299 |
+
an object-oriented interface for method selection.
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(self, verbose=False):
|
303 |
+
self.verbose = verbose
|
304 |
+
|
305 |
+
def select_method(self, df: pd.DataFrame, treatment: str, outcome: str, covariates: List[str],
|
306 |
+
dataset_analysis: Dict[str, Any], query_details: Dict[str, Any]) -> Dict[str, Any]:
|
307 |
+
"""
|
308 |
+
Apply decision tree to select appropriate causal method.
|
309 |
+
"""
|
310 |
+
|
311 |
+
if self.verbose:
|
312 |
+
print(f"Applying decision tree for treatment: {treatment}, outcome: {outcome}")
|
313 |
+
print(f"Available covariates: {covariates}")
|
314 |
+
|
315 |
+
treatment_variable_type = query_details.get("treatment_variable_type")
|
316 |
+
covariate_overlap_result = query_details.get("covariate_overlap_result")
|
317 |
+
info = {"treatment_variable": treatment, "outcome_variable": outcome,
|
318 |
+
"covariates": covariates, "time_variable": query_details.get("time_variable"),
|
319 |
+
"group_variable": query_details.get("group_variable"),
|
320 |
+
"instrument_variable": query_details.get("instrument_variable"),
|
321 |
+
"running_variable": query_details.get("running_variable"),
|
322 |
+
"cutoff_value": query_details.get("cutoff_value"),
|
323 |
+
"is_rct": query_details.get("is_rct", False),
|
324 |
+
"has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
|
325 |
+
"frontdoor_criterion": query_details.get("frontdoor_criterion", False),
|
326 |
+
"covariate_overlap_score": covariate_overlap_result,
|
327 |
+
"treatment_variable_type": treatment_variable_type}
|
328 |
+
|
329 |
+
result = select_method(info)
|
330 |
+
|
331 |
+
if self.verbose:
|
332 |
+
print(f"Selected method: {result['selected_method']}")
|
333 |
+
print(f"Justification: {result['method_justification']}")
|
334 |
+
|
335 |
+
result["decision_path"] = self._get_decision_path(result["selected_method"])
|
336 |
+
return result
|
337 |
+
|
338 |
+
|
339 |
+
def _get_decision_path(self, method):
|
340 |
+
if method == "linear_regression":
|
341 |
+
return ["Check if randomized experiment", "Data appears to be from a randomized experiment with covariates"]
|
342 |
+
elif method == "propensity_score_matching":
|
343 |
+
return ["Check if randomized experiment", "Data is observational",
|
344 |
+
"Check for sufficient covariate overlap", "Sufficient overlap exists"]
|
345 |
+
elif method == "propensity_score_weighting":
|
346 |
+
return ["Check if randomized experiment", "Data is observational",
|
347 |
+
"Check for sufficient covariate overlap", "Low overlap—weighting preferred"]
|
348 |
+
elif method == "backdoor_adjustment":
|
349 |
+
return ["Check if randomized experiment", "Data is observational",
|
350 |
+
"Check for sufficient covariate overlap", "Adjusting for covariates"]
|
351 |
+
elif method == "instrumental_variable":
|
352 |
+
return ["Check if randomized experiment", "Data is observational",
|
353 |
+
"Check for instrumental variables", "Instrument is available"]
|
354 |
+
elif method == "regression_discontinuity_design":
|
355 |
+
return ["Check if randomized experiment", "Data is observational",
|
356 |
+
"Check for discontinuity", "Discontinuity exists"]
|
357 |
+
elif method == "difference_in_differences":
|
358 |
+
return ["Check if randomized experiment", "Data is observational",
|
359 |
+
"Check for temporal structure", "Panel data structure exists"]
|
360 |
+
elif method == "frontdoor_adjustment":
|
361 |
+
return ["Check if randomized experiment", "Data is observational",
|
362 |
+
"Check front-door criterion", "Front-door path identified"]
|
363 |
+
elif method == "diff_in_means":
|
364 |
+
return ["Check if randomized experiment", "Pure RCT without covariates"]
|
365 |
+
else:
|
366 |
+
return ["Default method selection"]
|
auto_causal/components/decision_tree_llm.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM-based Decision tree component for selecting causal inference methods.
|
3 |
+
|
4 |
+
This module implements the decision tree logic via an LLM prompt
|
5 |
+
to select the most appropriate causal inference method based on
|
6 |
+
dataset characteristics and available variables.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import json
|
11 |
+
from typing import Dict, Any, Optional, List
|
12 |
+
|
13 |
+
from langchain_core.messages import HumanMessage
|
14 |
+
from langchain_core.language_models import BaseChatModel
|
15 |
+
|
16 |
+
# Import constants and assumptions from the original decision_tree module
|
17 |
+
from .decision_tree import (
|
18 |
+
METHOD_ASSUMPTIONS,
|
19 |
+
BACKDOOR_ADJUSTMENT,
|
20 |
+
LINEAR_REGRESSION,
|
21 |
+
DIFF_IN_MEANS,
|
22 |
+
DIFF_IN_DIFF,
|
23 |
+
REGRESSION_DISCONTINUITY,
|
24 |
+
PROPENSITY_SCORE_MATCHING,
|
25 |
+
INSTRUMENTAL_VARIABLE,
|
26 |
+
CORRELATION_ANALYSIS,
|
27 |
+
PROPENSITY_SCORE_WEIGHTING,
|
28 |
+
GENERALIZED_PROPENSITY_SCORE
|
29 |
+
)
|
30 |
+
|
31 |
+
# Configure logging
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
# Define a list of all known methods for the LLM prompt
|
35 |
+
ALL_METHODS = [
|
36 |
+
DIFF_IN_MEANS,
|
37 |
+
LINEAR_REGRESSION,
|
38 |
+
DIFF_IN_DIFF,
|
39 |
+
REGRESSION_DISCONTINUITY,
|
40 |
+
INSTRUMENTAL_VARIABLE,
|
41 |
+
PROPENSITY_SCORE_MATCHING,
|
42 |
+
PROPENSITY_SCORE_WEIGHTING,
|
43 |
+
GENERALIZED_PROPENSITY_SCORE,
|
44 |
+
BACKDOOR_ADJUSTMENT, # Often a general approach rather than a specific model.
|
45 |
+
CORRELATION_ANALYSIS,
|
46 |
+
]
|
47 |
+
|
48 |
+
METHOD_DESCRIPTIONS_FOR_LLM = {
|
49 |
+
DIFF_IN_MEANS: "Appropriate for Randomized Controlled Trials (RCTs) with no covariates. Compares the average outcome between treated and control groups.",
|
50 |
+
LINEAR_REGRESSION: "Can be used for RCTs with covariates to increase precision, or for observational data assuming linear relationships and no unmeasured confounders. Models the outcome as a linear function of treatment and covariates.",
|
51 |
+
DIFF_IN_DIFF: "Suitable for observational data with a temporal structure (e.g., panel data with pre/post treatment periods). Requires the 'parallel trends' assumption: treatment and control groups would have followed similar trends in the outcome in the absence of treatment.",
|
52 |
+
REGRESSION_DISCONTINUITY: "Applicable when treatment assignment is determined by whether an observed 'running variable' crosses a specific cutoff point. Assumes individuals cannot precisely manipulate the running variable.",
|
53 |
+
INSTRUMENTAL_VARIABLE: "Used when there's an 'instrument' variable that is correlated with the treatment, affects the outcome only through the treatment, and is not confounded with the outcome. Useful for handling unobserved confounding.",
|
54 |
+
PROPENSITY_SCORE_MATCHING: "For observational data with covariates. Estimates the probability of receiving treatment (propensity score) for each unit and then matches treated and control units with similar scores. Aims to create balanced groups.",
|
55 |
+
PROPENSITY_SCORE_WEIGHTING: "Similar to PSM, for observational data with covariates. Uses propensity scores to weight units to create a pseudo-population where confounders are balanced. Can estimate ATE, ATT, or ATC.",
|
56 |
+
GENERALIZED_PROPENSITY_SCORE: "An extension of propensity scores for continuous treatment variables. Aims to estimate the dose-response function, assuming unconfoundedness given covariates.",
|
57 |
+
BACKDOOR_ADJUSTMENT: "A general strategy for causal inference in observational studies that involves statistically controlling for all common causes (confounders) of the treatment and outcome. Specific methods like regression or matching implement this.",
|
58 |
+
CORRELATION_ANALYSIS: "A fallback method when causal inference is not feasible due to data limitations (e.g., no clear design, no covariates for adjustment). Measures the statistical association between variables, but does not imply causation."
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
class DecisionTreeLLMEngine:
|
63 |
+
"""
|
64 |
+
Engine for applying an LLM-based decision tree to select appropriate causal methods.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, verbose: bool = False):
|
68 |
+
"""
|
69 |
+
Initialize the LLM decision tree engine.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
verbose: Whether to print verbose information.
|
73 |
+
"""
|
74 |
+
self.verbose = verbose
|
75 |
+
|
76 |
+
def _construct_prompt(self, dataset_analysis: Dict[str, Any], variables: Dict[str, Any], is_rct: bool, excluded_methods: Optional[List[str]] = None) -> str:
|
77 |
+
"""
|
78 |
+
Constructs the detailed prompt for the LLM.
|
79 |
+
"""
|
80 |
+
# Filter out excluded methods
|
81 |
+
excluded_methods = excluded_methods or []
|
82 |
+
available_methods = [method for method in ALL_METHODS if method not in excluded_methods]
|
83 |
+
methods_list_str = "\n".join([f"- {method}: {METHOD_DESCRIPTIONS_FOR_LLM[method]}" for method in available_methods if method in METHOD_DESCRIPTIONS_FOR_LLM])
|
84 |
+
|
85 |
+
excluded_info = ""
|
86 |
+
if excluded_methods:
|
87 |
+
excluded_info = f"\nEXCLUDED METHODS (do not select these): {', '.join(excluded_methods)}\nReason: These methods failed validation in previous attempts.\n"
|
88 |
+
|
89 |
+
prompt = f"""You are an expert in causal inference. Your task is to select the most appropriate causal inference method based on the provided dataset analysis and variable information.
|
90 |
+
|
91 |
+
Dataset Analysis:
|
92 |
+
{json.dumps(dataset_analysis, indent=2)}
|
93 |
+
|
94 |
+
Identified Variables:
|
95 |
+
{json.dumps(variables, indent=2)}
|
96 |
+
|
97 |
+
Is the data from a Randomized Controlled Trial (RCT)? {'Yes' if is_rct else 'No'}{excluded_info}
|
98 |
+
|
99 |
+
Available Causal Inference Methods and their descriptions:
|
100 |
+
{methods_list_str}
|
101 |
+
|
102 |
+
Instructions:
|
103 |
+
1. Carefully review all the provided information: dataset analysis, variables, and RCT status.
|
104 |
+
2. Reason step-by-step to determine the most suitable method. Consider the hierarchy of methods (e.g., specific designs like DiD, RDD, IV before general adjustment methods).
|
105 |
+
3. Explain your reasoning for selecting a particular method.
|
106 |
+
4. Identify any potential alternative methods if applicable.
|
107 |
+
5. State the key assumptions for your *selected* method by referring to the general list of assumptions for all methods that will be provided to you separately (you don't need to list them here, just be aware that you need to select a method for which assumptions are known).
|
108 |
+
|
109 |
+
Output your final decision as a JSON object with the following exact keys:
|
110 |
+
- "selected_method": string (must be one of {', '.join(available_methods)})
|
111 |
+
- "method_justification": string (your detailed reasoning)
|
112 |
+
- "alternative_methods": list of strings (alternative method names, can be empty)
|
113 |
+
|
114 |
+
Example JSON output format:
|
115 |
+
{{
|
116 |
+
"selected_method": "difference_in_differences",
|
117 |
+
"method_justification": "The dataset has a clear time variable and group variable, indicating a panel structure suitable for DiD. The parallel trends assumption will need to be checked.",
|
118 |
+
"alternative_methods": ["instrumental_variable"]
|
119 |
+
}}
|
120 |
+
|
121 |
+
Please provide only the JSON object in your response.
|
122 |
+
"""
|
123 |
+
return prompt
|
124 |
+
|
125 |
+
def select_method_llm(self, dataset_analysis: Dict[str, Any], variables: Dict[str, Any], is_rct: bool = False, llm: Optional[BaseChatModel] = None, excluded_methods: Optional[List[str]] = None) -> Dict[str, Any]:
|
126 |
+
"""
|
127 |
+
Apply LLM-based decision tree to select appropriate causal method.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
dataset_analysis: Dataset analysis results.
|
131 |
+
variables: Identified variables from query_interpreter.
|
132 |
+
is_rct: Boolean indicating if the data comes from an RCT.
|
133 |
+
llm: Langchain BaseChatModel instance for making the call.
|
134 |
+
excluded_methods: Optional list of method names to exclude from selection.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
Dict with selected method, justification, and assumptions.
|
138 |
+
Example:
|
139 |
+
{{
|
140 |
+
"selected_method": "difference_in_differences",
|
141 |
+
"method_justification": "Reasoning...",
|
142 |
+
"method_assumptions": ["Assumption 1", ...],
|
143 |
+
"alternative_methods": ["instrumental_variable"]
|
144 |
+
}}
|
145 |
+
"""
|
146 |
+
if not llm:
|
147 |
+
logger.error("LLM client not provided to DecisionTreeLLMEngine. Cannot select method.")
|
148 |
+
return {
|
149 |
+
"selected_method": CORRELATION_ANALYSIS,
|
150 |
+
"method_justification": "LLM client not provided. Defaulting to Correlation Analysis as causal inference method selection is not possible. This indicates association, not causation.",
|
151 |
+
"method_assumptions": METHOD_ASSUMPTIONS.get(CORRELATION_ANALYSIS, []),
|
152 |
+
"alternative_methods": []
|
153 |
+
}
|
154 |
+
|
155 |
+
prompt = self._construct_prompt(dataset_analysis, variables, is_rct, excluded_methods)
|
156 |
+
if self.verbose:
|
157 |
+
logger.info("LLM Prompt for method selection:")
|
158 |
+
logger.info(prompt)
|
159 |
+
|
160 |
+
messages = [HumanMessage(content=prompt)]
|
161 |
+
|
162 |
+
llm_output_str = "" # Initialize llm_output_str here
|
163 |
+
try:
|
164 |
+
response = llm.invoke(messages)
|
165 |
+
llm_output_str = response.content.strip()
|
166 |
+
|
167 |
+
if self.verbose:
|
168 |
+
logger.info(f"LLM Raw Output: {llm_output_str}")
|
169 |
+
|
170 |
+
# Attempt to parse the JSON output
|
171 |
+
# The LLM might sometimes include explanations outside the JSON block.
|
172 |
+
# Try to extract JSON from within ```json ... ``` if present.
|
173 |
+
if "```json" in llm_output_str:
|
174 |
+
json_str = llm_output_str.split("```json")[1].split("```")[0].strip()
|
175 |
+
elif "```" in llm_output_str and llm_output_str.startswith("{") == False : # if it doesn't start with { then likely ```{}```
|
176 |
+
json_str = llm_output_str.split("```")[1].strip()
|
177 |
+
else: # Assume the entire string is the JSON if no triple backticks
|
178 |
+
json_str = llm_output_str
|
179 |
+
|
180 |
+
parsed_response = json.loads(json_str)
|
181 |
+
|
182 |
+
selected_method = parsed_response.get("selected_method")
|
183 |
+
justification = parsed_response.get("method_justification", "No justification provided by LLM.")
|
184 |
+
alternatives = parsed_response.get("alternative_methods", [])
|
185 |
+
|
186 |
+
if selected_method and selected_method in METHOD_ASSUMPTIONS:
|
187 |
+
logger.info(f"LLM selected method: {selected_method}")
|
188 |
+
return {
|
189 |
+
"selected_method": selected_method,
|
190 |
+
"method_justification": justification,
|
191 |
+
"method_assumptions": METHOD_ASSUMPTIONS[selected_method],
|
192 |
+
"alternative_methods": alternatives
|
193 |
+
}
|
194 |
+
else:
|
195 |
+
logger.warning(f"LLM selected an invalid or unknown method: '{selected_method}'. Or method not in METHOD_ASSUMPTIONS. Raw response: {llm_output_str}")
|
196 |
+
fallback_justification = f"LLM output was problematic (selected: {selected_method}). Defaulting to Correlation Analysis. LLM Raw Response: {llm_output_str}"
|
197 |
+
selected_method = CORRELATION_ANALYSIS
|
198 |
+
justification = fallback_justification
|
199 |
+
|
200 |
+
except json.JSONDecodeError as e:
|
201 |
+
logger.error(f"Failed to parse JSON response from LLM: {e}. Raw response: {llm_output_str}", exc_info=True)
|
202 |
+
fallback_justification = f"LLM response was not valid JSON. Defaulting to Correlation Analysis. Error: {e}. LLM Raw Response: {llm_output_str}"
|
203 |
+
selected_method = CORRELATION_ANALYSIS
|
204 |
+
justification = fallback_justification
|
205 |
+
alternatives = []
|
206 |
+
except Exception as e:
|
207 |
+
logger.error(f"Error during LLM call for method selection: {e}. Raw response: {llm_output_str}", exc_info=True)
|
208 |
+
fallback_justification = f"An unexpected error occurred during LLM method selection. Defaulting to Correlation Analysis. Error: {e}. LLM Raw Response: {llm_output_str}"
|
209 |
+
selected_method = CORRELATION_ANALYSIS
|
210 |
+
justification = fallback_justification
|
211 |
+
alternatives = []
|
212 |
+
|
213 |
+
return {
|
214 |
+
"selected_method": selected_method,
|
215 |
+
"method_justification": justification,
|
216 |
+
"method_assumptions": METHOD_ASSUMPTIONS.get(selected_method, []),
|
217 |
+
"alternative_methods": alternatives
|
218 |
+
}
|
auto_causal/components/explanation_generator.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Explanation generator component for causal inference methods.
|
3 |
+
|
4 |
+
This module generates explanations for causal inference methods, including
|
5 |
+
what the method does, its assumptions, and how it will be applied to the dataset.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import Dict, Any, List, Optional
|
9 |
+
from langchain_core.language_models import BaseChatModel # For LLM type hint
|
10 |
+
|
11 |
+
|
12 |
+
def generate_explanation(
|
13 |
+
method_info: Dict[str, Any],
|
14 |
+
validation_result: Dict[str, Any],
|
15 |
+
variables: Dict[str, Any],
|
16 |
+
results: Dict[str, Any],
|
17 |
+
dataset_analysis: Optional[Dict[str, Any]] = None,
|
18 |
+
dataset_description: Optional[str] = None,
|
19 |
+
llm: Optional[BaseChatModel] = None
|
20 |
+
) -> Dict[str, str]:
|
21 |
+
"""
|
22 |
+
Generates a comprehensive explanation text for the causal analysis.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
method_info: Dictionary containing selected method details.
|
26 |
+
validation_result: Dictionary containing method validation results.
|
27 |
+
variables: Dictionary containing identified variables.
|
28 |
+
results: Dictionary containing numerical results from the method execution.
|
29 |
+
dataset_analysis: Optional dictionary with dataset analysis details.
|
30 |
+
dataset_description: Optional string describing the dataset.
|
31 |
+
llm: Optional language model instance (for potential future use in generation).
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
Dictionary containing the final explanation text.
|
35 |
+
"""
|
36 |
+
method = method_info.get("method_name")
|
37 |
+
|
38 |
+
# Handle potential None for validation_result
|
39 |
+
if validation_result and validation_result.get("valid") is False:
|
40 |
+
method = validation_result.get("recommended_method", method)
|
41 |
+
|
42 |
+
# Get components
|
43 |
+
method_explanation = get_method_explanation(method)
|
44 |
+
assumption_explanations = explain_assumptions(method_info.get("assumptions", []))
|
45 |
+
application_explanation = explain_application(method, variables.get("treatment_variable"),
|
46 |
+
variables.get("outcome_variable"),
|
47 |
+
variables.get("covariates", []), variables)
|
48 |
+
limitations_explanation = explain_limitations(method, validation_result.get("concerns", []) if validation_result else [])
|
49 |
+
interpretation_guide = generate_interpretation_guide(method, variables.get("treatment_variable"),
|
50 |
+
variables.get("outcome_variable"))
|
51 |
+
|
52 |
+
# --- Extract Numerical Results ---
|
53 |
+
effect_estimate = results.get("effect_estimate")
|
54 |
+
effect_se = results.get("effect_se")
|
55 |
+
ci = results.get("confidence_interval")
|
56 |
+
p_value = results.get("p_value") # Assuming method executor returns p_value
|
57 |
+
|
58 |
+
# --- Assemble Final Text ---
|
59 |
+
final_text = f"**Method Used:** {method_info.get('method_name', method)}\n\n"
|
60 |
+
final_text += f"**Method Explanation:**\n{method_explanation}\n\n"
|
61 |
+
|
62 |
+
# Add Results Section
|
63 |
+
final_text += "**Results:**\n"
|
64 |
+
if effect_estimate is not None:
|
65 |
+
final_text += f"- Estimated Causal Effect: {effect_estimate:.4f}\n"
|
66 |
+
if effect_se is not None:
|
67 |
+
final_text += f"- Standard Error: {effect_se:.4f}\n"
|
68 |
+
if ci and ci[0] is not None and ci[1] is not None:
|
69 |
+
final_text += f"- 95% Confidence Interval: [{ci[0]:.4f}, {ci[1]:.4f}]\n"
|
70 |
+
if p_value is not None:
|
71 |
+
final_text += f"- P-value: {p_value:.4f}\n"
|
72 |
+
final_text += "\n"
|
73 |
+
|
74 |
+
final_text += f"**Interpretation Guide:**\n{interpretation_guide}\n\n"
|
75 |
+
final_text += f"**Assumptions:**\n"
|
76 |
+
for item in assumption_explanations:
|
77 |
+
final_text += f"- {item['assumption']}: {item['explanation']}\n"
|
78 |
+
final_text += "\n"
|
79 |
+
final_text += f"**Limitations:**\n{limitations_explanation}\n\n"
|
80 |
+
|
81 |
+
return {
|
82 |
+
"final_explanation_text": final_text
|
83 |
+
# Return only the final text, the tool wrapper adds workflow state
|
84 |
+
}
|
85 |
+
|
86 |
+
|
87 |
+
def get_method_explanation(method: str) -> str:
|
88 |
+
"""
|
89 |
+
Get explanation for what the method does.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
method: Causal inference method name
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
String explaining what the method does
|
96 |
+
"""
|
97 |
+
explanations = {
|
98 |
+
"propensity_score_matching": (
|
99 |
+
"Propensity Score Matching is a statistical technique that attempts to estimate the effect "
|
100 |
+
"of a treatment by accounting for covariates that predict receiving the treatment. "
|
101 |
+
"It creates matched sets of treated and untreated subjects who share similar characteristics, "
|
102 |
+
"allowing for a more fair comparison between groups."
|
103 |
+
),
|
104 |
+
"regression_adjustment": (
|
105 |
+
"Regression Adjustment is a method that uses regression models to estimate causal effects "
|
106 |
+
"by controlling for covariates. It models the outcome as a function of the treatment and "
|
107 |
+
"other potential confounding variables, allowing the isolation of the treatment effect."
|
108 |
+
),
|
109 |
+
"instrumental_variable": (
|
110 |
+
"The Instrumental Variable method addresses issues of endogeneity or unmeasured confounding "
|
111 |
+
"by using an 'instrument' - a variable that affects the treatment but not the outcome directly. "
|
112 |
+
"It effectively finds the natural experiment hidden in your data to estimate causal effects."
|
113 |
+
),
|
114 |
+
"difference_in_differences": (
|
115 |
+
"Difference-in-Differences compares the changes in outcomes over time between a group that "
|
116 |
+
"receives a treatment and a group that does not. It controls for time-invariant unobserved "
|
117 |
+
"confounders by looking at differences in trends rather than absolute values."
|
118 |
+
),
|
119 |
+
"regression_discontinuity": (
|
120 |
+
"Regression Discontinuity Design exploits a threshold or cutoff rule that determines treatment "
|
121 |
+
"assignment. By comparing observations just above and below this threshold, where treatment "
|
122 |
+
"status changes but other characteristics remain similar, it estimates the local causal effect."
|
123 |
+
),
|
124 |
+
"backdoor_adjustment": (
|
125 |
+
"Backdoor Adjustment controls for confounding variables that create 'backdoor paths' between "
|
126 |
+
"treatment and outcome variables in a causal graph. By conditioning on these variables, "
|
127 |
+
"it blocks the non-causal associations, allowing for identification of the causal effect."
|
128 |
+
),
|
129 |
+
}
|
130 |
+
|
131 |
+
return explanations.get(method,
|
132 |
+
f"The {method} method is a causal inference technique used to estimate "
|
133 |
+
f"causal effects from observational data.")
|
134 |
+
|
135 |
+
|
136 |
+
def explain_assumptions(assumptions: List[str]) -> List[Dict[str, str]]:
|
137 |
+
"""
|
138 |
+
Explain each assumption of the method.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
assumptions: List of assumption names
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
List of dictionaries with assumption name and explanation
|
145 |
+
"""
|
146 |
+
assumption_details = {
|
147 |
+
"Treatment is randomly assigned": (
|
148 |
+
"This assumes that treatment assignment is not influenced by any factors "
|
149 |
+
"related to the outcome, similar to a randomized controlled trial. "
|
150 |
+
"In observational data, this assumption rarely holds without conditioning on confounders."
|
151 |
+
),
|
152 |
+
"No systematic differences between treatment and control groups": (
|
153 |
+
"Treatment and control groups should be balanced on all relevant characteristics "
|
154 |
+
"except for the treatment itself. Any systematic differences could bias the estimate."
|
155 |
+
),
|
156 |
+
"No unmeasured confounders (conditional ignorability)": (
|
157 |
+
"All variables that simultaneously affect the treatment and outcome are measured and "
|
158 |
+
"included in the analysis. If important confounders are missing, the estimated causal "
|
159 |
+
"effect will be biased."
|
160 |
+
),
|
161 |
+
"Sufficient overlap between treatment and control groups": (
|
162 |
+
"For each combination of covariate values, there should be both treated and untreated "
|
163 |
+
"units. Without overlap, the model must extrapolate, which can lead to biased estimates."
|
164 |
+
),
|
165 |
+
"Treatment assignment is not deterministic given covariates": (
|
166 |
+
"No combination of covariates should perfectly predict treatment assignment. "
|
167 |
+
"If treatment is deterministic for some units, causal comparisons become impossible."
|
168 |
+
),
|
169 |
+
"Instrument is correlated with treatment (relevance)": (
|
170 |
+
"The instrumental variable must have a clear and preferably strong effect on the "
|
171 |
+
"treatment variable. Weak instruments lead to imprecise and potentially biased estimates."
|
172 |
+
),
|
173 |
+
"Instrument affects outcome only through treatment (exclusion restriction)": (
|
174 |
+
"The instrumental variable must not directly affect the outcome except through its "
|
175 |
+
"effect on the treatment. If this assumption fails, the causal estimate will be biased."
|
176 |
+
),
|
177 |
+
"Instrument is as good as randomly assigned (exogeneity)": (
|
178 |
+
"The instrumental variable must not be correlated with any confounders of the "
|
179 |
+
"treatment-outcome relationship. It should be as good as randomly assigned."
|
180 |
+
),
|
181 |
+
"Parallel trends between treatment and control groups": (
|
182 |
+
"In the absence of treatment, the difference between treatment and control groups "
|
183 |
+
"would have remained constant over time. This is the key identifying assumption for "
|
184 |
+
"difference-in-differences and cannot be directly tested for the post-treatment period."
|
185 |
+
),
|
186 |
+
"No spillover effects between groups": (
|
187 |
+
"The treatment of one unit should not affect the outcomes of other units. "
|
188 |
+
"If spillovers exist, they can bias the estimated treatment effect."
|
189 |
+
),
|
190 |
+
"No anticipation effects before treatment": (
|
191 |
+
"Units should not change their behavior in anticipation of future treatment. "
|
192 |
+
"If anticipation effects exist, the pre-treatment trends may already reflect treatment effects."
|
193 |
+
),
|
194 |
+
"Stable composition of treatment and control groups": (
|
195 |
+
"The composition of treatment and control groups should remain stable over time. "
|
196 |
+
"If units move between groups based on outcomes, this can bias the estimates."
|
197 |
+
),
|
198 |
+
"Units cannot precisely manipulate their position around the cutoff": (
|
199 |
+
"In regression discontinuity, units must not be able to precisely control their position "
|
200 |
+
"relative to the cutoff. If they can, the randomization-like property of the design fails."
|
201 |
+
),
|
202 |
+
"No other variables change discontinuously at the cutoff": (
|
203 |
+
"Any discontinuity in outcomes at the cutoff should be attributable only to the change "
|
204 |
+
"in treatment status. If other relevant variables also change at the cutoff, the causal "
|
205 |
+
"interpretation is compromised."
|
206 |
+
),
|
207 |
+
"The relationship between running variable and outcome is continuous at the cutoff": (
|
208 |
+
"In the absence of treatment, the relationship between the running variable and the "
|
209 |
+
"outcome would be continuous at the cutoff. This allows attributing any observed "
|
210 |
+
"discontinuity to the treatment effect."
|
211 |
+
),
|
212 |
+
"The model correctly specifies the relationship between variables": (
|
213 |
+
"The functional form of the relationship between variables in the model should correctly "
|
214 |
+
"capture the true relationship in the data. Misspecification can lead to biased estimates."
|
215 |
+
),
|
216 |
+
"No reverse causality": (
|
217 |
+
"The treatment must cause the outcome, not the other way around. If the outcome affects "
|
218 |
+
"the treatment, the estimated relationship will not have a causal interpretation."
|
219 |
+
),
|
220 |
+
}
|
221 |
+
|
222 |
+
return [
|
223 |
+
{"assumption": assumption, "explanation": assumption_details.get(assumption,
|
224 |
+
"This is a key assumption for the selected causal inference method.")}
|
225 |
+
for assumption in assumptions
|
226 |
+
]
|
227 |
+
|
228 |
+
|
229 |
+
def explain_application(method: str, treatment: str, outcome: str,
|
230 |
+
covariates: List[str], variables: Dict[str, Any]) -> str:
|
231 |
+
"""
|
232 |
+
Explain how the method will be applied to the dataset.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
method: Causal inference method name
|
236 |
+
treatment: Treatment variable name
|
237 |
+
outcome: Outcome variable name
|
238 |
+
covariates: List of covariate names
|
239 |
+
variables: Dictionary of identified variables
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
String explaining the application
|
243 |
+
"""
|
244 |
+
covariate_str = ", ".join(covariates[:3])
|
245 |
+
if len(covariates) > 3:
|
246 |
+
covariate_str += f", and {len(covariates) - 3} other variables"
|
247 |
+
|
248 |
+
applications = {
|
249 |
+
"propensity_score_matching": (
|
250 |
+
f"I will estimate the propensity scores (probability of receiving treatment) for each "
|
251 |
+
f"observation based on the covariates ({covariate_str}). Then, I'll match treated and "
|
252 |
+
f"untreated units with similar propensity scores to create balanced comparison groups. "
|
253 |
+
f"Finally, I'll calculate the difference in {outcome} between these matched groups to "
|
254 |
+
f"estimate the causal effect of {treatment}."
|
255 |
+
),
|
256 |
+
"regression_adjustment": (
|
257 |
+
f"I will build a regression model with {outcome} as the dependent variable and "
|
258 |
+
f"{treatment} as the independent variable of interest, while controlling for "
|
259 |
+
f"potential confounders ({covariate_str}). The coefficient of {treatment} will "
|
260 |
+
f"represent the estimated causal effect after adjusting for these covariates."
|
261 |
+
),
|
262 |
+
"instrumental_variable": (
|
263 |
+
f"I will use {variables.get('instrument_variable')} as an instrumental variable for "
|
264 |
+
f"{treatment}. First, I'll estimate how the instrument affects {treatment} (first stage). "
|
265 |
+
f"Then, I'll use these predictions to estimate how changes in {treatment} that are induced "
|
266 |
+
f"by the instrument affect {outcome} (second stage). This two-stage approach helps "
|
267 |
+
f"address potential unmeasured confounding."
|
268 |
+
),
|
269 |
+
"difference_in_differences": (
|
270 |
+
f"I will compare the change in {outcome} before and after the intervention for the "
|
271 |
+
f"group receiving {treatment}, relative to the change in a control group that didn't "
|
272 |
+
f"receive the treatment. This approach controls for time-invariant confounders and "
|
273 |
+
f"common time trends that affect both groups."
|
274 |
+
),
|
275 |
+
"regression_discontinuity": (
|
276 |
+
f"I will focus on observations close to the cutoff value "
|
277 |
+
f"({variables.get('cutoff_value')}) of the running variable "
|
278 |
+
f"({variables.get('running_variable')}), where treatment assignment changes. "
|
279 |
+
f"By comparing outcomes just above and below this threshold, I can estimate "
|
280 |
+
f"the local causal effect of {treatment} on {outcome}."
|
281 |
+
),
|
282 |
+
"backdoor_adjustment": (
|
283 |
+
f"I will control for the identified confounding variables ({covariate_str}) to "
|
284 |
+
f"block all backdoor paths between {treatment} and {outcome}. This may involve "
|
285 |
+
f"stratification, regression adjustment, or inverse probability weighting, depending "
|
286 |
+
f"on the data characteristics."
|
287 |
+
),
|
288 |
+
}
|
289 |
+
|
290 |
+
return applications.get(method,
|
291 |
+
f"I will apply the {method} method to estimate the causal effect of "
|
292 |
+
f"{treatment} on {outcome}, controlling for relevant confounding factors "
|
293 |
+
f"where appropriate.")
|
294 |
+
|
295 |
+
|
296 |
+
def explain_limitations(method: str, concerns: List[str]) -> str:
|
297 |
+
"""
|
298 |
+
Explain the limitations of the method based on validation concerns.
|
299 |
+
|
300 |
+
Args:
|
301 |
+
method: Causal inference method name
|
302 |
+
concerns: List of concerns from validation
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
String explaining the limitations
|
306 |
+
"""
|
307 |
+
method_limitations = {
|
308 |
+
"propensity_score_matching": (
|
309 |
+
"Propensity Score Matching can only account for observed confounders, and its "
|
310 |
+
"effectiveness depends on having good overlap between treatment and control groups. "
|
311 |
+
"It may also be sensitive to model specification for the propensity score estimation."
|
312 |
+
),
|
313 |
+
"regression_adjustment": (
|
314 |
+
"Regression Adjustment relies heavily on correct model specification and can only "
|
315 |
+
"control for observed confounders. Extrapolation to regions with limited data can lead "
|
316 |
+
"to unreliable estimates, and the method may be sensitive to outliers."
|
317 |
+
),
|
318 |
+
"instrumental_variable": (
|
319 |
+
"Instrumental Variable estimation can be imprecise with weak instruments and is "
|
320 |
+
"sensitive to violations of the exclusion restriction. The estimated effect is a local "
|
321 |
+
"average treatment effect for 'compliers', which may not generalize to the entire population."
|
322 |
+
),
|
323 |
+
"difference_in_differences": (
|
324 |
+
"Difference-in-Differences relies on the parallel trends assumption, which cannot be fully "
|
325 |
+
"tested for the post-treatment period. It may be sensitive to the choice of comparison group "
|
326 |
+
"and can be biased if there are time-varying confounders or anticipation effects."
|
327 |
+
),
|
328 |
+
"regression_discontinuity": (
|
329 |
+
"Regression Discontinuity provides estimates that are local to the cutoff point and may not "
|
330 |
+
"generalize to units far from this threshold. It also requires sufficient data around the "
|
331 |
+
"cutoff and is sensitive to the choice of bandwidth and functional form."
|
332 |
+
),
|
333 |
+
"backdoor_adjustment": (
|
334 |
+
"Backdoor Adjustment requires correctly identifying all confounding variables and their "
|
335 |
+
"relationships. It depends on the assumption of no unmeasured confounders and may be "
|
336 |
+
"sensitive to model misspecification in complex settings."
|
337 |
+
),
|
338 |
+
}
|
339 |
+
|
340 |
+
base_limitation = method_limitations.get(method,
|
341 |
+
f"The {method} method has general limitations in terms of its assumptions and applicability.")
|
342 |
+
|
343 |
+
# Add specific concerns if any
|
344 |
+
if concerns:
|
345 |
+
concern_text = " Additionally, specific concerns for this analysis include: " + \
|
346 |
+
"; ".join(concerns) + "."
|
347 |
+
return base_limitation + concern_text
|
348 |
+
|
349 |
+
return base_limitation
|
350 |
+
|
351 |
+
|
352 |
+
def generate_interpretation_guide(method: str, treatment: str, outcome: str) -> str:
|
353 |
+
"""
|
354 |
+
Generate guide for interpreting the results.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
method: Causal inference method name
|
358 |
+
treatment: Treatment variable name
|
359 |
+
outcome: Outcome variable name
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
String with interpretation guide
|
363 |
+
"""
|
364 |
+
interpretation_guides = {
|
365 |
+
"propensity_score_matching": (
|
366 |
+
f"The estimated effect represents the Average Treatment Effect (ATE) or the Average "
|
367 |
+
f"Treatment Effect on the Treated (ATT), depending on the specific matching approach. "
|
368 |
+
f"It can be interpreted as the expected change in {outcome} if a unit were to receive "
|
369 |
+
f"{treatment}, compared to not receiving it, for units with similar covariate values."
|
370 |
+
),
|
371 |
+
"regression_adjustment": (
|
372 |
+
f"The coefficient of {treatment} in the regression model represents the estimated "
|
373 |
+
f"average causal effect on {outcome}, holding all included covariates constant. "
|
374 |
+
f"For binary treatments, it's the expected difference in outcomes between treated "
|
375 |
+
f"and untreated units with the same covariate values."
|
376 |
+
),
|
377 |
+
"instrumental_variable": (
|
378 |
+
f"The estimated effect represents the Local Average Treatment Effect (LATE) for 'compliers' "
|
379 |
+
f"- units whose treatment status is influenced by the instrument. It can be interpreted as "
|
380 |
+
f"the average effect of {treatment} on {outcome} for this specific subpopulation."
|
381 |
+
),
|
382 |
+
"difference_in_differences": (
|
383 |
+
f"The estimated effect represents the average causal impact of {treatment} on {outcome}, "
|
384 |
+
f"under the assumption that treatment and control groups would have followed parallel "
|
385 |
+
f"trends in the absence of treatment. It accounts for both time-invariant differences "
|
386 |
+
f"between groups and common time trends."
|
387 |
+
),
|
388 |
+
"regression_discontinuity": (
|
389 |
+
f"The estimated effect represents the local causal impact of {treatment} on {outcome} "
|
390 |
+
f"at the cutoff point. It can be interpreted as the expected difference in outcomes "
|
391 |
+
f"for units just above versus just below the threshold, where treatment status changes."
|
392 |
+
),
|
393 |
+
"backdoor_adjustment": (
|
394 |
+
f"The estimated effect represents the average causal effect of {treatment} on {outcome} "
|
395 |
+
f"after controlling for all identified confounding variables. It can be interpreted as "
|
396 |
+
f"the expected difference in outcomes if a unit were to receive versus not receive the "
|
397 |
+
f"treatment, holding all confounding factors constant."
|
398 |
+
),
|
399 |
+
}
|
400 |
+
|
401 |
+
return interpretation_guides.get(method,
|
402 |
+
f"The estimated effect represents the causal impact of {treatment} on {outcome}, "
|
403 |
+
f"given the assumptions of the method are met. Careful consideration of these "
|
404 |
+
f"assumptions is needed for valid causal interpretation.")
|
auto_causal/components/input_parser.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Input parser component for extracting information from causal queries.
|
3 |
+
|
4 |
+
This module provides functionality to parse user queries and extract key
|
5 |
+
elements such as the causal question, relevant variables, and constraints.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import re
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
import logging # Added for better logging
|
12 |
+
from typing import Dict, List, Any, Optional, Union
|
13 |
+
import pandas as pd
|
14 |
+
from pydantic import BaseModel, Field, ValidationError
|
15 |
+
from functools import partial # Import partial
|
16 |
+
|
17 |
+
# Add dotenv import
|
18 |
+
from dotenv import load_dotenv
|
19 |
+
|
20 |
+
# LangChain Imports
|
21 |
+
from langchain_openai import ChatOpenAI # Example, replace if using another provider
|
22 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
23 |
+
from langchain_core.exceptions import OutputParserException # Correct path
|
24 |
+
from langchain_core.language_models import BaseChatModel # Import BaseChatModel
|
25 |
+
|
26 |
+
# --- Load .env file ---
|
27 |
+
load_dotenv() # Load environment variables from .env file
|
28 |
+
|
29 |
+
# --- Configure Logging ---
|
30 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
# --- Instantiate LLM Client ---
|
34 |
+
# Ensure OPENAI_API_KEY environment variable is set
|
35 |
+
# Consider making model name configurable
|
36 |
+
try:
|
37 |
+
# Using with_structured_output later, so instantiate base model here
|
38 |
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
39 |
+
# Add a check or allow configuration for different providers if needed
|
40 |
+
except ImportError:
|
41 |
+
logger.error("langchain_openai not installed. Please install it to use OpenAI models.")
|
42 |
+
llm = None
|
43 |
+
except Exception as e:
|
44 |
+
logger.error(f"Error initializing LLM: {e}. Input parsing will rely on fallbacks.")
|
45 |
+
llm = None
|
46 |
+
|
47 |
+
# --- Pydantic Models for Structured Output ---
|
48 |
+
class ParsedVariables(BaseModel):
|
49 |
+
treatment: List[str] = Field(default_factory=list, description="Variable(s) representing the treatment/intervention.")
|
50 |
+
outcome: List[str] = Field(default_factory=list, description="Variable(s) representing the outcome/result.")
|
51 |
+
covariates_mentioned: Optional[List[str]] = Field(default_factory=list, description="Covariate/control variable(s) explicitly mentioned in the query.")
|
52 |
+
grouping_vars: Optional[List[str]] = Field(default_factory=list, description="Variable(s) identifying groups or units for analysis.")
|
53 |
+
instruments_mentioned: Optional[List[str]] = Field(default_factory=list, description="Potential instrumental variable(s) mentioned.")
|
54 |
+
|
55 |
+
class ParsedQueryInfo(BaseModel):
|
56 |
+
query_type: str = Field(..., description="Type of query (e.g., EFFECT_ESTIMATION, COUNTERFACTUAL, CORRELATION, DESCRIPTIVE, OTHER). Required.")
|
57 |
+
variables: ParsedVariables = Field(..., description="Variables identified in the query.")
|
58 |
+
constraints: Optional[List[str]] = Field(default_factory=list, description="Constraints or conditions mentioned (e.g., 'X > 10', 'country = USA').")
|
59 |
+
dataset_path_mentioned: Optional[str] = Field(None, description="Dataset path explicitly mentioned in the query, if any.")
|
60 |
+
|
61 |
+
# Add Pydantic model for path extraction
|
62 |
+
class ExtractedPath(BaseModel):
|
63 |
+
dataset_path: Optional[str] = Field(None, description="File path or URL for the dataset mentioned in the query.")
|
64 |
+
|
65 |
+
# --- End Pydantic Models ---
|
66 |
+
|
67 |
+
def _build_llm_prompt(query: str, dataset_info: Optional[Dict] = None) -> str:
|
68 |
+
"""Builds the prompt for the LLM to extract query information."""
|
69 |
+
dataset_context = "No dataset context provided."
|
70 |
+
if dataset_info:
|
71 |
+
columns = dataset_info.get('columns', [])
|
72 |
+
column_details = "\n".join([f"- {col} (Type: {dataset_info.get('column_types', {}).get(col, 'Unknown')})" for col in columns])
|
73 |
+
sample_rows = dataset_info.get('sample_rows', 'Not available')
|
74 |
+
# Ensure sample rows are formatted reasonably
|
75 |
+
if isinstance(sample_rows, list):
|
76 |
+
sample_rows_str = json.dumps(sample_rows[:3], indent=2) # Show first 3 sample rows
|
77 |
+
elif isinstance(sample_rows, str):
|
78 |
+
sample_rows_str = sample_rows
|
79 |
+
else:
|
80 |
+
sample_rows_str = 'Not available'
|
81 |
+
|
82 |
+
dataset_context = f"""
|
83 |
+
Dataset Context:
|
84 |
+
Columns:
|
85 |
+
{column_details}
|
86 |
+
Sample Rows (first few):
|
87 |
+
{sample_rows_str}
|
88 |
+
"""
|
89 |
+
|
90 |
+
prompt = f"""
|
91 |
+
Analyze the following causal query **strictly in the context of the provided dataset information (if available)**. Identify the query type, key variables (mapping query terms to actual column names when possible), constraints, and any explicitly mentioned dataset path.
|
92 |
+
|
93 |
+
User Query: "{query}"
|
94 |
+
|
95 |
+
{dataset_context}
|
96 |
+
|
97 |
+
# Add specific guidance for query types
|
98 |
+
Guidance for Identifying Query Type:
|
99 |
+
- EFFECT_ESTIMATION: Look for keywords like 'effect', 'impact', 'influence', 'cause', 'affect', 'consequence'. Also consider questions asking "how does X affect Y?" or comparing outcomes between groups based on an intervention.
|
100 |
+
- COUNTERFACTUAL: Look for hypothetical scenarios, often using phrases like 'what if', 'if X had been', 'would Y have changed', 'imagine if', 'counterfactual'.
|
101 |
+
- CORRELATION: Look for keywords like 'correlation', 'association', 'relationship', 'linked to', 'related to'. These queries ask about statistical relationships without necessarily implying causality.
|
102 |
+
- DESCRIPTIVE: These queries ask for summaries, descriptions, trends, or statistics about the data without investigating causal links or relationships (e.g., "Show sales over time", "What is the average age?").
|
103 |
+
- OTHER: Use this if the query does not fit any of the above categories.
|
104 |
+
|
105 |
+
Choose the most appropriate type from: EFFECT_ESTIMATION, COUNTERFACTUAL, CORRELATION, DESCRIPTIVE, OTHER.
|
106 |
+
|
107 |
+
Variable Roles to Identify:
|
108 |
+
- treatment: The intervention or variable whose effect is being studied.
|
109 |
+
- outcome: The result or variable being measured.
|
110 |
+
- covariates_mentioned: Variables explicitly mentioned to control for or adjust for.
|
111 |
+
- grouping_vars: Variables identifying specific subgroups for analysis (e.g., 'for men', 'in the sales department').
|
112 |
+
- instruments_mentioned: Variables explicitly mentioned as potential instruments.
|
113 |
+
|
114 |
+
Constraints: Conditions applied to the analysis (e.g., filters on columns, specific time periods).
|
115 |
+
|
116 |
+
Dataset Path Mentioned: Extract the file path or URL if explicitly stated in the query.
|
117 |
+
|
118 |
+
**Output ONLY a valid JSON object** matching this exact schema (no explanations, notes, or surrounding text):
|
119 |
+
```json
|
120 |
+
{{
|
121 |
+
"query_type": "<Identified Query Type>",
|
122 |
+
"variables": {{
|
123 |
+
"treatment": ["<Treatment Variable(s) Mentioned>"],
|
124 |
+
"outcome": ["<Outcome Variable(s) Mentioned>"],
|
125 |
+
"covariates_mentioned": ["<Covariate(s) Mentioned>"],
|
126 |
+
"grouping_vars": ["<Grouping Variable(s) Mentioned>"],
|
127 |
+
"instruments_mentioned": ["<Instrument(s) Mentioned>"]
|
128 |
+
}},
|
129 |
+
"constraints": ["<Constraint 1>", "<Constraint 2>"],
|
130 |
+
"dataset_path_mentioned": "<Path Mentioned or null>"
|
131 |
+
}}
|
132 |
+
```
|
133 |
+
If Dataset Context is provided, ensure variable names in the output JSON correspond to actual column names where possible. If no context is provided, or if a mentioned variable doesn't map directly, use the phrasing from the query.
|
134 |
+
Respond with only the JSON object.
|
135 |
+
"""
|
136 |
+
return prompt
|
137 |
+
|
138 |
+
def _validate_llm_output(parsed_info: ParsedQueryInfo, dataset_info: Optional[Dict] = None) -> bool:
|
139 |
+
"""Perform basic assertions on the parsed LLM output."""
|
140 |
+
# 1. Check required fields exist (Pydantic handles this on parsing)
|
141 |
+
# 2. Check query type is one of the allowed types (can add enum to Pydantic later)
|
142 |
+
allowed_types = {"EFFECT_ESTIMATION", "COUNTERFACTUAL", "CORRELATION", "DESCRIPTIVE", "OTHER"}
|
143 |
+
print(parsed_info)
|
144 |
+
assert parsed_info.query_type in allowed_types, f"Invalid query_type: {parsed_info.query_type}"
|
145 |
+
|
146 |
+
# 3. Check that if it's an effect query, treatment and outcome are likely present
|
147 |
+
if parsed_info.query_type == "EFFECT_ESTIMATION":
|
148 |
+
# Check that the lists are not empty
|
149 |
+
assert parsed_info.variables.treatment, "Treatment variable list is empty for effect query."
|
150 |
+
assert parsed_info.variables.outcome, "Outcome variable list is empty for effect query."
|
151 |
+
|
152 |
+
# 4. If dataset_info provided, check if extracted variables exist in columns
|
153 |
+
if dataset_info and (columns := dataset_info.get('columns')):
|
154 |
+
all_extracted_vars = set()
|
155 |
+
for var_list in parsed_info.variables.model_dump().values(): # Iterate through variable lists
|
156 |
+
if var_list: # Ensure var_list is not None or empty
|
157 |
+
all_extracted_vars.update(var_list)
|
158 |
+
|
159 |
+
unknown_vars = all_extracted_vars - set(columns)
|
160 |
+
# Allow for non-column variables if context is missing? Maybe relax this.
|
161 |
+
# For now, strict check if columns are provided.
|
162 |
+
if unknown_vars:
|
163 |
+
logger.warning(f"LLM mentioned variables potentially not in dataset columns: {unknown_vars}")
|
164 |
+
# Decide if this should be a hard failure (AssertionError) or just a warning.
|
165 |
+
# Let's make it a hard failure for now to enforce mapping.
|
166 |
+
raise AssertionError(f"LLM hallucinated variables not in dataset columns: {unknown_vars}")
|
167 |
+
|
168 |
+
logger.info("LLM output validation passed.")
|
169 |
+
return True
|
170 |
+
|
171 |
+
def _extract_query_information_with_llm(query: str, dataset_info: Optional[Dict] = None, llm: Optional[BaseChatModel] = None, max_retries: int = 3) -> Optional[ParsedQueryInfo]:
|
172 |
+
"""Extracts query type, variables, and constraints using LLM with retries and validation."""
|
173 |
+
if not llm:
|
174 |
+
logger.error("LLM client not provided. Cannot perform LLM extraction.")
|
175 |
+
return None
|
176 |
+
|
177 |
+
last_error = None
|
178 |
+
# Bind the Pydantic model to the LLM for structured output
|
179 |
+
structured_llm = llm.with_structured_output(ParsedQueryInfo)
|
180 |
+
|
181 |
+
# Initial prompt construction
|
182 |
+
system_prompt_content = _build_llm_prompt(query, dataset_info)
|
183 |
+
messages = [HumanMessage(content=system_prompt_content)] # Start with just the detailed prompt as Human message
|
184 |
+
|
185 |
+
for attempt in range(max_retries):
|
186 |
+
logger.info(f"LLM Extraction Attempt {attempt + 1}/{max_retries}...")
|
187 |
+
try:
|
188 |
+
# --- Invoke LangChain LLM with structured output (using passed llm) ---
|
189 |
+
parsed_info = structured_llm.invoke(messages)
|
190 |
+
# ---------------------------------------------------
|
191 |
+
print(messages)
|
192 |
+
print('---------------------------------------------------')
|
193 |
+
print(parsed_info)
|
194 |
+
# Perform custom assertions/validation
|
195 |
+
if _validate_llm_output(parsed_info, dataset_info):
|
196 |
+
return parsed_info # Success!
|
197 |
+
|
198 |
+
# Catch errors specific to structured output parsing or Pydantic validation
|
199 |
+
except (OutputParserException, ValidationError, AssertionError) as e:
|
200 |
+
logger.warning(f"Validation/Parsing Error (Attempt {attempt + 1}): {e}")
|
201 |
+
last_error = e
|
202 |
+
# Add feedback message for retry
|
203 |
+
messages.append(SystemMessage(content=f"Your previous response failed validation: {str(e)}. Please revise your response to be valid JSON conforming strictly to the schema and ensure variable names exist in the dataset context."))
|
204 |
+
continue # Go to next retry
|
205 |
+
except Exception as e: # Catch other potential LLM API errors
|
206 |
+
logger.error(f"Unexpected LLM Error (Attempt {attempt + 1}): {e}", exc_info=True)
|
207 |
+
last_error = e
|
208 |
+
break # Stop retrying on unexpected API errors
|
209 |
+
|
210 |
+
logger.error(f"LLM extraction failed after {max_retries} attempts.")
|
211 |
+
if last_error:
|
212 |
+
logger.error(f"Last error: {last_error}")
|
213 |
+
return None # Indicate failure
|
214 |
+
|
215 |
+
# Add helper function to call LLM for path - needs llm argument
|
216 |
+
def _call_llm_for_path(query: str, llm: Optional[BaseChatModel] = None, max_retries: int = 2) -> Optional[str]:
|
217 |
+
"""Uses LLM as a fallback to extract just the dataset path."""
|
218 |
+
if not llm:
|
219 |
+
logger.warning("LLM client not provided. Cannot perform LLM path fallback.")
|
220 |
+
return None
|
221 |
+
|
222 |
+
logger.info("Attempting LLM fallback for dataset path extraction...")
|
223 |
+
path_extractor_llm = llm.with_structured_output(ExtractedPath)
|
224 |
+
prompt = f"Extract the dataset file path (e.g., /path/to/file.csv or https://...) mentioned in the following query. Respond ONLY with the JSON object.\nQuery: \"{query}\""
|
225 |
+
messages = [HumanMessage(content=prompt)]
|
226 |
+
last_error = None
|
227 |
+
|
228 |
+
for attempt in range(max_retries):
|
229 |
+
try:
|
230 |
+
parsed_info = path_extractor_llm.invoke(messages)
|
231 |
+
if parsed_info.dataset_path:
|
232 |
+
logger.info(f"LLM fallback extracted path: {parsed_info.dataset_path}")
|
233 |
+
return parsed_info.dataset_path
|
234 |
+
else:
|
235 |
+
logger.info("LLM fallback did not find a path.")
|
236 |
+
return None # LLM explicitly found no path
|
237 |
+
except (OutputParserException, ValidationError) as e:
|
238 |
+
logger.warning(f"LLM path extraction parsing/validation error (Attempt {attempt+1}): {e}")
|
239 |
+
last_error = e
|
240 |
+
messages.append(SystemMessage(content=f"Parsing Error: {e}. Please ensure you provide valid JSON with only the 'dataset_path' key."))
|
241 |
+
continue
|
242 |
+
except Exception as e:
|
243 |
+
logger.error(f"Unexpected LLM Error during path fallback (Attempt {attempt+1}): {e}", exc_info=True)
|
244 |
+
last_error = e
|
245 |
+
break # Don't retry on unexpected errors
|
246 |
+
|
247 |
+
logger.error(f"LLM path fallback failed after {max_retries} attempts. Last error: {last_error}")
|
248 |
+
return None
|
249 |
+
|
250 |
+
# Renamed and modified function for regex path extraction + LLM fallback - needs llm argument
|
251 |
+
def extract_dataset_path(query: str, llm: Optional[BaseChatModel] = None) -> Optional[str]:
|
252 |
+
"""
|
253 |
+
Extract dataset path from the query using regex patterns, with LLM fallback.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
query: The user's causal question text
|
257 |
+
llm: The shared LLM client instance for fallback.
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
String with dataset path or None if not found
|
261 |
+
"""
|
262 |
+
# --- Regex Part (existing logic) ---
|
263 |
+
# Check for common patterns indicating dataset paths
|
264 |
+
path_patterns = [
|
265 |
+
# More specific patterns first
|
266 |
+
r"(?:dataset|data|file) (?:at|in|from|located at) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?", # Handles subdirs in path
|
267 |
+
r"(?:use|using|analyze|analyse) (?:the |)(?:dataset|data|file) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?",
|
268 |
+
# Simpler patterns
|
269 |
+
r"[\"']([^\"']+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"']", # Path in quotes
|
270 |
+
r"([a-zA-Z0-9_/.:-]+[\\/][a-zA-Z0-9_.:-]+\.csv)", # More generic path-like structure ending in .csv
|
271 |
+
r"([^\"\'.,\s]+\.csv)" # Just a .csv file name (least specific)
|
272 |
+
]
|
273 |
+
|
274 |
+
for pattern in path_patterns:
|
275 |
+
matches = re.search(pattern, query, re.IGNORECASE)
|
276 |
+
if matches:
|
277 |
+
path = matches.group(1).strip()
|
278 |
+
|
279 |
+
# Basic check if it looks like a path
|
280 |
+
if '/' in path or '\\' in path or os.path.exists(path):
|
281 |
+
# Check if this is a valid file path immediately
|
282 |
+
if os.path.exists(path):
|
283 |
+
logger.info(f"Regex found existing path: {path}")
|
284 |
+
return path
|
285 |
+
|
286 |
+
# Check if it's in common data directories
|
287 |
+
data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
|
288 |
+
for data_dir in data_dir_paths:
|
289 |
+
potential_path = os.path.join(data_dir, os.path.basename(path))
|
290 |
+
if os.path.exists(potential_path):
|
291 |
+
logger.info(f"Regex found path in {data_dir}: {potential_path}")
|
292 |
+
return potential_path
|
293 |
+
|
294 |
+
# If not found but looks like a path, return it anyway - let downstream handle non-existence
|
295 |
+
logger.info(f"Regex found potential path (existence not verified): {path}")
|
296 |
+
return path
|
297 |
+
# Else: it might just be a word ending in .csv, ignore unless it exists
|
298 |
+
elif os.path.exists(path):
|
299 |
+
logger.info(f"Regex found existing path (simple pattern): {path}")
|
300 |
+
return path
|
301 |
+
|
302 |
+
# --- LLM Fallback ---
|
303 |
+
logger.info("Regex did not find dataset path. Trying LLM fallback...")
|
304 |
+
llm_fallback_path = _call_llm_for_path(query, llm=llm)
|
305 |
+
if llm_fallback_path:
|
306 |
+
# Optional: Add existence check here too? Or let downstream handle it.
|
307 |
+
# For now, return what LLM found.
|
308 |
+
return llm_fallback_path
|
309 |
+
|
310 |
+
logger.info("No dataset path found via regex or LLM fallback.")
|
311 |
+
return None
|
312 |
+
|
313 |
+
def parse_input(query: str, dataset_path_arg: Optional[str] = None, dataset_info: Optional[Dict] = None, llm: Optional[BaseChatModel] = None) -> Dict[str, Any]:
|
314 |
+
"""
|
315 |
+
Parse the user's causal query using LLM and regex.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
query: The user's causal question text.
|
319 |
+
dataset_path_arg: Path to dataset if provided directly as an argument.
|
320 |
+
dataset_info: Dictionary with dataset context (columns, types, etc.).
|
321 |
+
llm: The shared LLM client instance.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
Dict containing parsed query information.
|
325 |
+
"""
|
326 |
+
result = {
|
327 |
+
"original_query": query,
|
328 |
+
"dataset_path": dataset_path_arg, # Start with argument path
|
329 |
+
"query_type": "OTHER", # Default values
|
330 |
+
"extracted_variables": {},
|
331 |
+
"constraints": []
|
332 |
+
}
|
333 |
+
|
334 |
+
# --- 1. Use LLM for core NLP tasks ---
|
335 |
+
parsed_llm_info = _extract_query_information_with_llm(query, dataset_info, llm=llm)
|
336 |
+
|
337 |
+
if parsed_llm_info:
|
338 |
+
result["query_type"] = parsed_llm_info.query_type
|
339 |
+
result["extracted_variables"] = {k: v if v is not None else [] for k, v in parsed_llm_info.variables.model_dump().items()}
|
340 |
+
result["constraints"] = parsed_llm_info.constraints if parsed_llm_info.constraints is not None else []
|
341 |
+
llm_mentioned_path = parsed_llm_info.dataset_path_mentioned
|
342 |
+
else:
|
343 |
+
logger.warning("LLM-based query information extraction failed.")
|
344 |
+
llm_mentioned_path = None
|
345 |
+
# Consider falling back to old regex methods here if critical
|
346 |
+
# logger.info("Falling back to regex-based parsing (if implemented).")
|
347 |
+
|
348 |
+
# --- 2. Determine Dataset Path (Hybrid Approach) ---
|
349 |
+
final_dataset_path = dataset_path_arg # Priority 1: Explicit argument
|
350 |
+
|
351 |
+
# Pass llm instance to the path extractor for its fallback mechanism
|
352 |
+
path_extractor = partial(extract_dataset_path, llm=llm)
|
353 |
+
|
354 |
+
if not final_dataset_path:
|
355 |
+
# Priority 2: Path mentioned in query (extracted by main LLM call)
|
356 |
+
if llm_mentioned_path and os.path.exists(llm_mentioned_path):
|
357 |
+
logger.info(f"Using dataset path mentioned by LLM: {llm_mentioned_path}")
|
358 |
+
final_dataset_path = llm_mentioned_path
|
359 |
+
elif llm_mentioned_path: # Check data dirs if path not absolute
|
360 |
+
data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
|
361 |
+
base_name = os.path.basename(llm_mentioned_path)
|
362 |
+
for data_dir in data_dir_paths:
|
363 |
+
potential_path = os.path.join(data_dir, base_name)
|
364 |
+
if os.path.exists(potential_path):
|
365 |
+
logger.info(f"Using dataset path mentioned by LLM (found in {data_dir}): {potential_path}")
|
366 |
+
final_dataset_path = potential_path
|
367 |
+
break
|
368 |
+
if not final_dataset_path:
|
369 |
+
logger.warning(f"LLM mentioned path '{llm_mentioned_path}' but it was not found.")
|
370 |
+
|
371 |
+
if not final_dataset_path:
|
372 |
+
# Priority 3: Path extracted by dedicated Regex + LLM fallback function
|
373 |
+
logger.info("Attempting dedicated dataset path extraction (Regex + LLM Fallback)...")
|
374 |
+
extracted_path = path_extractor(query) # Call the partial function with llm bound
|
375 |
+
if extracted_path:
|
376 |
+
final_dataset_path = extracted_path
|
377 |
+
|
378 |
+
result["dataset_path"] = final_dataset_path
|
379 |
+
|
380 |
+
# Check if a path was found ultimately
|
381 |
+
if not result["dataset_path"]:
|
382 |
+
logger.warning("Could not determine dataset path from query or arguments.")
|
383 |
+
else:
|
384 |
+
logger.info(f"Final dataset path determined: {result['dataset_path']}")
|
385 |
+
|
386 |
+
return result
|
387 |
+
|
388 |
+
# --- Old Regex-based functions (Commented out or removed) ---
|
389 |
+
# def determine_query_type(query: str) -> str:
|
390 |
+
# ... (implementation removed)
|
391 |
+
|
392 |
+
# def extract_variables(query: str) -> Dict[str, Any]:
|
393 |
+
# ... (implementation removed)
|
394 |
+
|
395 |
+
# def detect_constraints(query: str) -> List[str]:
|
396 |
+
# ... (implementation removed)
|
397 |
+
# --- End Old Functions ---
|
398 |
+
|
399 |
+
# Renamed function for regex path extraction
|
400 |
+
def extract_dataset_path_regex(query: str) -> Optional[str]:
|
401 |
+
"""
|
402 |
+
Extract dataset path from the query using regex patterns.
|
403 |
+
|
404 |
+
Args:
|
405 |
+
query: The user's causal question text
|
406 |
+
|
407 |
+
Returns:
|
408 |
+
String with dataset path or None if not found
|
409 |
+
"""
|
410 |
+
# Check for common patterns indicating dataset paths
|
411 |
+
path_patterns = [
|
412 |
+
# More specific patterns first
|
413 |
+
r"(?:dataset|data|file) (?:at|in|from|located at) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?", # Handles subdirs in path
|
414 |
+
r"(?:use|using|analyze|analyse) (?:the |)(?:dataset|data|file) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?",
|
415 |
+
# Simpler patterns
|
416 |
+
r"[\"']([^\"']+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"']", # Path in quotes
|
417 |
+
r"([a-zA-Z0-9_/.:-]+[\\/][a-zA-Z0-9_.:-]+\.csv)", # More generic path-like structure ending in .csv
|
418 |
+
r"([^\"\'.,\s]+\.csv)" # Just a .csv file name (least specific)
|
419 |
+
]
|
420 |
+
|
421 |
+
for pattern in path_patterns:
|
422 |
+
matches = re.search(pattern, query, re.IGNORECASE)
|
423 |
+
if matches:
|
424 |
+
path = matches.group(1).strip()
|
425 |
+
|
426 |
+
# Basic check if it looks like a path
|
427 |
+
if '/' in path or '\\' in path or os.path.exists(path):
|
428 |
+
# Check if this is a valid file path immediately
|
429 |
+
if os.path.exists(path):
|
430 |
+
logger.info(f"Regex found existing path: {path}")
|
431 |
+
return path
|
432 |
+
|
433 |
+
# Check if it's in common data directories
|
434 |
+
data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
|
435 |
+
# Also check relative to current dir (often useful)
|
436 |
+
# base_name = os.path.basename(path)
|
437 |
+
for data_dir in data_dir_paths:
|
438 |
+
potential_path = os.path.join(data_dir, os.path.basename(path))
|
439 |
+
if os.path.exists(potential_path):
|
440 |
+
logger.info(f"Regex found path in {data_dir}: {potential_path}")
|
441 |
+
return potential_path
|
442 |
+
|
443 |
+
# If not found but looks like a path, return it anyway - let downstream handle non-existence
|
444 |
+
logger.info(f"Regex found potential path (existence not verified): {path}")
|
445 |
+
return path
|
446 |
+
# Else: it might just be a word ending in .csv, ignore unless it exists
|
447 |
+
elif os.path.exists(path):
|
448 |
+
logger.info(f"Regex found existing path (simple pattern): {path}")
|
449 |
+
return path
|
450 |
+
|
451 |
+
# TODO: Optional: Add LLM fallback call here if regex fails
|
452 |
+
# if no path found:
|
453 |
+
# llm_fallback_path = call_llm_for_path(query)
|
454 |
+
# return llm_fallback_path
|
455 |
+
|
456 |
+
return None
|
auto_causal/components/method_validator.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Method validator component for causal inference methods.
|
3 |
+
|
4 |
+
This module validates the selected causal inference method against
|
5 |
+
dataset characteristics and available variables.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import Dict, List, Any, Optional
|
9 |
+
|
10 |
+
|
11 |
+
def validate_method(method_info: Dict[str, Any], dataset_analysis: Dict[str, Any],
|
12 |
+
variables: Dict[str, Any]) -> Dict[str, Any]:
|
13 |
+
"""
|
14 |
+
Validate the selected causal method against dataset characteristics.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
method_info: Information about the selected method from decision_tree
|
18 |
+
dataset_analysis: Dataset analysis results from dataset_analyzer
|
19 |
+
variables: Identified variables from query_interpreter
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Dict with validation results:
|
23 |
+
- valid: Boolean indicating if method is valid
|
24 |
+
- concerns: List of concerns/issues with the selected method
|
25 |
+
- alternative_suggestions: Alternative methods if the selected method is problematic
|
26 |
+
- recommended_method: Updated method recommendation if issues are found
|
27 |
+
"""
|
28 |
+
method = method_info.get("selected_method")
|
29 |
+
assumptions = method_info.get("method_assumptions", [])
|
30 |
+
|
31 |
+
# Get required variables
|
32 |
+
treatment = variables.get("treatment_variable")
|
33 |
+
outcome = variables.get("outcome_variable")
|
34 |
+
covariates = variables.get("covariates", [])
|
35 |
+
time_variable = variables.get("time_variable")
|
36 |
+
group_variable = variables.get("group_variable")
|
37 |
+
instrument_variable = variables.get("instrument_variable")
|
38 |
+
running_variable = variables.get("running_variable")
|
39 |
+
cutoff_value = variables.get("cutoff_value")
|
40 |
+
|
41 |
+
# Initialize validation result
|
42 |
+
validation_result = {
|
43 |
+
"valid": True,
|
44 |
+
"concerns": [],
|
45 |
+
"alternative_suggestions": [],
|
46 |
+
"recommended_method": method,
|
47 |
+
}
|
48 |
+
|
49 |
+
# Common validations for all methods
|
50 |
+
if treatment is None:
|
51 |
+
validation_result["valid"] = False
|
52 |
+
validation_result["concerns"].append("Treatment variable is not identified")
|
53 |
+
|
54 |
+
if outcome is None:
|
55 |
+
validation_result["valid"] = False
|
56 |
+
validation_result["concerns"].append("Outcome variable is not identified")
|
57 |
+
|
58 |
+
# Method-specific validations
|
59 |
+
if method == "propensity_score_matching":
|
60 |
+
validate_propensity_score_matching(validation_result, dataset_analysis, variables)
|
61 |
+
|
62 |
+
elif method == "regression_adjustment":
|
63 |
+
validate_regression_adjustment(validation_result, dataset_analysis, variables)
|
64 |
+
|
65 |
+
elif method == "instrumental_variable":
|
66 |
+
validate_instrumental_variable(validation_result, dataset_analysis, variables)
|
67 |
+
|
68 |
+
elif method == "difference_in_differences":
|
69 |
+
validate_difference_in_differences(validation_result, dataset_analysis, variables)
|
70 |
+
|
71 |
+
elif method == "regression_discontinuity_design":
|
72 |
+
validate_regression_discontinuity(validation_result, dataset_analysis, variables)
|
73 |
+
|
74 |
+
elif method == "backdoor_adjustment":
|
75 |
+
validate_backdoor_adjustment(validation_result, dataset_analysis, variables)
|
76 |
+
|
77 |
+
# If there are serious concerns, recommend alternatives
|
78 |
+
if not validation_result["valid"]:
|
79 |
+
validation_result["recommended_method"] = recommend_alternative(
|
80 |
+
method, validation_result["concerns"], method_info.get("alternatives", [])
|
81 |
+
)
|
82 |
+
|
83 |
+
# Make sure assumptions are listed in the validation result
|
84 |
+
validation_result["assumptions"] = assumptions
|
85 |
+
print("--------------------------")
|
86 |
+
print("Validation result:", validation_result)
|
87 |
+
print("--------------------------")
|
88 |
+
return validation_result
|
89 |
+
|
90 |
+
|
91 |
+
def validate_propensity_score_matching(validation_result: Dict[str, Any],
|
92 |
+
dataset_analysis: Dict[str, Any],
|
93 |
+
variables: Dict[str, Any]) -> None:
|
94 |
+
"""
|
95 |
+
Validate propensity score matching method requirements.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
validation_result: Current validation result to update
|
99 |
+
dataset_analysis: Dataset analysis results
|
100 |
+
variables: Identified variables
|
101 |
+
"""
|
102 |
+
treatment = variables.get("treatment_variable")
|
103 |
+
covariates = variables.get("covariates", [])
|
104 |
+
|
105 |
+
# Check if treatment is binary using column_categories
|
106 |
+
is_binary = dataset_analysis.get("column_categories", {}).get(treatment) == "binary"
|
107 |
+
|
108 |
+
# Fallback to check if the column has only two unique values (0 and 1)
|
109 |
+
if not is_binary:
|
110 |
+
column_types = dataset_analysis.get("column_types", {})
|
111 |
+
if column_types.get(treatment) == "int64" or column_types.get(treatment) == "int32":
|
112 |
+
# Assuming int type with only 0s and 1s is binary
|
113 |
+
is_binary = True
|
114 |
+
|
115 |
+
if not is_binary:
|
116 |
+
validation_result["valid"] = False
|
117 |
+
validation_result["concerns"].append(
|
118 |
+
"Treatment variable is not binary, which is required for propensity score matching"
|
119 |
+
)
|
120 |
+
|
121 |
+
# Check if there are sufficient covariates
|
122 |
+
if len(covariates) < 2:
|
123 |
+
validation_result["concerns"].append(
|
124 |
+
"Few covariates identified, which may limit the effectiveness of propensity score matching"
|
125 |
+
)
|
126 |
+
|
127 |
+
# Check for sufficient overlap
|
128 |
+
variable_relationships = dataset_analysis.get("variable_relationships", {})
|
129 |
+
treatment_imbalance = variable_relationships.get("treatment_imbalance", 0.5)
|
130 |
+
|
131 |
+
if treatment_imbalance < 0.1 or treatment_imbalance > 0.9:
|
132 |
+
validation_result["concerns"].append(
|
133 |
+
"Treatment groups are highly imbalanced, which may lead to poor matching quality"
|
134 |
+
)
|
135 |
+
validation_result["alternative_suggestions"].append("regression_adjustment")
|
136 |
+
|
137 |
+
|
138 |
+
def validate_regression_adjustment(validation_result: Dict[str, Any],
|
139 |
+
dataset_analysis: Dict[str, Any],
|
140 |
+
variables: Dict[str, Any]) -> None:
|
141 |
+
"""
|
142 |
+
Validate regression adjustment method requirements.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
validation_result: Current validation result to update
|
146 |
+
dataset_analysis: Dataset analysis results
|
147 |
+
variables: Identified variables
|
148 |
+
"""
|
149 |
+
outcome = variables.get("outcome_variable")
|
150 |
+
|
151 |
+
# Check outcome type for appropriate regression model
|
152 |
+
outcome_data = dataset_analysis.get("variable_types", {}).get(outcome, {})
|
153 |
+
outcome_type = outcome_data.get("type")
|
154 |
+
|
155 |
+
if outcome_type == "categorical" and outcome_data.get("n_categories", 0) > 2:
|
156 |
+
validation_result["concerns"].append(
|
157 |
+
"Outcome is categorical with multiple categories, which may require multinomial regression"
|
158 |
+
)
|
159 |
+
|
160 |
+
# Check for potential nonlinear relationships
|
161 |
+
nonlinear_relationships = dataset_analysis.get("nonlinear_relationships", False)
|
162 |
+
|
163 |
+
if nonlinear_relationships:
|
164 |
+
validation_result["concerns"].append(
|
165 |
+
"Potential nonlinear relationships detected, which may require more flexible models"
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
def validate_instrumental_variable(validation_result: Dict[str, Any],
|
170 |
+
dataset_analysis: Dict[str, Any],
|
171 |
+
variables: Dict[str, Any]) -> None:
|
172 |
+
"""
|
173 |
+
Validate instrumental variable method requirements.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
validation_result: Current validation result to update
|
177 |
+
dataset_analysis: Dataset analysis results
|
178 |
+
variables: Identified variables
|
179 |
+
"""
|
180 |
+
instrument_variable = variables.get("instrument_variable")
|
181 |
+
treatment = variables.get("treatment_variable")
|
182 |
+
|
183 |
+
if instrument_variable is None:
|
184 |
+
validation_result["valid"] = False
|
185 |
+
validation_result["concerns"].append(
|
186 |
+
"No instrumental variable identified, which is required for this method"
|
187 |
+
)
|
188 |
+
validation_result["alternative_suggestions"].append("propensity_score_matching")
|
189 |
+
return
|
190 |
+
|
191 |
+
# Check for instrument strength (correlation with treatment)
|
192 |
+
variable_relationships = dataset_analysis.get("variable_relationships", {})
|
193 |
+
instrument_correlation = next(
|
194 |
+
(corr.get("correlation", 0) for corr in variable_relationships.get("correlations", [])
|
195 |
+
if corr.get("var1") == instrument_variable and corr.get("var2") == treatment
|
196 |
+
or corr.get("var1") == treatment and corr.get("var2") == instrument_variable),
|
197 |
+
0
|
198 |
+
)
|
199 |
+
|
200 |
+
if abs(instrument_correlation) < 0.2:
|
201 |
+
validation_result["concerns"].append(
|
202 |
+
"Instrument appears weak (low correlation with treatment), which may lead to bias"
|
203 |
+
)
|
204 |
+
validation_result["alternative_suggestions"].append("propensity_score_matching")
|
205 |
+
|
206 |
+
|
207 |
+
def validate_difference_in_differences(validation_result: Dict[str, Any],
|
208 |
+
dataset_analysis: Dict[str, Any],
|
209 |
+
variables: Dict[str, Any]) -> None:
|
210 |
+
"""
|
211 |
+
Validate difference-in-differences method requirements.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
validation_result: Current validation result to update
|
215 |
+
dataset_analysis: Dataset analysis results
|
216 |
+
variables: Identified variables
|
217 |
+
"""
|
218 |
+
time_variable = variables.get("time_variable")
|
219 |
+
group_variable = variables.get("group_variable")
|
220 |
+
|
221 |
+
if time_variable is None:
|
222 |
+
validation_result["valid"] = False
|
223 |
+
validation_result["concerns"].append(
|
224 |
+
"No time variable identified, which is required for difference-in-differences"
|
225 |
+
)
|
226 |
+
validation_result["alternative_suggestions"].append("propensity_score_matching")
|
227 |
+
|
228 |
+
if group_variable is None:
|
229 |
+
validation_result["valid"] = False
|
230 |
+
validation_result["concerns"].append(
|
231 |
+
"No group variable identified, which is required for difference-in-differences"
|
232 |
+
)
|
233 |
+
validation_result["alternative_suggestions"].append("propensity_score_matching")
|
234 |
+
|
235 |
+
# Check for parallel trends
|
236 |
+
temporal_structure = dataset_analysis.get("temporal_structure", {})
|
237 |
+
parallel_trends = temporal_structure.get("parallel_trends", False)
|
238 |
+
|
239 |
+
if not parallel_trends:
|
240 |
+
validation_result["concerns"].append(
|
241 |
+
"No evidence of parallel trends, which is a key assumption for difference-in-differences"
|
242 |
+
)
|
243 |
+
validation_result["alternative_suggestions"].append("synthetic_control")
|
244 |
+
|
245 |
+
|
246 |
+
def validate_regression_discontinuity(validation_result: Dict[str, Any],
|
247 |
+
dataset_analysis: Dict[str, Any],
|
248 |
+
variables: Dict[str, Any]) -> None:
|
249 |
+
"""
|
250 |
+
Validate regression discontinuity method requirements.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
validation_result: Current validation result to update
|
254 |
+
dataset_analysis: Dataset analysis results
|
255 |
+
variables: Identified variables
|
256 |
+
"""
|
257 |
+
running_variable = variables.get("running_variable")
|
258 |
+
cutoff_value = variables.get("cutoff_value")
|
259 |
+
|
260 |
+
if running_variable is None:
|
261 |
+
validation_result["valid"] = False
|
262 |
+
validation_result["concerns"].append(
|
263 |
+
"No running variable identified, which is required for regression discontinuity"
|
264 |
+
)
|
265 |
+
validation_result["alternative_suggestions"].append("propensity_score_matching")
|
266 |
+
|
267 |
+
if cutoff_value is None:
|
268 |
+
validation_result["valid"] = False
|
269 |
+
validation_result["concerns"].append(
|
270 |
+
"No cutoff value identified, which is required for regression discontinuity"
|
271 |
+
)
|
272 |
+
validation_result["alternative_suggestions"].append("propensity_score_matching")
|
273 |
+
|
274 |
+
# Check for discontinuity at threshold
|
275 |
+
discontinuities = dataset_analysis.get("discontinuities", {})
|
276 |
+
has_discontinuity = discontinuities.get("has_discontinuities", False)
|
277 |
+
|
278 |
+
if not has_discontinuity:
|
279 |
+
validation_result["valid"] = False
|
280 |
+
validation_result["concerns"].append(
|
281 |
+
"No clear discontinuity detected at the threshold, which is necessary for this method"
|
282 |
+
)
|
283 |
+
validation_result["alternative_suggestions"].append("regression_adjustment")
|
284 |
+
|
285 |
+
def validate_backdoor_adjustment(validation_result: Dict[str, Any],
|
286 |
+
dataset_analysis: Dict[str, Any],
|
287 |
+
variables: Dict[str, Any]) -> None:
|
288 |
+
"""
|
289 |
+
Validate backdoor adjustment method requirements.
|
290 |
+
|
291 |
+
Args:
|
292 |
+
validation_result: Current validation result to update
|
293 |
+
dataset_analysis: Dataset analysis results
|
294 |
+
variables: Identified variables
|
295 |
+
"""
|
296 |
+
covariates = variables.get("covariates", [])
|
297 |
+
|
298 |
+
if len(covariates) == 0:
|
299 |
+
validation_result["valid"] = False
|
300 |
+
validation_result["concerns"].append(
|
301 |
+
"No covariates identified for backdoor adjustment"
|
302 |
+
)
|
303 |
+
validation_result["alternative_suggestions"].append("regression_adjustment")
|
304 |
+
|
305 |
+
|
306 |
+
def recommend_alternative(method: str, concerns: List[str], alternatives: List[str]) -> str:
|
307 |
+
"""
|
308 |
+
Recommend an alternative method if the current one has issues.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
method: Current method
|
312 |
+
concerns: List of concerns with the current method
|
313 |
+
alternatives: List of alternative methods suggested by the decision tree
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
String with the recommended method
|
317 |
+
"""
|
318 |
+
# If there are alternatives, recommend the first one
|
319 |
+
if alternatives:
|
320 |
+
return alternatives[0]
|
321 |
+
|
322 |
+
# If no alternatives, use regression adjustment as a fallback
|
323 |
+
if method != "regression_adjustment":
|
324 |
+
return "regression_adjustment"
|
325 |
+
|
326 |
+
# If regression adjustment is also problematic, use propensity score matching
|
327 |
+
return "propensity_score_matching"
|
auto_causal/components/output_formatter.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Output formatter component for causal inference results.
|
3 |
+
|
4 |
+
This module formats the results of causal analysis into a clear,
|
5 |
+
structured output for presentation to the user.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import Dict, List, Any, Optional
|
9 |
+
import json # Add this import at the top of the file
|
10 |
+
|
11 |
+
# Import the new model
|
12 |
+
from auto_causal.models import FormattedOutput
|
13 |
+
|
14 |
+
# Add this module-level variable, typically near imports or at the top
|
15 |
+
CURRENT_OUTPUT_LOG_FILE = None
|
16 |
+
|
17 |
+
# Revert signature and logic to handle results and structured explanation
|
18 |
+
def format_output(
|
19 |
+
query: str,
|
20 |
+
method: str,
|
21 |
+
results: Dict[str, Any],
|
22 |
+
explanation: Dict[str, Any],
|
23 |
+
dataset_analysis: Optional[Dict[str, Any]] = None,
|
24 |
+
dataset_description: Optional[str] = None
|
25 |
+
) -> FormattedOutput:
|
26 |
+
"""
|
27 |
+
Format final results including numerical estimates and explanations.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
query: Original user query
|
31 |
+
method: Causal inference method used (string name)
|
32 |
+
results: Numerical results from method_executor_tool
|
33 |
+
explanation: Structured explanation object from explainer_tool
|
34 |
+
dataset_analysis: Optional dictionary of dataset analysis results
|
35 |
+
dataset_description: Optional string description of the dataset
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
Dict with formatted output fields ready for presentation.
|
39 |
+
"""
|
40 |
+
# Extract numerical results
|
41 |
+
effect_estimate = results.get("effect_estimate")
|
42 |
+
confidence_interval = results.get("confidence_interval")
|
43 |
+
p_value = results.get("p_value")
|
44 |
+
effect_se = results.get("standard_error") # Get SE if available
|
45 |
+
|
46 |
+
# Format method name for readability
|
47 |
+
method_name_formatted = _format_method_name(method)
|
48 |
+
|
49 |
+
# Extract explanation components (assuming explainer returns structured dict again)
|
50 |
+
# If explainer returns single string, adjust this
|
51 |
+
method_explanation_text = explanation.get("method_explanation", "")
|
52 |
+
interpretation_guide = explanation.get("interpretation_guide", "")
|
53 |
+
limitations = explanation.get("limitations", [])
|
54 |
+
assumptions_discussion = explanation.get("assumptions", "") # Assuming key is 'assumptions'
|
55 |
+
practical_implications = explanation.get("practical_implications", "")
|
56 |
+
# Add back final_explanation_text if explainer provides it
|
57 |
+
# final_explanation_text = explanation.get("final_explanation_text")
|
58 |
+
|
59 |
+
# Create summary using numerical results
|
60 |
+
ci_text = ""
|
61 |
+
if confidence_interval and confidence_interval[0] is not None and confidence_interval[1] is not None:
|
62 |
+
ci_text = f" (95% CI: [{confidence_interval[0]:.4f}, {confidence_interval[1]:.4f}])"
|
63 |
+
|
64 |
+
p_value_text = f", p={p_value:.4f}" if p_value is not None else ""
|
65 |
+
effect_text = f"{effect_estimate:.4f}" if effect_estimate is not None else "N/A"
|
66 |
+
|
67 |
+
summary = (
|
68 |
+
f"Based on {method_name_formatted}, the estimated causal effect is {effect_text}"
|
69 |
+
f"{ci_text}{p_value_text}. {_create_effect_interpretation(effect_estimate, p_value)}"
|
70 |
+
f" See details below regarding assumptions and limitations."
|
71 |
+
)
|
72 |
+
|
73 |
+
# Assemble formatted output dictionary
|
74 |
+
results_dict = {
|
75 |
+
"query": query,
|
76 |
+
"method_used": method_name_formatted,
|
77 |
+
"causal_effect": effect_estimate,
|
78 |
+
"standard_error": effect_se,
|
79 |
+
"confidence_interval": confidence_interval,
|
80 |
+
"p_value": p_value,
|
81 |
+
"summary": summary,
|
82 |
+
"method_explanation": method_explanation_text,
|
83 |
+
"interpretation_guide": interpretation_guide,
|
84 |
+
"limitations": limitations,
|
85 |
+
"assumptions": assumptions_discussion,
|
86 |
+
"practical_implications": practical_implications,
|
87 |
+
# "full_explanation_text": final_explanation_text # Optionally include combined text
|
88 |
+
}
|
89 |
+
final_results_dict = {key : results_dict[key] for key in {"query", "method_used", "causal_effect", "standard_error", "confidence_interval"}}
|
90 |
+
# print(final_results_dict)
|
91 |
+
|
92 |
+
# Validate and instantiate the Pydantic model
|
93 |
+
try:
|
94 |
+
formatted_output_model = FormattedOutput(**results_dict)
|
95 |
+
except Exception as e: # Catch validation errors specifically if needed
|
96 |
+
# Handle validation error - perhaps log and return a default or raise
|
97 |
+
print(f"Error creating FormattedOutput model: {e}") # Or use logger
|
98 |
+
# Decide on error handling: raise, return None, return default?
|
99 |
+
# For now, re-raising might be simplest if the structure is expected
|
100 |
+
raise ValueError(f"Failed to create FormattedOutput from results: {e}")
|
101 |
+
|
102 |
+
return formatted_output_model # Return the Pydantic model instance
|
103 |
+
|
104 |
+
|
105 |
+
def _format_method_name(method: str) -> str:
|
106 |
+
"""Format method name for readability."""
|
107 |
+
method_names = {
|
108 |
+
"propensity_score_matching": "Propensity Score Matching",
|
109 |
+
"regression_adjustment": "Regression Adjustment",
|
110 |
+
"instrumental_variable": "Instrumental Variable Analysis",
|
111 |
+
"difference_in_differences": "Difference-in-Differences",
|
112 |
+
"regression_discontinuity": "Regression Discontinuity Design",
|
113 |
+
"backdoor_adjustment": "Backdoor Adjustment",
|
114 |
+
"propensity_score_weighting": "Propensity Score Weighting"
|
115 |
+
}
|
116 |
+
return method_names.get(method, method.replace("_", " ").title())
|
117 |
+
|
118 |
+
# Reinstate helper function for interpretation
|
119 |
+
def _create_effect_interpretation(effect: Optional[float], p_value: Optional[float] = None) -> str:
|
120 |
+
"""Create a basic interpretation of the effect."""
|
121 |
+
if effect is None:
|
122 |
+
return "Effect estimate not available."
|
123 |
+
|
124 |
+
significance = ""
|
125 |
+
if p_value is not None:
|
126 |
+
significance = "statistically significant" if p_value < 0.05 else "not statistically significant"
|
127 |
+
|
128 |
+
magnitude = ""
|
129 |
+
if abs(effect) < 0.01:
|
130 |
+
magnitude = "no practical effect"
|
131 |
+
elif abs(effect) < 0.1:
|
132 |
+
magnitude = "a small effect"
|
133 |
+
elif abs(effect) < 0.5:
|
134 |
+
magnitude = "a moderate effect"
|
135 |
+
else:
|
136 |
+
magnitude = "a substantial effect"
|
137 |
+
|
138 |
+
return f"This suggests {magnitude}{f' and is {significance}' if significance else ''}."
|
auto_causal/components/query_interpreter.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Query interpreter component for causal inference.
|
3 |
+
|
4 |
+
This module provides functionality to match query concepts to actual dataset variables,
|
5 |
+
identifying treatment, outcome, and covariate variables for causal inference analysis.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import re
|
9 |
+
from typing import Dict, List, Any, Optional, Union, Tuple
|
10 |
+
import pandas as pd
|
11 |
+
import logging
|
12 |
+
import numpy as np
|
13 |
+
from auto_causal.config import get_llm_client
|
14 |
+
# Import LLM and message types
|
15 |
+
from langchain_core.language_models import BaseChatModel
|
16 |
+
from langchain_core.messages import HumanMessage
|
17 |
+
from langchain_core.exceptions import OutputParserException
|
18 |
+
# Import base Pydantic models needed directly
|
19 |
+
from pydantic import BaseModel, ValidationError
|
20 |
+
from dowhy import CausalModel
|
21 |
+
import json
|
22 |
+
|
23 |
+
# Import shared Pydantic models from the central location
|
24 |
+
from auto_causal.models import (
|
25 |
+
LLMSelectedVariable,
|
26 |
+
LLMSelectedCovariates,
|
27 |
+
LLMIVars,
|
28 |
+
LLMRDDVars,
|
29 |
+
LLMRCTCheck,
|
30 |
+
LLMTreatmentReferenceLevel,
|
31 |
+
LLMInteractionSuggestion,
|
32 |
+
LLMEstimand,
|
33 |
+
# LLMDIDCheck,
|
34 |
+
# LLMDiDTemporalVars,
|
35 |
+
# LLMDiDGroupVars,
|
36 |
+
# LLMRDDCheck,
|
37 |
+
# LLMRDDVarsExtended
|
38 |
+
)
|
39 |
+
|
40 |
+
# Import the new prompt templates
|
41 |
+
from auto_causal.prompts.method_identification_prompts import (
|
42 |
+
IV_IDENTIFICATION_PROMPT_TEMPLATE,
|
43 |
+
RDD_IDENTIFICATION_PROMPT_TEMPLATE,
|
44 |
+
RCT_IDENTIFICATION_PROMPT_TEMPLATE,
|
45 |
+
TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE,
|
46 |
+
INTERACTION_TERM_IDENTIFICATION_PROMPT_TEMPLATE,
|
47 |
+
TREATMENT_VAR_IDENTIFICATION_PROMPT_TEMPLATE,
|
48 |
+
OUTCOME_VAR_IDENTIFICATION_PROMPT_TEMPLATE,
|
49 |
+
COVARIATES_IDENTIFICATION_PROMPT_TEMPLATE,
|
50 |
+
ESTIMAND_PROMPT_TEMPLATE,
|
51 |
+
CONFOUNDER_IDENTIFICATION_PROMPT_TEMPLATE,
|
52 |
+
DID_TERM_IDENTIFICATION_PROMPT_TEMPLATE)
|
53 |
+
|
54 |
+
|
55 |
+
# Assume central models are defined elsewhere or keep local definitions for now
|
56 |
+
# from ..models import ...
|
57 |
+
|
58 |
+
# --- Pydantic models for LLM structured output ---
|
59 |
+
# REMOVED - Now defined in causalscientist/auto_causal/models.py
|
60 |
+
# class LLMSelectedVariable(BaseModel): ...
|
61 |
+
# class LLMSelectedCovariates(BaseModel): ...
|
62 |
+
# class LLMIVars(BaseModel): ...
|
63 |
+
# class LLMRDDVars(BaseModel): ...
|
64 |
+
# class LLMRCTCheck(BaseModel): ...
|
65 |
+
|
66 |
+
|
67 |
+
logger = logging.getLogger(__name__)
|
68 |
+
|
69 |
+
def infer_treatment_variable_type(treatment_variable: str, column_categories: Dict[str, str],
|
70 |
+
dataset_analysis: Dict[str, Any]) -> str:
|
71 |
+
"""
|
72 |
+
Determine treatment variable type from column category and unique value count
|
73 |
+
Args:
|
74 |
+
treatment_variable: name of the treatment variable
|
75 |
+
column_categories: mapping of column names to their categories
|
76 |
+
dataset_analysis: exploratory analysis results
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
str: type of the treatment variable (e.g., "binary", "continuous", etc
|
80 |
+
"""
|
81 |
+
|
82 |
+
treatment_variable_type = "unknown"
|
83 |
+
if treatment_variable and treatment_variable in column_categories:
|
84 |
+
category = column_categories[treatment_variable]
|
85 |
+
logger.info(f"Category for treatment '{treatment_variable}' is '{category}'.")
|
86 |
+
|
87 |
+
if category == "continuous_numeric":
|
88 |
+
treatment_variable_type = "continuous"
|
89 |
+
|
90 |
+
elif category == "discrete_numeric":
|
91 |
+
num_unique = dataset_analysis.get("column_nunique_counts", {}).get(treatment_variable, -1)
|
92 |
+
if num_unique > 10:
|
93 |
+
logger.info(f"'{treatment_variable}' has {num_unique} unique values, treating as continuous.")
|
94 |
+
treatment_variable_type = "continuous"
|
95 |
+
elif num_unique == 2:
|
96 |
+
logger.info(f"'{treatment_variable}' has 2 unique values, treating as binary.")
|
97 |
+
treatment_variable_type = "binary"
|
98 |
+
elif num_unique > 0:
|
99 |
+
logger.info(f"'{treatment_variable}' has {num_unique} unique values, treating as discrete_multi_value.")
|
100 |
+
treatment_variable_type = "discrete_multi_value"
|
101 |
+
else:
|
102 |
+
logger.info(f"'{treatment_variable}' unique value count unknown or too few.")
|
103 |
+
treatment_variable_type = "discrete_numeric_unknown_cardinality"
|
104 |
+
|
105 |
+
elif category in ["binary", "binary_categorical"]:
|
106 |
+
treatment_variable_type = "binary"
|
107 |
+
|
108 |
+
elif category in ["categorical", "categorical_numeric"]:
|
109 |
+
num_unique = dataset_analysis.get("column_nunique_counts", {}).get(treatment_variable, -1)
|
110 |
+
if num_unique == 2:
|
111 |
+
treatment_variable_type = "binary"
|
112 |
+
elif num_unique > 0:
|
113 |
+
treatment_variable_type = "categorical_multi_value"
|
114 |
+
else:
|
115 |
+
treatment_variable_type = "categorical_unknown_cardinality"
|
116 |
+
|
117 |
+
else:
|
118 |
+
logger.warning(f"Unmapped category '{category}' for '{treatment_variable}', setting as 'other'.")
|
119 |
+
treatment_variable_type = "other"
|
120 |
+
|
121 |
+
elif treatment_variable:
|
122 |
+
logger.warning(f"'{treatment_variable}' not found in column_categories.")
|
123 |
+
else:
|
124 |
+
logger.info("No treatment variable identified.")
|
125 |
+
|
126 |
+
logger.info(f"Final Determined Treatment Variable Type: {treatment_variable_type}")
|
127 |
+
return treatment_variable_type
|
128 |
+
|
129 |
+
def determine_treatment_reference_level(is_rct: Optional[bool], llm: Optional[BaseChatModel], treatment_variable: Optional[str],
|
130 |
+
query_text: str, dataset_description: Optional[str], file_path: Optional[str],
|
131 |
+
columns: List[str]) -> Optional[str]:
|
132 |
+
"""
|
133 |
+
Determines the treatment reference level
|
134 |
+
"""
|
135 |
+
|
136 |
+
# If LLM didn't explicitly say RCT, default to False or keep None?
|
137 |
+
# Let's default to False if LLM didn't provide a boolean value.
|
138 |
+
if is_rct is None: is_rct = False
|
139 |
+
treatment_reference_level = None
|
140 |
+
|
141 |
+
if llm and treatment_variable and treatment_variable in columns:
|
142 |
+
treatment_values_sample = []
|
143 |
+
if file_path:
|
144 |
+
try:
|
145 |
+
df = pd.read_csv(file_path)
|
146 |
+
if treatment_variable in df.columns:
|
147 |
+
unique_vals = df[treatment_variable].unique()
|
148 |
+
treatment_values_sample = [item.item() if hasattr(item, 'item') else item for item in unique_vals][:10]
|
149 |
+
if treatment_values_sample:
|
150 |
+
logger.info(f"Successfully read treatment values sample from dataset at '{file_path}' for variable '{treatment_variable}'.")
|
151 |
+
else:
|
152 |
+
logger.info(f"'{treatment_variable}' in '{file_path}' has no unique values or is empty.")
|
153 |
+
else:
|
154 |
+
logger.warning(f"'{treatment_variable}' not found in dataset columns at '{file_path}'.")
|
155 |
+
except FileNotFoundError:
|
156 |
+
logger.warning(f"File not found at: {file_path}")
|
157 |
+
except pd.errors.EmptyDataError:
|
158 |
+
logger.warning(f"Empty file at: {file_path}")
|
159 |
+
except Exception as e:
|
160 |
+
logger.warning(f"Error reading dataset at '{file_path}' for '{treatment_variable}': {e}")
|
161 |
+
|
162 |
+
if not treatment_values_sample:
|
163 |
+
logger.warning(f"No unique values found for treatment '{treatment_variable}'. LLM prompt will receive empty list.")
|
164 |
+
else:
|
165 |
+
logger.info(f"Final treatment values sample: {treatment_values_sample}")
|
166 |
+
|
167 |
+
try:
|
168 |
+
prompt = TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description or 'N/A', treatment_variable=treatment_variable, treatment_variable_values=treatment_values_sample)
|
169 |
+
ref_result = _call_llm_for_var(llm, prompt, LLMTreatmentReferenceLevel)
|
170 |
+
if ref_result and ref_result.reference_level:
|
171 |
+
if treatment_values_sample and ref_result.reference_level not in treatment_values_sample:
|
172 |
+
logger.warning(f"LLM reference level '{ref_result.reference_level}' not in sampled values for '{treatment_variable}'.")
|
173 |
+
treatment_reference_level = ref_result.reference_level
|
174 |
+
logger.info(f"LLM identified reference level: {treatment_reference_level} (Reason: {ref_result.reasoning})")
|
175 |
+
elif ref_result:
|
176 |
+
logger.info(f"LLM returned no reference level (Reason: {ref_result.reasoning})")
|
177 |
+
except Exception as e:
|
178 |
+
logger.error(f"LLM error for treatment reference level: {e}")
|
179 |
+
|
180 |
+
return treatment_reference_level
|
181 |
+
|
182 |
+
def identify_interaction_term(llm: Optional[BaseChatModel], treatment_variable: Optional[str], covariates: List[str],
|
183 |
+
column_categories: Dict[str, str], query_text: str,
|
184 |
+
dataset_description: Optional[str]) -> Tuple[bool, Optional[str]]:
|
185 |
+
"""
|
186 |
+
Identifies the interaction term based on the query and the dataset information
|
187 |
+
"""
|
188 |
+
|
189 |
+
interaction_term_suggested, interaction_variable_candidate = False, None
|
190 |
+
|
191 |
+
if llm and treatment_variable and covariates:
|
192 |
+
try:
|
193 |
+
covariates_list_str = "\n".join([f"- {cov}: {column_categories.get(cov, 'Unknown')}" for cov in covariates]) or "No covariates identified or available."
|
194 |
+
prompt = INTERACTION_TERM_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description or 'N/A', treatment_variable=treatment_variable, covariates_list_with_types=covariates_list_str)
|
195 |
+
result = _call_llm_for_var(llm, prompt, LLMInteractionSuggestion)
|
196 |
+
if result:
|
197 |
+
interaction_term_suggested = result.interaction_needed if result.interaction_needed is not None else False
|
198 |
+
if interaction_term_suggested and result.interaction_variable:
|
199 |
+
if result.interaction_variable in covariates:
|
200 |
+
interaction_variable_candidate = result.interaction_variable
|
201 |
+
logger.info(f"LLM suggested interaction: needed={interaction_term_suggested}, variable='{interaction_variable_candidate}' (Reason: {result.reasoning})")
|
202 |
+
else:
|
203 |
+
logger.warning(f"LLM suggested variable '{result.interaction_variable}' not in covariates {covariates}. Ignoring.")
|
204 |
+
interaction_term_suggested = False
|
205 |
+
elif interaction_term_suggested:
|
206 |
+
logger.info(f"LLM suggested interaction is needed but no variable provided (Reason: {result.reasoning})")
|
207 |
+
else:
|
208 |
+
logger.info(f"LLM suggested no interaction is needed (Reason: {result.reasoning})")
|
209 |
+
else:
|
210 |
+
logger.warning("LLM returned no result for interaction term suggestion.")
|
211 |
+
except Exception as e:
|
212 |
+
logger.error(f"LLM error during interaction term check: {e}")
|
213 |
+
|
214 |
+
return interaction_term_suggested, interaction_variable_candidate
|
215 |
+
|
216 |
+
|
217 |
+
def interpret_query(query_info: Dict[str, Any], dataset_analysis: Dict[str, Any],
|
218 |
+
dataset_description: Optional[str] = None) -> Dict[str, Any]:
|
219 |
+
"""
|
220 |
+
Interpret query using hybrid heuristic/LLM approach to identify variables.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
query_info: Information extracted from the user's query (text, hints).
|
224 |
+
dataset_analysis: Information about the dataset structure (columns, types, etc.).
|
225 |
+
dataset_description: Optional textual description of the dataset.
|
226 |
+
llm: Optional language model instance.
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
Dict containing identified variables (treatment, outcome, covariates, etc., and is_rct).
|
230 |
+
"""
|
231 |
+
|
232 |
+
logger.info("Interpreting query with hybrid approach...")
|
233 |
+
llm = get_llm_client()
|
234 |
+
|
235 |
+
query_text = query_info.get("query_text", "")
|
236 |
+
columns = dataset_analysis.get("columns", [])
|
237 |
+
column_categories = dataset_analysis.get("column_categories", {})
|
238 |
+
file_path = dataset_analysis["dataset_info"]["file_path"]
|
239 |
+
|
240 |
+
|
241 |
+
# --- Identify Treatment ---
|
242 |
+
treatment_hints = query_info.get("potential_treatments", [])
|
243 |
+
dataset_treatments = dataset_analysis.get("potential_treatments", [])
|
244 |
+
treatment_variable = _identify_variable_hybrid(role="treatment", query_hints=treatment_hints,
|
245 |
+
dataset_suggestions=dataset_treatments, columns=columns,
|
246 |
+
column_categories=column_categories,
|
247 |
+
prioritize_types=["binary", "binary_categorical", "discrete_numeric","continuous_numeric"], # Prioritize binary/discrete
|
248 |
+
query_text=query_text, dataset_description=dataset_description,llm=llm)
|
249 |
+
logger.info(f"Identified Treatment: {treatment_variable}")
|
250 |
+
treatment_variable_type = infer_treatment_variable_type(treatment_variable, column_categories, dataset_analysis)
|
251 |
+
|
252 |
+
|
253 |
+
# --- Identify Outcome ---
|
254 |
+
outcome_hints = query_info.get("outcome_hints", [])
|
255 |
+
dataset_outcomes = dataset_analysis.get("potential_outcomes", [])
|
256 |
+
outcome_variable = _identify_variable_hybrid(role="outcome", query_hints=outcome_hints, dataset_suggestions=dataset_outcomes,
|
257 |
+
columns=columns, column_categories=column_categories,
|
258 |
+
prioritize_types=["continuous_numeric", "discrete_numeric"], # Prioritize numeric
|
259 |
+
exclude_vars=[treatment_variable], # Exclude treatment
|
260 |
+
query_text=query_text, dataset_description=dataset_description, llm=llm)
|
261 |
+
logger.info(f"Identified Outcome: {outcome_variable}")
|
262 |
+
|
263 |
+
# --- Identify Covariates ---
|
264 |
+
covariate_hints = query_info.get("covariates_hints", [])
|
265 |
+
covariates = _identify_covariates_hybrid("covars", treatment_variable=treatment_variable, outcome_variable=outcome_variable,
|
266 |
+
columns=columns, column_categories=column_categories, query_hints=covariate_hints,
|
267 |
+
query_text=query_text, dataset_description=dataset_description, llm=llm)
|
268 |
+
logger.info(f"Identified Covariates: {covariates}")
|
269 |
+
|
270 |
+
# --- Identify Confounders ---
|
271 |
+
confounder_hints = query_info.get("covariates_hints", [])
|
272 |
+
confounders = _identify_covariates_hybrid("confounders", treatment_variable=treatment_variable, outcome_variable=outcome_variable,
|
273 |
+
columns=columns, column_categories=column_categories, query_hints=confounder_hints,
|
274 |
+
query_text=query_text, dataset_description=dataset_description, llm=llm)
|
275 |
+
logger.info(f"Identified Confounders: {confounders}")
|
276 |
+
|
277 |
+
# --- Identify Time/Group (from dataset analysis) ---
|
278 |
+
time_variable = None
|
279 |
+
group_variable = None
|
280 |
+
has_temporal = dataset_analysis.get("temporal_structure", {}).get("has_temporal_structure", False)
|
281 |
+
temporal_structure = dataset_analysis.get("temporal_structure", {})
|
282 |
+
if temporal_structure.get("has_temporal_structure", False):
|
283 |
+
time_variable = temporal_structure.get("time_column") or temporal_structure.get("temporal_columns", [None])[0]
|
284 |
+
if temporal_structure.get("is_panel_data", False):
|
285 |
+
group_variable = temporal_structure.get("id_column")
|
286 |
+
logger.info(f"Identified Time Var: {time_variable}, Group Var: {group_variable}, temporal structure: {temporal_structure}")
|
287 |
+
|
288 |
+
# --- Identify IV/RDD/RCT using LLM ---
|
289 |
+
instrument_variable = None
|
290 |
+
running_variable = None
|
291 |
+
cutoff_value = None
|
292 |
+
is_rct = None
|
293 |
+
smd_score = None
|
294 |
+
|
295 |
+
if llm:
|
296 |
+
try:
|
297 |
+
# Check for RCT
|
298 |
+
prompt_rct = _create_identify_prompt("whether data is from RCT", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
|
299 |
+
rct_result = _call_llm_for_var(llm, prompt_rct, LLMRCTCheck)
|
300 |
+
is_rct = rct_result.is_rct if rct_result else None
|
301 |
+
logger.info(f"LLM identified RCT: {is_rct}")
|
302 |
+
|
303 |
+
# Check for IV
|
304 |
+
prompt_iv = _create_identify_prompt("instrumental variable", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
|
305 |
+
iv_result = _call_llm_for_var(llm, prompt_iv, LLMIVars)
|
306 |
+
instrument_variable = iv_result.instrument_variable if iv_result else None
|
307 |
+
if instrument_variable not in columns:
|
308 |
+
instrument_variable = None
|
309 |
+
logger.info(f"LLM identified IV: {instrument_variable}")
|
310 |
+
|
311 |
+
# Check for RDD
|
312 |
+
prompt_rdd = _create_identify_prompt("regression discontinuity (running variable and cutoff)", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
|
313 |
+
rdd_result = _call_llm_for_var(llm, prompt_rdd, LLMRDDVars)
|
314 |
+
if rdd_result:
|
315 |
+
running_variable = rdd_result.running_variable
|
316 |
+
cutoff_value = rdd_result.cutoff_value
|
317 |
+
if running_variable not in columns or cutoff_value is None:
|
318 |
+
running_variable = None
|
319 |
+
cutoff_value = None
|
320 |
+
logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}")
|
321 |
+
|
322 |
+
## For graph based methods
|
323 |
+
exclude_cols = [treatment_variable, outcome_variable]
|
324 |
+
potential_covariates = [col for col in columns if col not in exclude_cols and col is not None]
|
325 |
+
usable_covariates = [col for col in potential_covariates if column_categories.get(col) not in ["text_or_other"]]
|
326 |
+
logger.info(f"Usable covariates for graph: {usable_covariates}")
|
327 |
+
|
328 |
+
estimand_prompt = ESTIMAND_PROMPT_TEMPLATE.format(query=query_text,dataset_description=dataset_description,
|
329 |
+
dataset_columns=usable_covariates,
|
330 |
+
treatment=treatment_variable, outcome=outcome_variable)
|
331 |
+
|
332 |
+
estimand_result = _call_llm_for_var(llm, estimand_prompt, LLMEstimand)
|
333 |
+
estimand = "ate" if "ate" in estimand_result.estimand.strip().lower() else "att"
|
334 |
+
logger.info(f"LLM identified estimand: {estimand}")
|
335 |
+
|
336 |
+
## Did Term
|
337 |
+
did_term_prompt = DID_TERM_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description,
|
338 |
+
column_info=columns, time_variable=time_variable,
|
339 |
+
group_variable=group_variable, column_types=column_categories)
|
340 |
+
did_term_result = _call_llm_for_var(llm, did_term_prompt, LLMRDDVars)
|
341 |
+
did_term_result = did_term_result.did_term if did_term_result in columns else None
|
342 |
+
logger.info(f"LLM identified DiD term: {did_term_result}")
|
343 |
+
|
344 |
+
|
345 |
+
|
346 |
+
#smd_score_all = compute_smd(dataset_analysis.get("data", pd.DataFrame()), treatment_variable, usable_covariates)
|
347 |
+
#smd_score = smd_score_all.get("ate", 0.0) if smd_score_all else 0.0
|
348 |
+
#logger.info(f"Computed SMD score: {smd_score}")
|
349 |
+
|
350 |
+
#logger.debug(f"Computed SMD score for {estimand}: {smd_score}")
|
351 |
+
|
352 |
+
|
353 |
+
except Exception as e:
|
354 |
+
logger.error(f"Error during LLM checks for IV/RDD/RCT: {e}")
|
355 |
+
|
356 |
+
|
357 |
+
|
358 |
+
# --- Identify Treatment Reference Level ---
|
359 |
+
treatment_reference_level = determine_treatment_reference_level(is_rct=is_rct, llm=llm, treatment_variable=treatment_variable,
|
360 |
+
query_text=query_text, dataset_description=dataset_description,
|
361 |
+
file_path=file_path, columns=columns)
|
362 |
+
|
363 |
+
# --- Identify Interaction Term Suggestion ---
|
364 |
+
interaction_term_suggested, interaction_variable_candidate = identify_interaction_term(llm=llm, treatment_variable=treatment_variable,
|
365 |
+
covariates=covariates,
|
366 |
+
column_categories=column_categories, query_text=query_text,
|
367 |
+
dataset_description=dataset_description)
|
368 |
+
|
369 |
+
|
370 |
+
# --- Consolidate ---
|
371 |
+
return {
|
372 |
+
"treatment_variable": treatment_variable,
|
373 |
+
"treatment_variable_type": treatment_variable_type,
|
374 |
+
"outcome_variable": outcome_variable,
|
375 |
+
"covariates": covariates,
|
376 |
+
"time_variable": time_variable,
|
377 |
+
"group_variable": group_variable,
|
378 |
+
"instrument_variable": instrument_variable,
|
379 |
+
"running_variable": running_variable,
|
380 |
+
"cutoff_value": cutoff_value,
|
381 |
+
"is_rct": is_rct,
|
382 |
+
"treatment_reference_level": treatment_reference_level,
|
383 |
+
"interaction_term_suggested": interaction_term_suggested,
|
384 |
+
"interaction_variable_candidate": interaction_variable_candidate,
|
385 |
+
"confounders": confounders,
|
386 |
+
"did_term": did_term_result
|
387 |
+
}
|
388 |
+
|
389 |
+
def compute_smd(df: pd.DataFrame, treat, covars_list) -> Dict[str, float]:
|
390 |
+
"""
|
391 |
+
Computed the standardized mean differences (SMD) for the treatment variable
|
392 |
+
Args:
|
393 |
+
df (pd.DataFrame): The dataset.
|
394 |
+
treat (str): Name of the binary treatment column (0/1).
|
395 |
+
covars_list (List[str]): List of covariate names to consider for SMD calculation
|
396 |
+
|
397 |
+
Returns:
|
398 |
+
Dict{str ->float}: the standardized mean difference (SMD)
|
399 |
+
"""
|
400 |
+
logger.info(f"Computing SMD for treatment variable '{treat}' with covariates: {covars_list}")
|
401 |
+
df_t = df[df[treat] == 1]
|
402 |
+
df_c = df[df[treat] == 0]
|
403 |
+
|
404 |
+
covariates = covars_list if covars_list else df.columns.tolist()
|
405 |
+
smd_ate = np.zeros(len(covariates))
|
406 |
+
smd_att = np.zeros(len(covariates))
|
407 |
+
|
408 |
+
for i, col in enumerate(covariates):
|
409 |
+
try:
|
410 |
+
m_t, m_c = df_t[col].mean(), df_c[col].mean()
|
411 |
+
s_t, s_c = df_t[col].std(ddof=0), df_c[col].std(ddof=0)
|
412 |
+
pooled = np.sqrt((s_t**2 + s_c**2) / 2)
|
413 |
+
|
414 |
+
ate_val = 0.0 if pooled == 0 else (m_t - m_c) / pooled
|
415 |
+
att_val = 0.0 if s_t == 0 else (m_t - m_c) / s_t
|
416 |
+
|
417 |
+
smd_ate.append(ate_val)
|
418 |
+
smd_att.append(att_val)
|
419 |
+
except Exception as e:
|
420 |
+
logger.warning(f"SMD computation failed for column '{col}': {e}")
|
421 |
+
continue
|
422 |
+
|
423 |
+
avg_ate = np.nanmean(np.abs(smd_ate))
|
424 |
+
avg_att = np.nanmean(np.abs(smd_att))
|
425 |
+
|
426 |
+
return {"ate":avg_ate, "att":avg_att}
|
427 |
+
|
428 |
+
|
429 |
+
|
430 |
+
# --- Helper Functions for Hybrid Identification ---
|
431 |
+
def _identify_variable_hybrid(role: str, query_hints: List[str], dataset_suggestions: List[str],
|
432 |
+
columns: List[str], column_categories: Dict[str, str],
|
433 |
+
prioritize_types: List[str], query_text: str,
|
434 |
+
dataset_description: Optional[str],llm: Optional[BaseChatModel],
|
435 |
+
exclude_vars: Optional[List[str]] = None) -> Optional[str]:
|
436 |
+
"""
|
437 |
+
Used to identify a variable from the avaiable information by prompting the LLM. In case of failure,
|
438 |
+
it will fallback to a programmatic selection (heuristics)
|
439 |
+
|
440 |
+
Args:
|
441 |
+
role: variable type (treatment or outcome)
|
442 |
+
query_hints: hints from the query for this variable
|
443 |
+
dataset_suggestions: dataset-specific suggestions for this variable
|
444 |
+
columns: list of available columns in the dataset
|
445 |
+
column_categories: mapping of column names to their categories
|
446 |
+
prioritize_types: types to prioritize for this variable
|
447 |
+
query_text: the original query text
|
448 |
+
dataset_description: description of the dataset
|
449 |
+
llm: language model
|
450 |
+
exclude_vars: list of variables to exclude from selection (e.g., treatment for outcome)
|
451 |
+
Returns:
|
452 |
+
str: name of the identified variable, or None if not found
|
453 |
+
"""
|
454 |
+
|
455 |
+
candidates = set()
|
456 |
+
available_columns = [c for c in columns if c not in (exclude_vars or [])]
|
457 |
+
if not available_columns: return None
|
458 |
+
|
459 |
+
# 1. Exact matches from hints
|
460 |
+
for hint in query_hints:
|
461 |
+
if hint in available_columns:
|
462 |
+
candidates.add(hint)
|
463 |
+
# 2. Add dataset suggestions
|
464 |
+
for sugg in dataset_suggestions:
|
465 |
+
if sugg in available_columns:
|
466 |
+
candidates.add(sugg)
|
467 |
+
|
468 |
+
# 3. Programmatic Filtering based on type
|
469 |
+
plausible_candidates = [c for c in candidates if column_categories.get(c) in prioritize_types]
|
470 |
+
|
471 |
+
if llm:
|
472 |
+
if role == "treatment":
|
473 |
+
prompt_template = TREATMENT_VAR_IDENTIFICATION_PROMPT_TEMPLATE
|
474 |
+
elif role == "outcome":
|
475 |
+
prompt_template = OUTCOME_VAR_IDENTIFICATION_PROMPT_TEMPLATE
|
476 |
+
else:
|
477 |
+
raise ValueError(f"Unsupported role for LLM variable identification: {role}")
|
478 |
+
|
479 |
+
prompt = prompt_template.format(query=query_text, description=dataset_description,
|
480 |
+
column_info=available_columns)
|
481 |
+
llm_choice = _call_llm_for_var(llm, prompt, LLMSelectedVariable)
|
482 |
+
|
483 |
+
if llm_choice and llm_choice.variable_name in available_columns:
|
484 |
+
logger.info(f"LLM selected {role}: {llm_choice.variable_name}")
|
485 |
+
return llm_choice.variable_name
|
486 |
+
else:
|
487 |
+
fallback = plausible_candidates[0] if plausible_candidates else None
|
488 |
+
logger.warning(f"LLM failed to select valid {role}. Falling back to: {fallback}")
|
489 |
+
return fallback
|
490 |
+
|
491 |
+
if plausible_candidates:
|
492 |
+
logger.info(f"No LLM provided. Using first plausible {role}: {plausible_candidates[0]}")
|
493 |
+
return plausible_candidates[0]
|
494 |
+
|
495 |
+
logger.warning(f"No plausible candidates for {role}. Cannot identify variable.")
|
496 |
+
return None
|
497 |
+
|
498 |
+
|
499 |
+
def _identify_covariates_hybrid(role, treatment_variable: Optional[str], outcome_variable: Optional[str],
|
500 |
+
columns: List[str], column_categories: Dict[str, str], query_hints: List[str],
|
501 |
+
query_text: str, dataset_description: Optional[str], llm: Optional[BaseChatModel]) -> List[str]:
|
502 |
+
"""
|
503 |
+
Prompts an LLM to identify the covariates
|
504 |
+
"""
|
505 |
+
|
506 |
+
# 1. Initial Programmatic Filtering
|
507 |
+
exclude_cols = [treatment_variable, outcome_variable]
|
508 |
+
potential_covariates = [col for col in columns if col not in exclude_cols and col is not None]
|
509 |
+
|
510 |
+
# Filter out unusable types
|
511 |
+
usable_covariates = [col for col in potential_covariates if column_categories.get(col) not in ["text_or_other"]]
|
512 |
+
logger.debug(f"Initial usable covariates: {usable_covariates}")
|
513 |
+
|
514 |
+
# 2. LLM Refinement (if LLM available)
|
515 |
+
if llm:
|
516 |
+
logger.info("Using LLM to refine covariate list...")
|
517 |
+
prompt = ""
|
518 |
+
if role == "covars":
|
519 |
+
prompt = COVARIATES_IDENTIFICATION_PROMPT_TEMPLATE.format("covars", query=query_text, description=dataset_description,
|
520 |
+
column_info=", ".join(usable_covariates),
|
521 |
+
treatment=treatment_variable, outcome=outcome_variable)
|
522 |
+
elif role == "confounders":
|
523 |
+
prompt = CONFOUNDER_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description,
|
524 |
+
column_info=", ".join(usable_covariates),
|
525 |
+
treatment=treatment_variable, outcome=outcome_variable)
|
526 |
+
llm_selection = _call_llm_for_var(llm, prompt, LLMSelectedCovariates)
|
527 |
+
|
528 |
+
if llm_selection and llm_selection.covariates:
|
529 |
+
# Validate LLM output against available columns
|
530 |
+
valid_llm_covs = [c for c in llm_selection.covariates if c in usable_covariates]
|
531 |
+
if len(valid_llm_covs) < len(llm_selection.covariates):
|
532 |
+
logger.warning("LLM suggested covariates not found in initial usable list.")
|
533 |
+
if valid_llm_covs: # Use LLM selection if it's valid and non-empty
|
534 |
+
logger.info(f"LLM refined covariates to: {valid_llm_covs}")
|
535 |
+
return valid_llm_covs[:10] # Cap at 10
|
536 |
+
else:
|
537 |
+
logger.warning("LLM refinement failed or returned empty/invalid list. Falling back.")
|
538 |
+
else:
|
539 |
+
logger.warning("LLM refinement call failed or returned no covariates. Falling back.")
|
540 |
+
|
541 |
+
# 3. Fallback to Programmatic List (Capped)
|
542 |
+
logger.info(f"Using programmatically determined covariates (capped at 10): {usable_covariates[:10]}")
|
543 |
+
return usable_covariates[:10]
|
544 |
+
|
545 |
+
def _create_identify_prompt(target: str, query: str, description: Optional[str], columns: List[str],
|
546 |
+
categories: Dict[str,str], treatment: Optional[str], outcome: Optional[str]) -> str:
|
547 |
+
"""
|
548 |
+
Creates a prompt to ask LLM to identify specific roles like IV, RDD, or RCT by selecting and formatting a specific template
|
549 |
+
"""
|
550 |
+
column_info = "\n".join([f"- '{c}' (Type: {categories.get(c, 'Unknown')})" for c in columns])
|
551 |
+
|
552 |
+
# Select the appropriate detailed prompt template based on the target
|
553 |
+
if "instrumental variable" in target.lower():
|
554 |
+
template = IV_IDENTIFICATION_PROMPT_TEMPLATE
|
555 |
+
elif "regression discontinuity" in target.lower():
|
556 |
+
template = RDD_IDENTIFICATION_PROMPT_TEMPLATE
|
557 |
+
elif "rct" in target.lower():
|
558 |
+
template = RCT_IDENTIFICATION_PROMPT_TEMPLATE
|
559 |
+
else:
|
560 |
+
# Fallback or error? For now, let's raise an error if target is unexpected.
|
561 |
+
logger.error(f"Unsupported target for _create_identify_prompt: {target}")
|
562 |
+
raise ValueError(f"Unsupported target for specific identification prompt: {target}")
|
563 |
+
|
564 |
+
# Format the selected template with the provided context
|
565 |
+
prompt = template.format(query=query, description=description or 'N/A', column_info=column_info,
|
566 |
+
treatment=treatment or 'N/A', outcome=outcome or 'N/A')
|
567 |
+
return prompt
|
568 |
+
|
569 |
+
def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel) -> Optional[BaseModel]:
|
570 |
+
"""Helper to call LLM with structured output and handle errors."""
|
571 |
+
try:
|
572 |
+
messages = [HumanMessage(content=prompt)]
|
573 |
+
structured_llm = llm.with_structured_output(pydantic_model)
|
574 |
+
parsed_result = structured_llm.invoke(messages)
|
575 |
+
return parsed_result
|
576 |
+
except (OutputParserException, ValidationError) as e:
|
577 |
+
logger.error(f"LLM call failed parsing/validation for {pydantic_model.__name__}: {e}")
|
578 |
+
except Exception as e:
|
579 |
+
logger.error(f"LLM call failed unexpectedly for {pydantic_model.__name__}: {e}", exc_info=True)
|
580 |
+
return None
|
auto_causal/components/state_manager.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
State management utilities for the auto_causal workflow.
|
3 |
+
|
4 |
+
This module provides utility functions to create standardized state updates
|
5 |
+
for passing between tools in the auto_causal agent workflow.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import Dict, Any, Optional
|
9 |
+
|
10 |
+
def create_workflow_state_update(
|
11 |
+
current_step: str,
|
12 |
+
step_completed_flag: bool,
|
13 |
+
next_tool: str,
|
14 |
+
next_step_reason: str,
|
15 |
+
error: Optional[str] = None
|
16 |
+
) -> Dict[str, Any]:
|
17 |
+
"""
|
18 |
+
Create a standardized workflow state update dictionary.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
current_step: Current step in the workflow (e.g., "input_processing")
|
22 |
+
step_completed_flag: Flag indicating which step was completed (e.g., "query_parsed")
|
23 |
+
next_tool: Name of the next tool to call
|
24 |
+
next_step_reason: Reason message for the next step
|
25 |
+
error: Optional error message if the step failed
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Dictionary containing the workflow_state sub-dictionary
|
29 |
+
"""
|
30 |
+
state_update = {
|
31 |
+
"workflow_state": {
|
32 |
+
"current_step": current_step,
|
33 |
+
current_step + "_completed": step_completed_flag,
|
34 |
+
"next_tool": next_tool,
|
35 |
+
"next_step_reason": next_step_reason
|
36 |
+
}
|
37 |
+
}
|
38 |
+
if error:
|
39 |
+
state_update["workflow_state"]["error_message"] = error
|
40 |
+
return state_update
|
auto_causal/config.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# auto_causal/config.py
|
2 |
+
"""Central configuration for AutoCausal, including LLM client setup."""
|
3 |
+
|
4 |
+
import os
|
5 |
+
import logging
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
# Langchain imports
|
9 |
+
from langchain_core.language_models import BaseChatModel
|
10 |
+
from langchain_openai import ChatOpenAI # Default
|
11 |
+
from langchain_anthropic import ChatAnthropic # Example
|
12 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
13 |
+
# Add other providers if needed, e.g.:
|
14 |
+
# from langchain_community.chat_models import ChatOllama
|
15 |
+
from dotenv import load_dotenv
|
16 |
+
from langchain_deepseek import ChatDeepSeek
|
17 |
+
# Create a disk-backed SQLite cache:
|
18 |
+
# Import Together provider
|
19 |
+
from langchain_together import ChatTogether
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
# Load .env file when this module is loaded
|
24 |
+
load_dotenv()
|
25 |
+
|
26 |
+
def get_llm_client(provider: Optional[str] = None, model_name: Optional[str] = None, **kwargs) -> BaseChatModel:
|
27 |
+
"""Initializes and returns the chosen LLM client based on provider.
|
28 |
+
|
29 |
+
Reads provider, model, and API keys from environment variables if not passed directly.
|
30 |
+
Defaults to OpenAI GPT-4o-mini if no provider/model specified.
|
31 |
+
"""
|
32 |
+
# Prioritize arguments, then environment variables, then defaults
|
33 |
+
provider = provider or os.getenv("LLM_PROVIDER", "openai")
|
34 |
+
provider = provider.lower()
|
35 |
+
|
36 |
+
# Default model depends on provider
|
37 |
+
default_models = {
|
38 |
+
"openai": "gpt-4o-mini",
|
39 |
+
"anthropic": "claude-3-5-sonnet-latest",
|
40 |
+
"together": "deepseek-ai/DeepSeek-V3", # Default Together model
|
41 |
+
"gemini" : "gemini-2.5-flash",
|
42 |
+
"deepseek" : "deepseek-chat"
|
43 |
+
}
|
44 |
+
|
45 |
+
model_name = model_name or os.getenv("LLM_MODEL", default_models.get(provider, default_models["openai"]))
|
46 |
+
|
47 |
+
api_key = None
|
48 |
+
if model_name not in ['o3-mini', 'o3', 'o4-mini']:
|
49 |
+
kwargs.setdefault("temperature", 0) # Default temperature if not provided
|
50 |
+
|
51 |
+
logger.info(f"Initializing LLM client: Provider='{provider}', Model='{model_name}'")
|
52 |
+
|
53 |
+
try:
|
54 |
+
if provider == "openai":
|
55 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
56 |
+
if not api_key:
|
57 |
+
raise ValueError("OPENAI_API_KEY not found in environment.")
|
58 |
+
return ChatOpenAI(model=model_name, api_key=api_key, **kwargs)
|
59 |
+
|
60 |
+
elif provider == "anthropic":
|
61 |
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
62 |
+
if not api_key:
|
63 |
+
raise ValueError("ANTHROPIC_API_KEY not found in environment.")
|
64 |
+
return ChatAnthropic(model=model_name, api_key=api_key, **kwargs, streaming=False)
|
65 |
+
|
66 |
+
elif provider == "together":
|
67 |
+
api_key = os.getenv("TOGETHER_API_KEY")
|
68 |
+
if not api_key:
|
69 |
+
raise ValueError("TOGETHER_API_KEY not found in environment.")
|
70 |
+
return ChatTogether(model=model_name, api_key=api_key, **kwargs)
|
71 |
+
|
72 |
+
elif provider == "gemini":
|
73 |
+
api_key = os.getenv("GEMINI_API_KEY")
|
74 |
+
if not api_key:
|
75 |
+
raise ValueError("GEMINI_API_KEY not found in environment.")
|
76 |
+
return ChatGoogleGenerativeAI(model=model_name, **kwargs, function_calling="auto")
|
77 |
+
|
78 |
+
elif provider == "deepseek":
|
79 |
+
api_key = os.getenv("DEEPSEEK_API_KEY")
|
80 |
+
if not api_key:
|
81 |
+
raise ValueError("DEEPSEEK_API_KEY not found in environment.")
|
82 |
+
return ChatDeepSeek(model=model_name, **kwargs)
|
83 |
+
|
84 |
+
# Example for Ollama (ensure langchain_community is installed)
|
85 |
+
# elif provider == "ollama":
|
86 |
+
# try:
|
87 |
+
# from langchain_community.chat_models import ChatOllama
|
88 |
+
# return ChatOllama(model=model_name, **kwargs)
|
89 |
+
# except ImportError:
|
90 |
+
# raise ValueError("langchain_community needed for Ollama. Run `pip install langchain-community`")
|
91 |
+
|
92 |
+
else:
|
93 |
+
raise ValueError(f"Unsupported LLM provider: {provider}")
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
logger.error(f"Failed to initialize LLM (Provider: {provider}, Model: {model_name}): {e}")
|
97 |
+
raise # Re-raise the exception
|
auto_causal/methods/__init__.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Causal inference methods for the auto_causal module.
|
3 |
+
|
4 |
+
This package contains implementations of various causal inference methods
|
5 |
+
that can be selected and applied by the auto_causal pipeline.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from .causal_method import CausalMethod
|
9 |
+
from .propensity_score.matching import estimate_effect as psm_estimate_effect
|
10 |
+
from .propensity_score.weighting import estimate_effect as psw_estimate_effect
|
11 |
+
from .instrumental_variable.estimator import estimate_effect as iv_estimate_effect
|
12 |
+
from .difference_in_differences.estimator import estimate_effect as did_estimate_effect
|
13 |
+
from .diff_in_means.estimator import estimate_effect as dim_estimate_effect
|
14 |
+
from .linear_regression.estimator import estimate_effect as lr_estimate_effect
|
15 |
+
from .backdoor_adjustment.estimator import estimate_effect as ba_estimate_effect
|
16 |
+
from .regression_discontinuity.estimator import estimate_effect as rdd_estimate_effect
|
17 |
+
from .generalized_propensity_score.estimator import estimate_effect_gps
|
18 |
+
|
19 |
+
# Mapping of method names to their implementation functions
|
20 |
+
METHOD_MAPPING = {
|
21 |
+
"propensity_score_matching": psm_estimate_effect,
|
22 |
+
"propensity_score_weighting": psw_estimate_effect,
|
23 |
+
"instrumental_variable": iv_estimate_effect,
|
24 |
+
"difference_in_differences": did_estimate_effect,
|
25 |
+
"regression_discontinuity_design": rdd_estimate_effect,
|
26 |
+
"backdoor_adjustment": ba_estimate_effect,
|
27 |
+
"linear_regression": lr_estimate_effect,
|
28 |
+
"diff_in_means": dim_estimate_effect,
|
29 |
+
"generalized_propensity_score": estimate_effect_gps,
|
30 |
+
}
|
31 |
+
|
32 |
+
__all__ = [
|
33 |
+
"CausalMethod",
|
34 |
+
"psm_estimate_effect",
|
35 |
+
"psw_estimate_effect",
|
36 |
+
"iv_estimate_effect",
|
37 |
+
"did_estimate_effect",
|
38 |
+
"rdd_estimate_effect",
|
39 |
+
"dim_estimate_effect",
|
40 |
+
"lr_estimate_effect",
|
41 |
+
"ba_estimate_effect",
|
42 |
+
"METHOD_MAPPING",
|
43 |
+
"estimate_effect_gps",
|
44 |
+
]
|
auto_causal/methods/backdoor_adjustment/__init__.py
ADDED
File without changes
|
auto_causal/methods/backdoor_adjustment/diagnostics.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Diagnostic checks for Backdoor Adjustment models (typically OLS).
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Dict, Any, List
|
6 |
+
import statsmodels.api as sm
|
7 |
+
from statsmodels.stats.diagnostic import het_breuschpagan
|
8 |
+
from statsmodels.stats.stattools import jarque_bera, durbin_watson
|
9 |
+
from statsmodels.regression.linear_model import RegressionResultsWrapper
|
10 |
+
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
11 |
+
import pandas as pd
|
12 |
+
import numpy as np
|
13 |
+
import logging
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
def run_backdoor_diagnostics(results: RegressionResultsWrapper, X: pd.DataFrame) -> Dict[str, Any]:
|
18 |
+
"""
|
19 |
+
Runs diagnostic checks on a fitted OLS model used for backdoor adjustment.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
results: A fitted statsmodels OLS results object.
|
23 |
+
X: The design matrix (including constant and all predictors) used.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
Dictionary containing diagnostic metrics.
|
27 |
+
"""
|
28 |
+
diagnostics = {}
|
29 |
+
details = {}
|
30 |
+
|
31 |
+
try:
|
32 |
+
details['r_squared'] = results.rsquared
|
33 |
+
details['adj_r_squared'] = results.rsquared_adj
|
34 |
+
details['f_statistic'] = results.fvalue
|
35 |
+
details['f_p_value'] = results.f_pvalue
|
36 |
+
details['n_observations'] = int(results.nobs)
|
37 |
+
details['degrees_of_freedom_resid'] = int(results.df_resid)
|
38 |
+
details['durbin_watson'] = durbin_watson(results.resid) if results.nobs > 5 else 'N/A (Too few obs)' # Autocorrelation
|
39 |
+
|
40 |
+
# --- Normality of Residuals (Jarque-Bera) ---
|
41 |
+
try:
|
42 |
+
if results.nobs >= 2:
|
43 |
+
jb_value, jb_p_value, skew, kurtosis = jarque_bera(results.resid)
|
44 |
+
details['residuals_normality_jb_stat'] = jb_value
|
45 |
+
details['residuals_normality_jb_p_value'] = jb_p_value
|
46 |
+
details['residuals_skewness'] = skew
|
47 |
+
details['residuals_kurtosis'] = kurtosis
|
48 |
+
details['residuals_normality_status'] = "Normal" if jb_p_value > 0.05 else "Non-Normal"
|
49 |
+
else:
|
50 |
+
details['residuals_normality_status'] = "N/A (Too few obs)"
|
51 |
+
except Exception as e:
|
52 |
+
logger.warning(f"Could not run Jarque-Bera test: {e}")
|
53 |
+
details['residuals_normality_status'] = "Test Failed"
|
54 |
+
|
55 |
+
# --- Homoscedasticity (Breusch-Pagan) ---
|
56 |
+
try:
|
57 |
+
if X.shape[0] > X.shape[1]: # Needs more observations than predictors
|
58 |
+
lm_stat, lm_p_value, f_stat, f_p_value = het_breuschpagan(results.resid, X)
|
59 |
+
details['homoscedasticity_bp_lm_stat'] = lm_stat
|
60 |
+
details['homoscedasticity_bp_lm_p_value'] = lm_p_value
|
61 |
+
details['homoscedasticity_status'] = "Homoscedastic" if lm_p_value > 0.05 else "Heteroscedastic"
|
62 |
+
else:
|
63 |
+
details['homoscedasticity_status'] = "N/A (Too few obs or too many predictors)"
|
64 |
+
except Exception as e:
|
65 |
+
logger.warning(f"Could not run Breusch-Pagan test: {e}")
|
66 |
+
details['homoscedasticity_status'] = "Test Failed"
|
67 |
+
|
68 |
+
# --- Multicollinearity (VIF - Placeholder/Basic) ---
|
69 |
+
# Full VIF requires calculating for each predictor vs others.
|
70 |
+
# Providing a basic status based on condition number as a proxy.
|
71 |
+
try:
|
72 |
+
cond_no = np.linalg.cond(results.model.exog)
|
73 |
+
details['model_condition_number'] = cond_no
|
74 |
+
if cond_no > 30:
|
75 |
+
details['multicollinearity_status'] = "High (Cond. No. > 30)"
|
76 |
+
elif cond_no > 10:
|
77 |
+
details['multicollinearity_status'] = "Moderate (Cond. No. > 10)"
|
78 |
+
else:
|
79 |
+
details['multicollinearity_status'] = "Low"
|
80 |
+
except Exception as e:
|
81 |
+
logger.warning(f"Could not calculate condition number: {e}")
|
82 |
+
details['multicollinearity_status'] = "Check Failed"
|
83 |
+
# details['VIF'] = "Not Fully Implemented"
|
84 |
+
|
85 |
+
# --- Linearity (Still requires visual inspection) ---
|
86 |
+
details['linearity_check'] = "Requires visual inspection (e.g., residual vs fitted plot)"
|
87 |
+
|
88 |
+
return {"status": "Success", "details": details}
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
logger.error(f"Error running Backdoor Adjustment diagnostics: {e}")
|
92 |
+
return {"status": "Failed", "error": str(e), "details": details}
|
auto_causal/methods/backdoor_adjustment/estimator.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Backdoor Adjustment Estimator using Regression.
|
3 |
+
|
4 |
+
Estimates the Average Treatment Effect (ATE) by regressing the outcome on the
|
5 |
+
treatment and a set of covariates assumed to satisfy the backdoor criterion.
|
6 |
+
"""
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import statsmodels.api as sm
|
10 |
+
from typing import Dict, Any, List, Optional
|
11 |
+
import logging
|
12 |
+
from langchain.chat_models.base import BaseChatModel # For type hinting llm
|
13 |
+
|
14 |
+
# Import diagnostics and llm assist (placeholders for now)
|
15 |
+
from .diagnostics import run_backdoor_diagnostics
|
16 |
+
from .llm_assist import interpret_backdoor_results, identify_backdoor_set
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
def estimate_effect(
|
21 |
+
df: pd.DataFrame,
|
22 |
+
treatment: str,
|
23 |
+
outcome: str,
|
24 |
+
covariates: List[str], # Backdoor set - Required for this method
|
25 |
+
query: Optional[str] = None, # For potential LLM use
|
26 |
+
llm: Optional[BaseChatModel] = None, # For potential LLM use
|
27 |
+
**kwargs # To capture any other potential arguments
|
28 |
+
) -> Dict[str, Any]:
|
29 |
+
"""
|
30 |
+
Estimates the causal effect using Backdoor Adjustment (via OLS regression).
|
31 |
+
|
32 |
+
Assumes the provided `covariates` list satisfies the backdoor criterion.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
df: Input DataFrame.
|
36 |
+
treatment: Name of the treatment variable column.
|
37 |
+
outcome: Name of the outcome variable column.
|
38 |
+
covariates: List of covariate names forming the backdoor adjustment set.
|
39 |
+
query: Optional user query for context (e.g., for LLM).
|
40 |
+
llm: Optional Language Model instance.
|
41 |
+
**kwargs: Additional keyword arguments.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Dictionary containing estimation results:
|
45 |
+
- 'effect_estimate': The estimated coefficient for the treatment variable.
|
46 |
+
- 'p_value': The p-value associated with the treatment coefficient.
|
47 |
+
- 'confidence_interval': The 95% confidence interval for the effect.
|
48 |
+
- 'standard_error': The standard error of the treatment coefficient.
|
49 |
+
- 'formula': The regression formula used.
|
50 |
+
- 'model_summary': Summary object from statsmodels.
|
51 |
+
- 'diagnostics': Placeholder for diagnostic results.
|
52 |
+
- 'interpretation': LLM interpretation.
|
53 |
+
"""
|
54 |
+
if not covariates: # Check if the list is empty or None
|
55 |
+
raise ValueError("Backdoor Adjustment requires a non-empty list of covariates (adjustment set).")
|
56 |
+
|
57 |
+
required_cols = [treatment, outcome] + covariates
|
58 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
59 |
+
if missing_cols:
|
60 |
+
raise ValueError(f"Missing required columns for Backdoor Adjustment: {missing_cols}")
|
61 |
+
|
62 |
+
# Prepare data for statsmodels (add constant, handle potential NaNs)
|
63 |
+
df_analysis = df[required_cols].dropna()
|
64 |
+
if df_analysis.empty:
|
65 |
+
raise ValueError("No data remaining after dropping NaNs for required columns.")
|
66 |
+
|
67 |
+
X = df_analysis[[treatment] + covariates]
|
68 |
+
X = sm.add_constant(X) # Add intercept
|
69 |
+
y = df_analysis[outcome]
|
70 |
+
|
71 |
+
# Build the formula string for reporting
|
72 |
+
formula = f"{outcome} ~ {treatment} + " + " + ".join(covariates) + " + const"
|
73 |
+
logger.info(f"Running Backdoor Adjustment regression: {formula}")
|
74 |
+
|
75 |
+
try:
|
76 |
+
model = sm.OLS(y, X)
|
77 |
+
results = model.fit()
|
78 |
+
|
79 |
+
effect_estimate = results.params[treatment]
|
80 |
+
p_value = results.pvalues[treatment]
|
81 |
+
conf_int = results.conf_int(alpha=0.05).loc[treatment].tolist()
|
82 |
+
std_err = results.bse[treatment]
|
83 |
+
|
84 |
+
# Run diagnostics (Placeholders)
|
85 |
+
# Pass the full design matrix X for potential VIF checks etc.
|
86 |
+
diag_results = run_backdoor_diagnostics(results, X)
|
87 |
+
|
88 |
+
# Get interpretation
|
89 |
+
interpretation = interpret_backdoor_results(results, diag_results, treatment, covariates, llm=llm)
|
90 |
+
|
91 |
+
return {
|
92 |
+
'effect_estimate': effect_estimate,
|
93 |
+
'p_value': p_value,
|
94 |
+
'confidence_interval': conf_int,
|
95 |
+
'standard_error': std_err,
|
96 |
+
'formula': formula,
|
97 |
+
'model_summary': results.summary(),
|
98 |
+
'diagnostics': diag_results,
|
99 |
+
'interpretation': interpretation,
|
100 |
+
'method_used': 'Backdoor Adjustment (OLS)'
|
101 |
+
}
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
logger.error(f"Backdoor Adjustment failed: {e}")
|
105 |
+
raise
|
auto_causal/methods/backdoor_adjustment/llm_assist.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM assistance functions for Backdoor Adjustment analysis.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import List, Dict, Any, Optional
|
6 |
+
import logging
|
7 |
+
|
8 |
+
# Imported for type hinting
|
9 |
+
from langchain.chat_models.base import BaseChatModel
|
10 |
+
from statsmodels.regression.linear_model import RegressionResultsWrapper
|
11 |
+
|
12 |
+
# Import shared LLM helpers
|
13 |
+
from auto_causal.utils.llm_helpers import call_llm_with_json_output
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
def identify_backdoor_set(
|
18 |
+
df_cols: List[str],
|
19 |
+
treatment: str,
|
20 |
+
outcome: str,
|
21 |
+
query: Optional[str] = None,
|
22 |
+
existing_covariates: Optional[List[str]] = None, # Allow user to provide some
|
23 |
+
llm: Optional[BaseChatModel] = None
|
24 |
+
) -> List[str]:
|
25 |
+
"""
|
26 |
+
Use LLM to suggest a potential backdoor adjustment set (confounders).
|
27 |
+
|
28 |
+
Tries to identify variables that affect both treatment and outcome.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
df_cols: List of available column names in the dataset.
|
32 |
+
treatment: Treatment variable name.
|
33 |
+
outcome: Outcome variable name.
|
34 |
+
query: User's causal query text (provides context).
|
35 |
+
existing_covariates: Covariates already considered/provided by user.
|
36 |
+
llm: Optional LLM model instance.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
List of suggested variable names for the backdoor adjustment set.
|
40 |
+
"""
|
41 |
+
if llm is None:
|
42 |
+
logger.warning("No LLM provided for backdoor set identification.")
|
43 |
+
return existing_covariates or []
|
44 |
+
|
45 |
+
# Exclude treatment and outcome from potential confounders
|
46 |
+
potential_confounders = [c for c in df_cols if c not in [treatment, outcome]]
|
47 |
+
if not potential_confounders:
|
48 |
+
return existing_covariates or []
|
49 |
+
|
50 |
+
prompt = f"""
|
51 |
+
You are assisting with identifying a backdoor adjustment set for causal inference.
|
52 |
+
The goal is to find observed variables that confound the relationship between the treatment and outcome.
|
53 |
+
Assume the causal effect of '{treatment}' on '{outcome}' is of interest.
|
54 |
+
|
55 |
+
User query context (optional): {query}
|
56 |
+
Available variables in the dataset (excluding treatment and outcome): {potential_confounders}
|
57 |
+
Variables already specified as covariates by user (if any): {existing_covariates}
|
58 |
+
|
59 |
+
Based *only* on the variable names and the query context, identify which of the available variables are likely to be common causes (confounders) of both '{treatment}' and '{outcome}'.
|
60 |
+
These variables should be included in the backdoor adjustment set.
|
61 |
+
Consider variables that likely occurred *before* or *at the same time as* the treatment.
|
62 |
+
|
63 |
+
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
|
64 |
+
{{
|
65 |
+
"suggested_backdoor_set": ["confounder1", "confounder2", ...]
|
66 |
+
}}
|
67 |
+
Include variables from the user-provided list if they seem appropriate as confounders.
|
68 |
+
If no plausible confounders are identified among the available variables, return an empty list.
|
69 |
+
"""
|
70 |
+
|
71 |
+
response = call_llm_with_json_output(llm, prompt)
|
72 |
+
|
73 |
+
suggested_set = []
|
74 |
+
if response and "suggested_backdoor_set" in response and isinstance(response["suggested_backdoor_set"], list):
|
75 |
+
# Basic validation
|
76 |
+
valid_vars = [item for item in response["suggested_backdoor_set"] if isinstance(item, str)]
|
77 |
+
if len(valid_vars) != len(response["suggested_backdoor_set"]):
|
78 |
+
logger.warning("LLM returned non-string items in suggested_backdoor_set list.")
|
79 |
+
suggested_set = valid_vars
|
80 |
+
else:
|
81 |
+
logger.warning(f"Failed to get valid backdoor set recommendations from LLM. Response: {response}")
|
82 |
+
|
83 |
+
# Combine with existing covariates, removing duplicates
|
84 |
+
final_set = list(dict.fromkeys((existing_covariates or []) + suggested_set))
|
85 |
+
return final_set
|
86 |
+
|
87 |
+
def interpret_backdoor_results(
|
88 |
+
results: RegressionResultsWrapper,
|
89 |
+
diagnostics: Dict[str, Any],
|
90 |
+
treatment_var: str,
|
91 |
+
covariates: List[str],
|
92 |
+
llm: Optional[BaseChatModel] = None
|
93 |
+
) -> str:
|
94 |
+
"""
|
95 |
+
Use LLM to interpret Backdoor Adjustment results.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
results: Fitted statsmodels OLS results object.
|
99 |
+
diagnostics: Dictionary of diagnostic results.
|
100 |
+
treatment_var: Name of the treatment variable.
|
101 |
+
covariates: List of covariates used in the adjustment set.
|
102 |
+
llm: Optional LLM model instance.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
String containing natural language interpretation.
|
106 |
+
"""
|
107 |
+
default_interpretation = "LLM interpretation not available for Backdoor Adjustment."
|
108 |
+
if llm is None:
|
109 |
+
logger.info("LLM not provided for Backdoor Adjustment interpretation.")
|
110 |
+
return default_interpretation
|
111 |
+
|
112 |
+
try:
|
113 |
+
# --- Prepare summary for LLM ---
|
114 |
+
results_summary = {}
|
115 |
+
diag_details = diagnostics.get('details', {})
|
116 |
+
|
117 |
+
effect = results.params.get(treatment_var)
|
118 |
+
pval = results.pvalues.get(treatment_var)
|
119 |
+
|
120 |
+
results_summary['Treatment Effect Estimate'] = f"{effect:.3f}" if isinstance(effect, (int, float)) else str(effect)
|
121 |
+
results_summary['P-value'] = f"{pval:.3f}" if isinstance(pval, (int, float)) else str(pval)
|
122 |
+
try:
|
123 |
+
conf_int = results.conf_int().loc[treatment_var]
|
124 |
+
results_summary['95% Confidence Interval'] = f"[{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
|
125 |
+
except KeyError:
|
126 |
+
results_summary['95% Confidence Interval'] = "Not Found"
|
127 |
+
except Exception as ci_e:
|
128 |
+
results_summary['95% Confidence Interval'] = f"Error ({ci_e})"
|
129 |
+
|
130 |
+
results_summary['Adjustment Set (Covariates Used)'] = covariates
|
131 |
+
results_summary['Model R-squared'] = f"{diagnostics.get('details', {}).get('r_squared', 'N/A'):.3f}" if isinstance(diagnostics.get('details', {}).get('r_squared'), (int, float)) else "N/A"
|
132 |
+
|
133 |
+
diag_summary = {}
|
134 |
+
if diagnostics.get("status") == "Success":
|
135 |
+
diag_summary['Residuals Normality Status'] = diag_details.get('residuals_normality_status', 'N/A')
|
136 |
+
diag_summary['Homoscedasticity Status'] = diag_details.get('homoscedasticity_status', 'N/A')
|
137 |
+
diag_summary['Multicollinearity Status'] = diag_details.get('multicollinearity_status', 'N/A')
|
138 |
+
else:
|
139 |
+
diag_summary['Status'] = diagnostics.get("status", "Unknown")
|
140 |
+
|
141 |
+
# --- Construct Prompt ---
|
142 |
+
prompt = f"""
|
143 |
+
You are assisting with interpreting Backdoor Adjustment (Regression) results.
|
144 |
+
The key assumption is that the specified adjustment set (covariates) blocks all confounding paths between the treatment ('{treatment_var}') and outcome.
|
145 |
+
|
146 |
+
Results Summary:
|
147 |
+
{results_summary}
|
148 |
+
|
149 |
+
Diagnostics Summary (OLS model checks):
|
150 |
+
{diag_summary}
|
151 |
+
|
152 |
+
Explain these results in 2-4 concise sentences. Focus on:
|
153 |
+
1. The estimated average treatment effect after adjusting for the specified covariates (magnitude, direction, statistical significance based on p-value < 0.05).
|
154 |
+
2. **Crucially, mention that this estimate relies heavily on the assumption that the included covariates ('{str(covariates)[:100]}...') are sufficient to control for confounding (i.e., satisfy the backdoor criterion).**
|
155 |
+
3. Briefly mention any major OLS diagnostic issues noted (e.g., non-normal residuals, heteroscedasticity, high multicollinearity).
|
156 |
+
|
157 |
+
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
|
158 |
+
{{
|
159 |
+
"interpretation": "<your concise interpretation text>"
|
160 |
+
}}
|
161 |
+
"""
|
162 |
+
|
163 |
+
# --- Call LLM ---
|
164 |
+
response = call_llm_with_json_output(llm, prompt)
|
165 |
+
|
166 |
+
# --- Process Response ---
|
167 |
+
if response and isinstance(response, dict) and \
|
168 |
+
"interpretation" in response and isinstance(response["interpretation"], str):
|
169 |
+
return response["interpretation"]
|
170 |
+
else:
|
171 |
+
logger.warning(f"Failed to get valid interpretation from LLM for Backdoor Adj. Response: {response}")
|
172 |
+
return default_interpretation
|
173 |
+
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"Error during LLM interpretation for Backdoor Adj: {e}")
|
176 |
+
return f"Error generating interpretation: {e}"
|
auto_causal/methods/causal_method.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Abstract base class for all causal inference methods.
|
3 |
+
|
4 |
+
This module defines the interface that all causal inference methods
|
5 |
+
must implement, ensuring consistent behavior across different methods.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from abc import ABC, abstractmethod
|
9 |
+
from typing import Dict, List, Any
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
|
13 |
+
class CausalMethod(ABC):
|
14 |
+
"""Base class for all causal inference methods.
|
15 |
+
|
16 |
+
This abstract class defines the required methods that all causal
|
17 |
+
inference implementations must provide. It ensures a consistent
|
18 |
+
interface across different methods like propensity score matching,
|
19 |
+
instrumental variables, etc.
|
20 |
+
|
21 |
+
Each implementation should handle the specifics of the causal
|
22 |
+
inference method while conforming to this interface.
|
23 |
+
"""
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def validate_assumptions(self, df: pd.DataFrame, treatment: str,
|
27 |
+
outcome: str, covariates: List[str]) -> Dict[str, Any]:
|
28 |
+
"""Validate method assumptions against the dataset.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
df: DataFrame containing the dataset
|
32 |
+
treatment: Name of the treatment variable column
|
33 |
+
outcome: Name of the outcome variable column
|
34 |
+
covariates: List of covariate column names
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Dict containing validation results with keys:
|
38 |
+
- assumptions_valid (bool): Whether all assumptions are met
|
39 |
+
- failed_assumptions (List[str]): List of failed assumptions
|
40 |
+
- warnings (List[str]): List of warnings
|
41 |
+
- suggestions (List[str]): Suggestions for addressing issues
|
42 |
+
"""
|
43 |
+
pass
|
44 |
+
|
45 |
+
@abstractmethod
|
46 |
+
def estimate_effect(self, df: pd.DataFrame, treatment: str,
|
47 |
+
outcome: str, covariates: List[str]) -> Dict[str, Any]:
|
48 |
+
"""Estimate causal effect using this method.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
df: DataFrame containing the dataset
|
52 |
+
treatment: Name of the treatment variable column
|
53 |
+
outcome: Name of the outcome variable column
|
54 |
+
covariates: List of covariate column names
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Dict containing estimation results with keys:
|
58 |
+
- effect_estimate (float): Estimated causal effect
|
59 |
+
- confidence_interval (tuple): Confidence interval (lower, upper)
|
60 |
+
- p_value (float): P-value of the estimate
|
61 |
+
- additional_metrics (Dict): Any method-specific metrics
|
62 |
+
"""
|
63 |
+
pass
|
64 |
+
|
65 |
+
@abstractmethod
|
66 |
+
def generate_code(self, dataset_path: str, treatment: str,
|
67 |
+
outcome: str, covariates: List[str]) -> str:
|
68 |
+
"""Generate executable code for this causal method.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
dataset_path: Path to the dataset file
|
72 |
+
treatment: Name of the treatment variable column
|
73 |
+
outcome: Name of the outcome variable column
|
74 |
+
covariates: List of covariate column names
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
String containing executable Python code implementing this method
|
78 |
+
"""
|
79 |
+
pass
|
80 |
+
|
81 |
+
@abstractmethod
|
82 |
+
def explain(self) -> str:
|
83 |
+
"""Explain this causal method, its assumptions, and when to use it.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
String with detailed explanation of the method
|
87 |
+
"""
|
88 |
+
pass
|
auto_causal/methods/diff_in_means/__init__.py
ADDED
File without changes
|
auto_causal/methods/diff_in_means/diagnostics.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Basic descriptive statistics for Difference in Means.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Dict, Any
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
import logging
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
def run_dim_diagnostics(df: pd.DataFrame, treatment: str, outcome: str) -> Dict[str, Any]:
|
13 |
+
"""
|
14 |
+
Calculates basic descriptive statistics for treatment and control groups.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
df: Input DataFrame (should already be filtered for NaNs in treatment/outcome).
|
18 |
+
treatment: Name of the binary treatment variable column.
|
19 |
+
outcome: Name of the outcome variable column.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Dictionary containing group means, standard deviations, and counts.
|
23 |
+
"""
|
24 |
+
details = {}
|
25 |
+
try:
|
26 |
+
grouped = df.groupby(treatment)[outcome]
|
27 |
+
stats = grouped.agg(['mean', 'std', 'count'])
|
28 |
+
|
29 |
+
# Ensure both groups (0 and 1) are present if possible
|
30 |
+
control_stats = stats.loc[0].to_dict() if 0 in stats.index else {'mean': np.nan, 'std': np.nan, 'count': 0}
|
31 |
+
treated_stats = stats.loc[1].to_dict() if 1 in stats.index else {'mean': np.nan, 'std': np.nan, 'count': 0}
|
32 |
+
|
33 |
+
details['control_group_stats'] = control_stats
|
34 |
+
details['treated_group_stats'] = treated_stats
|
35 |
+
|
36 |
+
if control_stats['count'] == 0 or treated_stats['count'] == 0:
|
37 |
+
logger.warning("One or both treatment groups have zero observations.")
|
38 |
+
return {"status": "Warning - Empty Group(s)", "details": details}
|
39 |
+
|
40 |
+
# Simple check for variance difference (Levene's test could be added)
|
41 |
+
control_std = control_stats.get('std', 0)
|
42 |
+
treated_std = treated_stats.get('std', 0)
|
43 |
+
if control_std > 0 and treated_std > 0:
|
44 |
+
ratio = (control_std**2) / (treated_std**2)
|
45 |
+
details['variance_ratio_control_div_treated'] = ratio
|
46 |
+
if ratio > 4 or ratio < 0.25: # Rule of thumb
|
47 |
+
details['variance_homogeneity_status'] = "Potentially Unequal (ratio > 4 or < 0.25)"
|
48 |
+
else:
|
49 |
+
details['variance_homogeneity_status'] = "Likely Similar"
|
50 |
+
else:
|
51 |
+
details['variance_homogeneity_status'] = "Could not calculate (zero variance in a group)"
|
52 |
+
|
53 |
+
return {"status": "Success", "details": details}
|
54 |
+
|
55 |
+
except KeyError as ke:
|
56 |
+
logger.error(f"KeyError during diagnostics: {ke}. Treatment levels might not be 0/1.")
|
57 |
+
return {"status": "Failed", "error": f"Treatment levels might not be 0/1: {ke}", "details": details}
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Error running Difference in Means diagnostics: {e}")
|
60 |
+
return {"status": "Failed", "error": str(e), "details": details}
|
auto_causal/methods/diff_in_means/estimator.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Difference in Means / Simple Linear Regression Estimator.
|
3 |
+
|
4 |
+
Estimates the Average Treatment Effect (ATE) by comparing the mean outcome
|
5 |
+
between the treated and control groups. This is equivalent to a simple OLS
|
6 |
+
regression of the outcome on the treatment indicator.
|
7 |
+
|
8 |
+
Assumes no confounding (e.g., suitable for RCT data).
|
9 |
+
"""
|
10 |
+
import pandas as pd
|
11 |
+
import statsmodels.api as sm
|
12 |
+
import numpy as np
|
13 |
+
import warnings
|
14 |
+
from typing import Dict, Any, Optional
|
15 |
+
import logging
|
16 |
+
from langchain.chat_models.base import BaseChatModel # For type hinting llm
|
17 |
+
|
18 |
+
from .diagnostics import run_dim_diagnostics
|
19 |
+
from .llm_assist import interpret_dim_results
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
def estimate_effect(
|
24 |
+
df: pd.DataFrame,
|
25 |
+
treatment: str,
|
26 |
+
outcome: str,
|
27 |
+
query: Optional[str] = None, # For potential LLM use
|
28 |
+
llm: Optional[BaseChatModel] = None, # For potential LLM use
|
29 |
+
**kwargs # To capture any other potential arguments (e.g., covariates - which are ignored)
|
30 |
+
) -> Dict[str, Any]:
|
31 |
+
"""
|
32 |
+
Estimates the causal effect using Difference in Means (via OLS).
|
33 |
+
|
34 |
+
Ignores any provided covariates.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
df: Input DataFrame.
|
38 |
+
treatment: Name of the binary treatment variable column (should be 0 or 1).
|
39 |
+
outcome: Name of the outcome variable column.
|
40 |
+
query: Optional user query for context.
|
41 |
+
llm: Optional Language Model instance.
|
42 |
+
**kwargs: Additional keyword arguments (ignored).
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Dictionary containing estimation results:
|
46 |
+
- 'effect_estimate': The difference in means (treatment coefficient).
|
47 |
+
- 'p_value': The p-value associated with the difference.
|
48 |
+
- 'confidence_interval': The 95% confidence interval for the difference.
|
49 |
+
- 'standard_error': The standard error of the difference.
|
50 |
+
- 'formula': The regression formula used.
|
51 |
+
- 'model_summary': Summary object from statsmodels.
|
52 |
+
- 'diagnostics': Basic group statistics.
|
53 |
+
- 'interpretation': LLM interpretation.
|
54 |
+
"""
|
55 |
+
required_cols = [treatment, outcome]
|
56 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
57 |
+
if missing_cols:
|
58 |
+
raise ValueError(f"Missing required columns: {missing_cols}")
|
59 |
+
|
60 |
+
# Validate treatment is binary (or close to it)
|
61 |
+
treat_vals = df[treatment].dropna().unique()
|
62 |
+
if not np.all(np.isin(treat_vals, [0, 1])):
|
63 |
+
warnings.warn(f"Treatment column '{treatment}' contains values other than 0 and 1: {treat_vals}. Proceeding, but results may be unreliable.", UserWarning)
|
64 |
+
# Optional: could raise ValueError here if strict binary is required
|
65 |
+
|
66 |
+
# Prepare data for statsmodels (add constant, handle potential NaNs)
|
67 |
+
df_analysis = df[required_cols].dropna()
|
68 |
+
if df_analysis.empty:
|
69 |
+
raise ValueError("No data remaining after dropping NaNs for required columns.")
|
70 |
+
|
71 |
+
X = df_analysis[[treatment]]
|
72 |
+
X = sm.add_constant(X) # Add intercept
|
73 |
+
y = df_analysis[outcome]
|
74 |
+
|
75 |
+
formula = f"{outcome} ~ {treatment} + const"
|
76 |
+
logger.info(f"Running Difference in Means regression: {formula}")
|
77 |
+
|
78 |
+
try:
|
79 |
+
model = sm.OLS(y, X)
|
80 |
+
results = model.fit()
|
81 |
+
|
82 |
+
effect_estimate = results.params[treatment]
|
83 |
+
p_value = results.pvalues[treatment]
|
84 |
+
conf_int = results.conf_int(alpha=0.05).loc[treatment].tolist()
|
85 |
+
std_err = results.bse[treatment]
|
86 |
+
|
87 |
+
# Run basic diagnostics (group means, stds, counts)
|
88 |
+
diag_results = run_dim_diagnostics(df_analysis, treatment, outcome)
|
89 |
+
|
90 |
+
# Get interpretation
|
91 |
+
interpretation = interpret_dim_results(results, diag_results, treatment, llm=llm)
|
92 |
+
|
93 |
+
return {
|
94 |
+
'effect_estimate': effect_estimate,
|
95 |
+
'p_value': p_value,
|
96 |
+
'confidence_interval': conf_int,
|
97 |
+
'standard_error': std_err,
|
98 |
+
'formula': formula,
|
99 |
+
'model_summary': results.summary(),
|
100 |
+
'diagnostics': diag_results,
|
101 |
+
'interpretation': interpretation,
|
102 |
+
'method_used': 'Difference in Means (OLS)'
|
103 |
+
}
|
104 |
+
|
105 |
+
except Exception as e:
|
106 |
+
logger.error(f"Difference in Means failed: {e}")
|
107 |
+
raise
|
auto_causal/methods/diff_in_means/llm_assist.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM assistance functions for Difference in Means analysis.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Dict, Any, Optional
|
6 |
+
import logging
|
7 |
+
|
8 |
+
# Imported for type hinting
|
9 |
+
from langchain.chat_models.base import BaseChatModel
|
10 |
+
from statsmodels.regression.linear_model import RegressionResultsWrapper
|
11 |
+
|
12 |
+
# Import shared LLM helpers
|
13 |
+
from auto_causal.utils.llm_helpers import call_llm_with_json_output
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
def interpret_dim_results(
|
18 |
+
results: RegressionResultsWrapper,
|
19 |
+
diagnostics: Dict[str, Any],
|
20 |
+
treatment_var: str,
|
21 |
+
llm: Optional[BaseChatModel] = None
|
22 |
+
) -> str:
|
23 |
+
"""
|
24 |
+
Use LLM to interpret Difference in Means results.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
results: Fitted statsmodels OLS results object (from outcome ~ treatment).
|
28 |
+
diagnostics: Dictionary of diagnostic results (group stats).
|
29 |
+
treatment_var: Name of the treatment variable.
|
30 |
+
llm: Optional LLM model instance.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
String containing natural language interpretation.
|
34 |
+
"""
|
35 |
+
default_interpretation = "LLM interpretation not available for Difference in Means."
|
36 |
+
if llm is None:
|
37 |
+
logger.info("LLM not provided for Difference in Means interpretation.")
|
38 |
+
return default_interpretation
|
39 |
+
|
40 |
+
try:
|
41 |
+
# --- Prepare summary for LLM ---
|
42 |
+
results_summary = {}
|
43 |
+
diag_details = diagnostics.get('details', {})
|
44 |
+
control_stats = diag_details.get('control_group_stats', {})
|
45 |
+
treated_stats = diag_details.get('treated_group_stats', {})
|
46 |
+
|
47 |
+
effect = results.params.get(treatment_var)
|
48 |
+
pval = results.pvalues.get(treatment_var)
|
49 |
+
|
50 |
+
results_summary['Effect Estimate (Difference in Means)'] = f"{effect:.3f}" if isinstance(effect, (int, float)) else str(effect)
|
51 |
+
results_summary['P-value'] = f"{pval:.3f}" if isinstance(pval, (int, float)) else str(pval)
|
52 |
+
try:
|
53 |
+
conf_int = results.conf_int().loc[treatment_var]
|
54 |
+
results_summary['95% Confidence Interval'] = f"[{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
|
55 |
+
except KeyError:
|
56 |
+
results_summary['95% Confidence Interval'] = "Not Found"
|
57 |
+
except Exception as ci_e:
|
58 |
+
results_summary['95% Confidence Interval'] = f"Error ({ci_e})"
|
59 |
+
|
60 |
+
results_summary['Control Group Mean Outcome'] = f"{control_stats.get('mean', 'N/A'):.3f}" if isinstance(control_stats.get('mean'), (int, float)) else str(control_stats.get('mean'))
|
61 |
+
results_summary['Treated Group Mean Outcome'] = f"{treated_stats.get('mean', 'N/A'):.3f}" if isinstance(treated_stats.get('mean'), (int, float)) else str(treated_stats.get('mean'))
|
62 |
+
results_summary['Control Group Size'] = control_stats.get('count', 'N/A')
|
63 |
+
results_summary['Treated Group Size'] = treated_stats.get('count', 'N/A')
|
64 |
+
|
65 |
+
# --- Construct Prompt ---
|
66 |
+
prompt = f"""
|
67 |
+
You are assisting with interpreting Difference in Means results, likely from an RCT.
|
68 |
+
|
69 |
+
Results Summary:
|
70 |
+
{results_summary}
|
71 |
+
|
72 |
+
Explain these results in 1-3 concise sentences. Focus on:
|
73 |
+
1. The estimated average treatment effect (magnitude, direction, statistical significance based on p-value < 0.05).
|
74 |
+
2. Compare the mean outcomes between the treated and control groups.
|
75 |
+
|
76 |
+
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
|
77 |
+
{{
|
78 |
+
"interpretation": "<your concise interpretation text>"
|
79 |
+
}}
|
80 |
+
"""
|
81 |
+
|
82 |
+
# --- Call LLM ---
|
83 |
+
response = call_llm_with_json_output(llm, prompt)
|
84 |
+
|
85 |
+
# --- Process Response ---
|
86 |
+
if response and isinstance(response, dict) and \
|
87 |
+
"interpretation" in response and isinstance(response["interpretation"], str):
|
88 |
+
return response["interpretation"]
|
89 |
+
else:
|
90 |
+
logger.warning(f"Failed to get valid interpretation from LLM for Difference in Means. Response: {response}")
|
91 |
+
return default_interpretation
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
logger.error(f"Error during LLM interpretation for Difference in Means: {e}")
|
95 |
+
return f"Error generating interpretation: {e}"
|
auto_causal/methods/difference_in_differences/diagnostics.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Diagnostic functions for Difference-in-Differences method."""
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from typing import Dict, Any, Optional, List
|
6 |
+
import logging
|
7 |
+
import statsmodels.formula.api as smf # Import statsmodels
|
8 |
+
from patsy import PatsyError # To catch formula errors
|
9 |
+
|
10 |
+
# Import helper function from estimator -> Change to utils
|
11 |
+
from .utils import create_post_indicator
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
def validate_parallel_trends(df: pd.DataFrame, time_var: str, outcome: str,
|
16 |
+
group_indicator_col: str, treatment_period_start: Any,
|
17 |
+
dataset_description: Optional[str] = None,
|
18 |
+
time_varying_covariates: Optional[List[str]] = None) -> Dict[str, Any]:
|
19 |
+
"""Validates the parallel trends assumption using pre-treatment data.
|
20 |
+
|
21 |
+
Regresses the outcome on group-specific time trends before the treatment period.
|
22 |
+
Tests if the interaction terms between group and pre-treatment time periods are jointly significant.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
df: DataFrame containing the data.
|
26 |
+
time_var: Name of the time variable column.
|
27 |
+
outcome: Name of the outcome variable column.
|
28 |
+
group_indicator_col: Name of the binary treatment group indicator column (0/1).
|
29 |
+
treatment_period_start: The time period value when treatment starts.
|
30 |
+
dataset_description: Optional dictionary for additional dataset description.
|
31 |
+
time_varying_covariates: Optional list of time-varying covariates to include.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
Dictionary with validation results.
|
35 |
+
"""
|
36 |
+
logger.info("Validating parallel trends...")
|
37 |
+
validation_result = {"valid": False, "p_value": 1.0, "details": "", "error": None}
|
38 |
+
|
39 |
+
try:
|
40 |
+
# Filter pre-treatment data
|
41 |
+
pre_df = df[df[time_var] < treatment_period_start].copy()
|
42 |
+
|
43 |
+
if len(pre_df) < 20 or pre_df[group_indicator_col].nunique() < 2 or pre_df[time_var].nunique() < 2:
|
44 |
+
validation_result["details"] = "Insufficient pre-treatment data or variation to perform test."
|
45 |
+
logger.warning(validation_result["details"])
|
46 |
+
# Assume valid if cannot test? Or invalid? Let's default to True if we can't test
|
47 |
+
validation_result["valid"] = True
|
48 |
+
validation_result["details"] += " Defaulting to assuming parallel trends (unable to test)."
|
49 |
+
return validation_result
|
50 |
+
|
51 |
+
# Check if group indicator is binary
|
52 |
+
if pre_df[group_indicator_col].nunique() > 2:
|
53 |
+
validation_result["details"] = f"Group indicator '{group_indicator_col}' has more than 2 unique values. Using simple visual assessment."
|
54 |
+
logger.warning(validation_result["details"])
|
55 |
+
# Use visual assessment method instead (check if trends look roughly parallel)
|
56 |
+
validation_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col)
|
57 |
+
# Ensure p_value is set
|
58 |
+
if validation_result["p_value"] is None:
|
59 |
+
validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04
|
60 |
+
return validation_result
|
61 |
+
|
62 |
+
# Use a robust approach first - test for pre-trend differences using a simpler model
|
63 |
+
try:
|
64 |
+
# Create a linear time trend
|
65 |
+
pre_df['time_trend'] = pre_df[time_var].astype(float)
|
66 |
+
|
67 |
+
# Create interaction between trend and group
|
68 |
+
pre_df['group_trend'] = pre_df['time_trend'] * pre_df[group_indicator_col].astype(float)
|
69 |
+
|
70 |
+
# Simple regression with linear trend interaction
|
71 |
+
simple_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}') + time_trend + group_trend"
|
72 |
+
simple_model = smf.ols(simple_formula, data=pre_df)
|
73 |
+
simple_results = simple_model.fit()
|
74 |
+
|
75 |
+
# Check if trend interaction coefficient is significant
|
76 |
+
group_trend_pvalue = simple_results.pvalues['group_trend']
|
77 |
+
|
78 |
+
# If p > 0.05, trends are not significantly different
|
79 |
+
validation_result["valid"] = group_trend_pvalue > 0.05
|
80 |
+
validation_result["p_value"] = group_trend_pvalue
|
81 |
+
validation_result["details"] = f"Simple linear trend test: p-value for group-trend interaction: {group_trend_pvalue:.4f}. Parallel trends: {validation_result['valid']}."
|
82 |
+
logger.info(validation_result["details"])
|
83 |
+
|
84 |
+
# If we've successfully validated with the simple approach, return
|
85 |
+
return validation_result
|
86 |
+
|
87 |
+
except Exception as e:
|
88 |
+
logger.warning(f"Simple trend test failed: {e}. Trying alternative approach.")
|
89 |
+
# Continue to more complex method if simple method fails
|
90 |
+
|
91 |
+
# Try more complex approach with period-specific interactions
|
92 |
+
try:
|
93 |
+
# Create period dummies to avoid issues with categorical variables
|
94 |
+
time_periods = sorted(pre_df[time_var].unique())
|
95 |
+
|
96 |
+
# Create dummy variables for time periods (except first)
|
97 |
+
for period in time_periods[1:]:
|
98 |
+
period_col = f'period_{period}'
|
99 |
+
pre_df[period_col] = (pre_df[time_var] == period).astype(int)
|
100 |
+
|
101 |
+
# Create interaction with group
|
102 |
+
pre_df[f'group_x_{period_col}'] = pre_df[period_col] * pre_df[group_indicator_col].astype(float)
|
103 |
+
|
104 |
+
# Construct formula with manual dummies
|
105 |
+
interaction_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}')"
|
106 |
+
|
107 |
+
# Add period dummies except first (reference)
|
108 |
+
for period in time_periods[1:]:
|
109 |
+
period_col = f'period_{period}'
|
110 |
+
interaction_formula += f" + {period_col}"
|
111 |
+
|
112 |
+
# Add interactions
|
113 |
+
interaction_terms = []
|
114 |
+
for period in time_periods[1:]:
|
115 |
+
interaction_col = f'group_x_period_{period}'
|
116 |
+
interaction_formula += f" + {interaction_col}"
|
117 |
+
interaction_terms.append(interaction_col)
|
118 |
+
|
119 |
+
# Add covariates if provided
|
120 |
+
if time_varying_covariates:
|
121 |
+
for cov in time_varying_covariates:
|
122 |
+
interaction_formula += f" + Q('{cov}')"
|
123 |
+
|
124 |
+
# Fit model
|
125 |
+
complex_model = smf.ols(interaction_formula, data=pre_df)
|
126 |
+
complex_results = complex_model.fit()
|
127 |
+
|
128 |
+
# Test joint significance of interaction terms
|
129 |
+
if interaction_terms:
|
130 |
+
from statsmodels.formula.api import ols
|
131 |
+
from statsmodels.stats.anova import anova_lm
|
132 |
+
|
133 |
+
# Create models with and without interactions
|
134 |
+
formula_with = interaction_formula
|
135 |
+
formula_without = interaction_formula
|
136 |
+
for term in interaction_terms:
|
137 |
+
formula_without = formula_without.replace(f" + {term}", "")
|
138 |
+
|
139 |
+
model_with = smf.ols(formula_with, data=pre_df).fit()
|
140 |
+
model_without = smf.ols(formula_without, data=pre_df).fit()
|
141 |
+
|
142 |
+
# Compare models
|
143 |
+
try:
|
144 |
+
from scipy import stats
|
145 |
+
df_model = len(interaction_terms)
|
146 |
+
df_residual = model_with.df_resid
|
147 |
+
f_value = ((model_without.ssr - model_with.ssr) / df_model) / (model_with.ssr / df_residual)
|
148 |
+
p_value = 1 - stats.f.cdf(f_value, df_model, df_residual)
|
149 |
+
|
150 |
+
validation_result["valid"] = p_value > 0.05
|
151 |
+
validation_result["p_value"] = p_value
|
152 |
+
validation_result["details"] = f"Manual F-test for pre-treatment interactions: F({df_model}, {df_residual})={f_value:.4f}, p={p_value:.4f}. Parallel trends: {validation_result['valid']}."
|
153 |
+
logger.info(validation_result["details"])
|
154 |
+
|
155 |
+
except Exception as e:
|
156 |
+
logger.warning(f"Manual F-test failed: {e}. Using individual coefficient significance.")
|
157 |
+
|
158 |
+
# If F-test fails, check individual coefficients
|
159 |
+
significant_interactions = 0
|
160 |
+
for term in interaction_terms:
|
161 |
+
if term in complex_results.pvalues and complex_results.pvalues[term] < 0.05:
|
162 |
+
significant_interactions += 1
|
163 |
+
|
164 |
+
validation_result["valid"] = significant_interactions == 0
|
165 |
+
# Set a dummy p-value based on proportion of significant interactions
|
166 |
+
if len(interaction_terms) > 0:
|
167 |
+
validation_result["p_value"] = 1.0 - (significant_interactions / len(interaction_terms))
|
168 |
+
else:
|
169 |
+
validation_result["p_value"] = 1.0 # Default to 1.0 if no interaction terms
|
170 |
+
validation_result["details"] = f"{significant_interactions} out of {len(interaction_terms)} pre-treatment interactions are significant at p<0.05. Parallel trends: {validation_result['valid']}."
|
171 |
+
logger.info(validation_result["details"])
|
172 |
+
else:
|
173 |
+
validation_result["valid"] = True
|
174 |
+
validation_result["p_value"] = 1.0 # Default to 1.0 if no interaction terms
|
175 |
+
validation_result["details"] = "No pre-treatment interaction terms could be tested. Defaulting to assuming parallel trends."
|
176 |
+
logger.warning(validation_result["details"])
|
177 |
+
|
178 |
+
except Exception as e:
|
179 |
+
logger.warning(f"Complex trend test failed: {e}. Falling back to visual assessment.")
|
180 |
+
tmp_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col)
|
181 |
+
# Copy over values from visual assessment ensuring p_value is set
|
182 |
+
validation_result.update(tmp_result)
|
183 |
+
# Ensure p_value is set
|
184 |
+
if validation_result["p_value"] is None:
|
185 |
+
validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04
|
186 |
+
|
187 |
+
except Exception as e:
|
188 |
+
error_msg = f"Error during parallel trends validation: {e}"
|
189 |
+
logger.error(error_msg, exc_info=True)
|
190 |
+
validation_result["details"] = error_msg
|
191 |
+
validation_result["error"] = str(e)
|
192 |
+
# Default to assuming valid if test fails completely
|
193 |
+
validation_result["valid"] = True
|
194 |
+
validation_result["p_value"] = 1.0 # Default to 1.0 if test fails
|
195 |
+
validation_result["details"] += " Defaulting to assuming parallel trends (test failed)."
|
196 |
+
|
197 |
+
return validation_result
|
198 |
+
|
199 |
+
def assess_trends_visually(df: pd.DataFrame, time_var: str, outcome: str,
|
200 |
+
group_indicator_col: str) -> Dict[str, Any]:
|
201 |
+
"""Simple visual assessment of parallel trends by comparing group means over time.
|
202 |
+
|
203 |
+
This is a fallback method when statistical tests fail.
|
204 |
+
"""
|
205 |
+
result = {"valid": False, "p_value": 1.0, "details": "", "error": None}
|
206 |
+
|
207 |
+
try:
|
208 |
+
# Group by time and treatment group, calculate means
|
209 |
+
grouped = df.groupby([time_var, group_indicator_col])[outcome].mean().reset_index()
|
210 |
+
|
211 |
+
# Pivot to get time series for each group
|
212 |
+
if df[group_indicator_col].nunique() <= 10: # Only if reasonable number of groups
|
213 |
+
pivot = grouped.pivot(index=time_var, columns=group_indicator_col, values=outcome)
|
214 |
+
|
215 |
+
# Calculate slopes between consecutive periods for each group
|
216 |
+
slopes = {}
|
217 |
+
time_values = sorted(df[time_var].unique())
|
218 |
+
|
219 |
+
if len(time_values) >= 3: # Need at least 3 periods to compare slopes
|
220 |
+
for group in pivot.columns:
|
221 |
+
group_slopes = []
|
222 |
+
for i in range(len(time_values) - 1):
|
223 |
+
t1, t2 = time_values[i], time_values[i+1]
|
224 |
+
if t1 in pivot.index and t2 in pivot.index:
|
225 |
+
slope = (pivot.loc[t2, group] - pivot.loc[t1, group]) / (t2 - t1)
|
226 |
+
group_slopes.append(slope)
|
227 |
+
if group_slopes:
|
228 |
+
slopes[group] = group_slopes
|
229 |
+
|
230 |
+
# Compare slopes between groups
|
231 |
+
if len(slopes) >= 2:
|
232 |
+
slope_diffs = []
|
233 |
+
groups = list(slopes.keys())
|
234 |
+
for i in range(len(slopes[groups[0]])):
|
235 |
+
if i < len(slopes[groups[1]]):
|
236 |
+
slope_diffs.append(abs(slopes[groups[0]][i] - slopes[groups[1]][i]))
|
237 |
+
|
238 |
+
# If average slope difference is small relative to outcome scale
|
239 |
+
outcome_scale = df[outcome].std()
|
240 |
+
avg_slope_diff = sum(slope_diffs) / len(slope_diffs) if slope_diffs else 0
|
241 |
+
relative_diff = avg_slope_diff / outcome_scale if outcome_scale > 0 else 0
|
242 |
+
|
243 |
+
result["valid"] = relative_diff < 0.2 # Threshold for "parallel enough"
|
244 |
+
# Set p-value based on relative difference
|
245 |
+
result["p_value"] = 1.0 - (relative_diff * 5) if relative_diff < 0.2 else 0.04
|
246 |
+
result["details"] = f"Visual assessment: relative slope difference = {relative_diff:.4f}. Parallel trends: {result['valid']}."
|
247 |
+
else:
|
248 |
+
result["valid"] = True
|
249 |
+
result["p_value"] = 1.0
|
250 |
+
result["details"] = "Visual assessment: insufficient group data for comparison. Defaulting to assuming parallel trends."
|
251 |
+
else:
|
252 |
+
result["valid"] = True
|
253 |
+
result["p_value"] = 1.0
|
254 |
+
result["details"] = "Visual assessment: insufficient time periods for comparison. Defaulting to assuming parallel trends."
|
255 |
+
else:
|
256 |
+
result["valid"] = True
|
257 |
+
result["p_value"] = 1.0
|
258 |
+
result["details"] = f"Visual assessment: too many groups ({df[group_indicator_col].nunique()}) for visual comparison. Defaulting to assuming parallel trends."
|
259 |
+
|
260 |
+
except Exception as e:
|
261 |
+
result["error"] = str(e)
|
262 |
+
result["valid"] = True
|
263 |
+
result["p_value"] = 1.0
|
264 |
+
result["details"] = f"Visual assessment failed: {e}. Defaulting to assuming parallel trends."
|
265 |
+
|
266 |
+
logger.info(result["details"])
|
267 |
+
return result
|
268 |
+
|
269 |
+
def run_placebo_test(df: pd.DataFrame, time_var: str, group_var: str, outcome: str,
|
270 |
+
treated_unit_indicator: str, covariates: List[str],
|
271 |
+
treatment_period_start: Any,
|
272 |
+
placebo_period_start: Any) -> Dict[str, Any]:
|
273 |
+
"""Runs a placebo test for DiD by assigning a fake earlier treatment period.
|
274 |
+
|
275 |
+
Re-runs the DiD estimation using the placebo period and checks if the effect is non-significant.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
df: Original DataFrame.
|
279 |
+
time_var: Name of the time variable column.
|
280 |
+
group_var: Name of the unit/group ID column (for clustering SE).
|
281 |
+
outcome: Name of the outcome variable column.
|
282 |
+
treated_unit_indicator: Name of the binary treatment group indicator column (0/1).
|
283 |
+
covariates: List of covariate names.
|
284 |
+
treatment_period_start: The actual treatment start period.
|
285 |
+
placebo_period_start: The fake treatment start period (must be before actual start).
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
Dictionary with placebo test results.
|
289 |
+
"""
|
290 |
+
logger.info(f"Running placebo test assigning treatment start at {placebo_period_start}...")
|
291 |
+
placebo_result = {"passed": False, "effect_estimate": None, "p_value": None, "details": "", "error": None}
|
292 |
+
|
293 |
+
if placebo_period_start >= treatment_period_start:
|
294 |
+
error_msg = "Placebo period must be before the actual treatment period."
|
295 |
+
logger.error(error_msg)
|
296 |
+
placebo_result["error"] = error_msg
|
297 |
+
placebo_result["details"] = error_msg
|
298 |
+
return placebo_result
|
299 |
+
|
300 |
+
try:
|
301 |
+
df_placebo = df.copy()
|
302 |
+
# Create placebo post and interaction terms
|
303 |
+
post_placebo_col = 'post_placebo'
|
304 |
+
interaction_placebo_col = 'did_interaction_placebo'
|
305 |
+
|
306 |
+
df_placebo[post_placebo_col] = create_post_indicator(df_placebo, time_var, placebo_period_start)
|
307 |
+
df_placebo[interaction_placebo_col] = df_placebo[treated_unit_indicator] * df_placebo[post_placebo_col]
|
308 |
+
|
309 |
+
# Construct formula for placebo regression
|
310 |
+
formula = f"`{outcome}` ~ `{treated_unit_indicator}` + `{post_placebo_col}` + `{interaction_placebo_col}`"
|
311 |
+
if covariates:
|
312 |
+
formula += f" + {' + '.join([f'`{c}`' for c in covariates])}"
|
313 |
+
formula += f" + C(`{group_var}`) + C(`{time_var}`)" # Include FEs
|
314 |
+
|
315 |
+
logger.debug(f"Placebo test formula: {formula}")
|
316 |
+
|
317 |
+
# Fit the placebo model with clustered SE
|
318 |
+
ols_model = smf.ols(formula=formula, data=df_placebo)
|
319 |
+
results = ols_model.fit(cov_type='cluster', cov_kwds={'groups': df_placebo[group_var]})
|
320 |
+
|
321 |
+
# Check the significance of the placebo interaction term
|
322 |
+
placebo_effect = float(results.params[interaction_placebo_col])
|
323 |
+
placebo_p_value = float(results.pvalues[interaction_placebo_col])
|
324 |
+
|
325 |
+
# Test passes if the placebo effect is not statistically significant (e.g., p > 0.1)
|
326 |
+
passed_test = placebo_p_value > 0.10
|
327 |
+
|
328 |
+
placebo_result["passed"] = passed_test
|
329 |
+
placebo_result["effect_estimate"] = placebo_effect
|
330 |
+
placebo_result["p_value"] = placebo_p_value
|
331 |
+
placebo_result["details"] = f"Placebo treatment effect estimated at {placebo_effect:.4f} (p={placebo_p_value:.4f}). Test passed: {passed_test}."
|
332 |
+
logger.info(placebo_result["details"])
|
333 |
+
|
334 |
+
except (KeyError, PatsyError, ValueError, Exception) as e:
|
335 |
+
error_msg = f"Error during placebo test execution: {e}"
|
336 |
+
logger.error(error_msg, exc_info=True)
|
337 |
+
placebo_result["details"] = error_msg
|
338 |
+
placebo_result["error"] = str(e)
|
339 |
+
|
340 |
+
return placebo_result
|
341 |
+
|
342 |
+
# TODO: Add function for Event Study plot (plot_event_study)
|
343 |
+
# This would involve estimating effects for leads and lags around the treatment period.
|
344 |
+
|
345 |
+
# Add other diagnostic functions as needed (e.g., plot_event_study)
|
auto_causal/methods/difference_in_differences/estimator.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Difference-in-Differences Estimator using DoWhy with Statsmodels fallback.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
from typing import Dict, List, Optional, Any, Tuple
|
9 |
+
from auto_causal.config import get_llm_client # IMPORT LLM Client Factory
|
10 |
+
|
11 |
+
# DoWhy imports (Commented out for simplification)
|
12 |
+
# from dowhy import CausalModel
|
13 |
+
# from dowhy.causal_estimators import CausalEstimator
|
14 |
+
# from dowhy.causal_estimator import CausalEstimate
|
15 |
+
# Statsmodels import for estimation
|
16 |
+
import statsmodels.formula.api as smf
|
17 |
+
|
18 |
+
# Local imports
|
19 |
+
from .llm_assist import (
|
20 |
+
identify_time_variable,
|
21 |
+
determine_treatment_period,
|
22 |
+
identify_treatment_group,
|
23 |
+
interpret_did_results
|
24 |
+
)
|
25 |
+
from .diagnostics import validate_parallel_trends # Import diagnostics
|
26 |
+
# Import from the new utils module
|
27 |
+
from .utils import create_post_indicator
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
# --- Helper functions moved from old file ---
|
32 |
+
def format_did_results(statsmodels_results: Any, interaction_term_key: str,
|
33 |
+
validation_results: Dict[str, Any],
|
34 |
+
method_details: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
|
35 |
+
'''Formats the DiD results from statsmodels results into a standard dictionary.'''
|
36 |
+
|
37 |
+
try:
|
38 |
+
# Use the interaction_term_key passed directly
|
39 |
+
effect = float(statsmodels_results.params[interaction_term_key])
|
40 |
+
stderr = float(statsmodels_results.bse[interaction_term_key])
|
41 |
+
pval = float(statsmodels_results.pvalues[interaction_term_key])
|
42 |
+
ci = statsmodels_results.conf_int().loc[interaction_term_key].values.tolist()
|
43 |
+
ci_lower, ci_upper = float(ci[0]), float(ci[1])
|
44 |
+
logger.info(f"Extracted effect for '{interaction_term_key}'")
|
45 |
+
|
46 |
+
except KeyError:
|
47 |
+
logger.error(f"Interaction term '{interaction_term_key}' not found in statsmodels results. Available params: {statsmodels_results.params.index.tolist()}")
|
48 |
+
# Fallback to NaN if term not found
|
49 |
+
effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan
|
50 |
+
except Exception as e:
|
51 |
+
logger.error(f"Error extracting results from statsmodels object: {e}")
|
52 |
+
effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan
|
53 |
+
|
54 |
+
# Create a standardized results dictionary
|
55 |
+
results = {
|
56 |
+
"effect_estimate": effect,
|
57 |
+
"standard_error": stderr,
|
58 |
+
"p_value": pval,
|
59 |
+
"confidence_interval": [ci_lower, ci_upper],
|
60 |
+
"diagnostics": validation_results,
|
61 |
+
"parameters": parameters,
|
62 |
+
"details": str(statsmodels_results.summary())
|
63 |
+
}
|
64 |
+
|
65 |
+
return results
|
66 |
+
|
67 |
+
# Comment out unused DoWhy result formatter
|
68 |
+
# def format_dowhy_results(estimate: CausalEstimate,
|
69 |
+
# validation_results: Dict[str, Any],
|
70 |
+
# parameters: Dict[str, Any]) -> Dict[str, Any]:
|
71 |
+
# '''Formats the DiD results from DoWhy causal estimate into a standard dictionary.'''
|
72 |
+
|
73 |
+
# try:
|
74 |
+
# # Extract values from DoWhy estimate
|
75 |
+
# effect = float(estimate.value)
|
76 |
+
# stderr = float(estimate.get_standard_error()) if hasattr(estimate, 'get_standard_error') else np.nan
|
77 |
+
# ci_lower, ci_upper = estimate.get_confidence_intervals() if hasattr(estimate, 'get_confidence_intervals') else (np.nan, np.nan)
|
78 |
+
# # Extract p-value if available, otherwise use NaN
|
79 |
+
# pval = estimate.get_significance_test_results().get('p_value', np.nan) if hasattr(estimate, 'get_significance_test_results') else np.nan
|
80 |
+
|
81 |
+
# # Get available details from estimate
|
82 |
+
# details = str(estimate)
|
83 |
+
# if hasattr(estimate, 'summary'):
|
84 |
+
# details = str(estimate.summary())
|
85 |
+
|
86 |
+
# logger.info(f"Extracted effect from DoWhy estimate: {effect}")
|
87 |
+
|
88 |
+
# except Exception as e:
|
89 |
+
# logger.error(f"Error extracting results from DoWhy estimate: {e}")
|
90 |
+
# effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan
|
91 |
+
# details = f"Error extracting DoWhy results: {e}"
|
92 |
+
|
93 |
+
# # Create a standardized results dictionary
|
94 |
+
# results = {
|
95 |
+
# "effect_estimate": effect,
|
96 |
+
# "effect_se": stderr,
|
97 |
+
# "p_value": pval,
|
98 |
+
# "confidence_interval": [ci_lower, ci_upper],
|
99 |
+
# "diagnostics": validation_results,
|
100 |
+
# "parameters": parameters,
|
101 |
+
# "details": details,
|
102 |
+
# "estimator": "dowhy"
|
103 |
+
# }
|
104 |
+
|
105 |
+
# return results
|
106 |
+
|
107 |
+
# --- Main `estimate_effect` function ---
|
108 |
+
|
109 |
+
def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str,
|
110 |
+
covariates: List[str],
|
111 |
+
dataset_description: Optional[str] = None,
|
112 |
+
query: Optional[str] = None,
|
113 |
+
**kwargs) -> Dict[str, Any]:
|
114 |
+
"""Difference-in-Differences estimation using DoWhy with Statsmodels fallback.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
df: Dataset containing causal variables
|
118 |
+
treatment: Name of treatment variable (or variable indicating treated group)
|
119 |
+
outcome: Name of outcome variable
|
120 |
+
covariates: List of covariate names
|
121 |
+
dataset_description: Optional dictionary describing the dataset
|
122 |
+
**kwargs: Method-specific parameters (e.g., time_var, group_var, query, llm instance if needed)
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Dictionary with effect estimate and diagnostics
|
126 |
+
"""
|
127 |
+
query = kwargs.get('query_str')
|
128 |
+
# llm_instance = kwargs.get('llm') # Pass llm if helpers need it
|
129 |
+
df_processed = df.copy() # Work on a copy
|
130 |
+
|
131 |
+
logger.info("Starting DiD estimation using DoWhy with Statsmodels fallback...")
|
132 |
+
|
133 |
+
# --- Step 1: Identify Key Variables (using LLM Assist placeholders) ---
|
134 |
+
# Pass llm_instance to helpers if they are implemented to use it
|
135 |
+
llm_instance = get_llm_client() # Get llm instance if passed
|
136 |
+
time_var = kwargs.get('time_variable', identify_time_variable(df_processed, query, dataset_description, llm=llm_instance))
|
137 |
+
if time_var is None:
|
138 |
+
raise ValueError("Time variable could not be identified for DiD.")
|
139 |
+
if time_var not in df_processed.columns:
|
140 |
+
raise ValueError(f"Identified time variable '{time_var}' not found in DataFrame.")
|
141 |
+
|
142 |
+
# Determine the variable that identifies the panel unit (for grouping/FE)
|
143 |
+
group_var = kwargs.get('group_variable', identify_treatment_group(df_processed, treatment, query, dataset_description, llm=llm_instance))
|
144 |
+
if group_var is None:
|
145 |
+
raise ValueError("Group/Unit variable could not be identified for DiD.")
|
146 |
+
if group_var not in df_processed.columns:
|
147 |
+
raise ValueError(f"Identified group/unit variable '{group_var}' not found in DataFrame.")
|
148 |
+
|
149 |
+
# Check outcome exists before proceeding further
|
150 |
+
if outcome not in df_processed.columns:
|
151 |
+
raise ValueError(f"Outcome variable '{outcome}' not found in DataFrame.")
|
152 |
+
|
153 |
+
# Determine treatment period start
|
154 |
+
treatment_period = kwargs.get('treatment_period_start', kwargs.get('treatment_period',
|
155 |
+
determine_treatment_period(df_processed, time_var, treatment, query, dataset_description, llm=llm_instance)))
|
156 |
+
|
157 |
+
# --- Identify the TRUE binary treatment group indicator column ---
|
158 |
+
treated_group_col_for_formula = None
|
159 |
+
|
160 |
+
# Priority 1: Check if the 'treatment' argument itself is a valid binary indicator
|
161 |
+
if treatment in df_processed.columns and pd.api.types.is_numeric_dtype(df_processed[treatment]):
|
162 |
+
unique_treat_vals = set(df_processed[treatment].dropna().unique())
|
163 |
+
if unique_treat_vals.issubset({0, 1}):
|
164 |
+
treated_group_col_for_formula = treatment
|
165 |
+
logger.info(f"Using the provided 'treatment' argument '{treatment}' as binary group indicator.")
|
166 |
+
|
167 |
+
# Priority 2: Check if a column explicitly named 'group' exists and is binary
|
168 |
+
if treated_group_col_for_formula is None and 'group' in df_processed.columns and pd.api.types.is_numeric_dtype(df_processed['group']):
|
169 |
+
unique_group_vals = set(df_processed['group'].dropna().unique())
|
170 |
+
if unique_group_vals.issubset({0, 1}):
|
171 |
+
treated_group_col_for_formula = 'group'
|
172 |
+
logger.info(f"Using column 'group' as binary group indicator.")
|
173 |
+
|
174 |
+
# Priority 3: Fallback - Search other columns (excluding known roles and time-related ones)
|
175 |
+
if treated_group_col_for_formula is None:
|
176 |
+
logger.warning(f"Provided 'treatment' arg '{treatment}' is not binary 0/1 and no 'group' column found. Searching other columns...")
|
177 |
+
potential_group_cols = []
|
178 |
+
# Exclude outcome, time var, unit ID var, and common time indicators like 'post'
|
179 |
+
excluded_cols = [outcome, time_var, group_var, 'post', 'is_post_treatment', 'did_interaction']
|
180 |
+
for col_name in df_processed.columns:
|
181 |
+
if col_name in excluded_cols:
|
182 |
+
continue
|
183 |
+
try:
|
184 |
+
col_data = df_processed[col_name]
|
185 |
+
# Ensure we are working with a Series
|
186 |
+
if isinstance(col_data, pd.DataFrame):
|
187 |
+
if col_data.shape[1] == 1:
|
188 |
+
col_data = col_data.iloc[:, 0] # Extract the Series
|
189 |
+
else:
|
190 |
+
logger.warning(f"Skipping multi-column DataFrame slice for '{col_name}'.")
|
191 |
+
continue
|
192 |
+
|
193 |
+
# Check if the Series can be interpreted as binary 0/1
|
194 |
+
if not pd.api.types.is_numeric_dtype(col_data) and not pd.api.types.is_bool_dtype(col_data):
|
195 |
+
continue # Skip non-numeric/non-boolean columns
|
196 |
+
|
197 |
+
unique_vals = set(col_data.dropna().unique())
|
198 |
+
# Simplified check: directly test if unique values are a subset of {0, 1}
|
199 |
+
if unique_vals.issubset({0, 1}):
|
200 |
+
logger.info(f" Found potential binary indicator: {col_name}")
|
201 |
+
potential_group_cols.append(col_name)
|
202 |
+
|
203 |
+
except AttributeError as ae:
|
204 |
+
# Catch attribute errors likely due to unexpected types
|
205 |
+
logger.warning(f"Attribute error checking column '{col_name}': {ae}. Skipping.")
|
206 |
+
except Exception as e:
|
207 |
+
logger.warning(f"Unexpected error checking column '{col_name}' during group ID search: {e}")
|
208 |
+
|
209 |
+
if potential_group_cols:
|
210 |
+
treated_group_col_for_formula = potential_group_cols[0] # Take the first suitable one found
|
211 |
+
logger.info(f"Using column '{treated_group_col_for_formula}' found during search as binary group indicator.")
|
212 |
+
else:
|
213 |
+
# Final fallback: Use the originally identified group_var, but warn heavily
|
214 |
+
treated_group_col_for_formula = group_var
|
215 |
+
logger.error(f"CRITICAL WARNING: Could not find suitable binary treatment group indicator. Using '{group_var}', but this is likely incorrect and will produce invalid DiD estimates.")
|
216 |
+
|
217 |
+
# --- Final Check ---
|
218 |
+
if treated_group_col_for_formula not in df_processed.columns:
|
219 |
+
# This case should ideally not happen with the logic above but added defensively
|
220 |
+
raise ValueError(f"Determined treatment group column '{treated_group_col_for_formula}' not found in DataFrame.")
|
221 |
+
if df_processed[treated_group_col_for_formula].nunique(dropna=True) > 2:
|
222 |
+
logger.warning(f"Selected treatment group column '{treated_group_col_for_formula}' is not binary (has {df_processed[treated_group_col_for_formula].nunique()} unique values). DiD requires binary treatment group.")
|
223 |
+
|
224 |
+
# --- Step 2: Create Indicator Variables ---
|
225 |
+
post_indicator_col = 'post'
|
226 |
+
if post_indicator_col not in df_processed.columns:
|
227 |
+
# Create the post indicator if it doesn't exist
|
228 |
+
df_processed[post_indicator_col] = create_post_indicator(df_processed, time_var, treatment_period)
|
229 |
+
|
230 |
+
# Interaction term is treatment group * post
|
231 |
+
interaction_term_col = 'did_interaction' # Keep explicit interaction term
|
232 |
+
df_processed[interaction_term_col] = df_processed[treated_group_col_for_formula] * df_processed[post_indicator_col]
|
233 |
+
|
234 |
+
# --- Step 3: Validate Parallel Trends (using the group column) ---
|
235 |
+
parallel_trends_validation = validate_parallel_trends(df_processed, time_var, outcome,
|
236 |
+
treated_group_col_for_formula, treatment_period, dataset_description)
|
237 |
+
# Note: The validation result is currently just a placeholder
|
238 |
+
if not parallel_trends_validation.get('valid', False):
|
239 |
+
logger.warning("Parallel trends assumption potentially violated (based on placeholder check). Proceeding with estimation, but results may be biased.")
|
240 |
+
# Add this info to the final results diagnostics
|
241 |
+
|
242 |
+
# --- Step 4: Prepare for Statsmodels Estimation ---
|
243 |
+
# (DoWhy section commented out for simplicity)
|
244 |
+
# all_common_causes = covariates + [time_var, group_var] # group_var is unit ID
|
245 |
+
# use_dowhy_estimate = False
|
246 |
+
# dowhy_estimate = None
|
247 |
+
|
248 |
+
# try:
|
249 |
+
# # Create DoWhy CausalModel
|
250 |
+
# model = CausalModel(
|
251 |
+
# data=df_processed,
|
252 |
+
# treatment=treated_group_col_for_formula, # Use group indicator here
|
253 |
+
# outcome=outcome,
|
254 |
+
# common_causes=all_common_causes,
|
255 |
+
# )
|
256 |
+
# logger.info("DoWhy CausalModel created for DiD estimation.")
|
257 |
+
|
258 |
+
# # Identify estimand
|
259 |
+
# identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
|
260 |
+
# logger.info(f"DoWhy identified estimand: {identified_estimand.estimand_type}")
|
261 |
+
|
262 |
+
# # Try to estimate using DiD estimator if available in DoWhy
|
263 |
+
# try:
|
264 |
+
# logger.info("Attempting to use DoWhy's DiD estimator...")
|
265 |
+
|
266 |
+
# # Debug info - print DataFrame info to help diagnose possible issues
|
267 |
+
# logger.debug(f"DataFrame shape before DoWhy DiD: {df_processed.shape}")
|
268 |
+
# # ... (rest of DoWhy debug logs commented out) ...
|
269 |
+
|
270 |
+
# # Create params dictionary for DoWhy DiD estimator
|
271 |
+
# did_params = {
|
272 |
+
# 'time_var': time_var,
|
273 |
+
# 'treatment_period': treatment_period,
|
274 |
+
# 'unit_var': group_var
|
275 |
+
# }
|
276 |
+
|
277 |
+
# # Add control variables if available
|
278 |
+
# if covariates:
|
279 |
+
# did_params['control_vars'] = covariates
|
280 |
+
|
281 |
+
# logger.debug(f"DoWhy DiD params: {did_params}")
|
282 |
+
|
283 |
+
# # Try to use DiD estimator from DoWhy (requires recent version of DoWhy)
|
284 |
+
# if hasattr(model, 'estimate_effect'):
|
285 |
+
# try:
|
286 |
+
# # First check if difference_in_differences method is available
|
287 |
+
# available_methods = model.get_available_effect_estimators() if hasattr(model, 'get_available_effect_estimators') else []
|
288 |
+
# logger.debug(f"Available DoWhy estimators: {available_methods}")
|
289 |
+
|
290 |
+
# if "difference_in_differences" not in str(available_methods):
|
291 |
+
# logger.warning("'difference_in_differences' estimator not found in available DoWhy estimators. Falling back to statsmodels.")
|
292 |
+
# else:
|
293 |
+
# # Try the estimation with more error handling
|
294 |
+
# logger.info("Calling DoWhy DiD estimator...")
|
295 |
+
# estimate = model.estimate_effect(
|
296 |
+
# identified_estimand,
|
297 |
+
# method_name="difference_in_differences",
|
298 |
+
# method_params=did_params
|
299 |
+
# )
|
300 |
+
|
301 |
+
# if estimate:
|
302 |
+
# # Extra check to verify estimate has expected attributes
|
303 |
+
# if hasattr(estimate, 'value') and not pd.isna(estimate.value):
|
304 |
+
# dowhy_estimate = estimate
|
305 |
+
# use_dowhy_estimate = True
|
306 |
+
# logger.info(f"Successfully used DoWhy's DiD estimator. Effect estimate: {estimate.value}")
|
307 |
+
# else:
|
308 |
+
# logger.warning(f"DoWhy's DiD estimator returned invalid estimate: {estimate}. Falling back to statsmodels.")
|
309 |
+
# else:
|
310 |
+
# logger.warning("DoWhy's DiD estimator returned None. Falling back to statsmodels.")
|
311 |
+
# except IndexError as idx_err:
|
312 |
+
# # Handle specific IndexError that's occurring
|
313 |
+
# logger.error(f"IndexError in DoWhy DiD estimator: {idx_err}. Check input data structure.")
|
314 |
+
# # Trace more details about the error
|
315 |
+
# import traceback
|
316 |
+
# logger.error(f"Error traceback: {traceback.format_exc()}")
|
317 |
+
# logger.warning("Falling back to statsmodels due to IndexError in DoWhy.")
|
318 |
+
# else:
|
319 |
+
# logger.warning("DoWhy model does not have estimate_effect method. Falling back to statsmodels.")
|
320 |
+
|
321 |
+
# except (ImportError, AttributeError) as e:
|
322 |
+
# logger.warning(f"DoWhy DiD estimator not available or not implemented: {e}. Falling back to statsmodels.")
|
323 |
+
# except ValueError as ve:
|
324 |
+
# logger.error(f"ValueError in DoWhy DiD estimator: {ve}. Likely issue with data formatting. Falling back to statsmodels.")
|
325 |
+
# except Exception as e:
|
326 |
+
# logger.error(f"Error using DoWhy's DiD estimator: {e}. Falling back to statsmodels.")
|
327 |
+
# # Add traceback for better debugging
|
328 |
+
# import traceback
|
329 |
+
# logger.error(f"Full error traceback: {traceback.format_exc()}")
|
330 |
+
|
331 |
+
# except Exception as e:
|
332 |
+
# logger.error(f"Failed to create DoWhy CausalModel: {e}", exc_info=True)
|
333 |
+
# # model = None # Set model to None if creation fails
|
334 |
+
|
335 |
+
# Create parameters dictionary for formatting results
|
336 |
+
parameters = {
|
337 |
+
"time_var": time_var,
|
338 |
+
"group_var": group_var, # Unit ID
|
339 |
+
"treatment_indicator": treated_group_col_for_formula, # Group indicator used in formula basis
|
340 |
+
"post_indicator": post_indicator_col,
|
341 |
+
"treatment_period_start": treatment_period,
|
342 |
+
"covariates": covariates,
|
343 |
+
}
|
344 |
+
|
345 |
+
# Group diagnostics for formatting
|
346 |
+
did_diagnostics = {
|
347 |
+
"parallel_trends": parallel_trends_validation,
|
348 |
+
# "placebo_test": run_placebo_test(...)
|
349 |
+
}
|
350 |
+
|
351 |
+
# If DoWhy estimation was successful, use those results (Section Commented Out)
|
352 |
+
# if use_dowhy_estimate and dowhy_estimate:
|
353 |
+
# logger.info("Using DoWhy DiD estimation results.")
|
354 |
+
# parameters["estimation_method"] = "DoWhy Difference-in-Differences"
|
355 |
+
|
356 |
+
# # Format the results
|
357 |
+
# formatted_results = format_dowhy_results(dowhy_estimate, did_diagnostics, parameters)
|
358 |
+
# else:
|
359 |
+
|
360 |
+
# --- Step 5: Use Statsmodels OLS ---
|
361 |
+
logger.info("Determining Statsmodels OLS formula based on number of time periods...")
|
362 |
+
|
363 |
+
num_time_periods = df_processed[time_var].nunique()
|
364 |
+
|
365 |
+
interaction_term_key_for_results: str
|
366 |
+
method_details_str: str
|
367 |
+
formula: str
|
368 |
+
|
369 |
+
if num_time_periods == 2:
|
370 |
+
logger.info(
|
371 |
+
f"Number of unique time periods is 2. Using 2x2 DiD formula: "
|
372 |
+
f"{outcome} ~ {treated_group_col_for_formula} * {post_indicator_col}"
|
373 |
+
)
|
374 |
+
# For 2x2 DiD: outcome ~ group * post_indicator
|
375 |
+
# The interaction term A:B in statsmodels gives the DiD estimate.
|
376 |
+
formula_core = f"{treated_group_col_for_formula} * {post_indicator_col}"
|
377 |
+
interaction_term_key_for_results = f"{treated_group_col_for_formula}:{post_indicator_col}"
|
378 |
+
|
379 |
+
formula_parts = [formula_core]
|
380 |
+
main_model_terms = {outcome, treated_group_col_for_formula, post_indicator_col}
|
381 |
+
|
382 |
+
if covariates:
|
383 |
+
filtered_covs = [
|
384 |
+
c for c in covariates if c not in main_model_terms
|
385 |
+
]
|
386 |
+
if filtered_covs:
|
387 |
+
formula_parts.extend(filtered_covs)
|
388 |
+
|
389 |
+
formula = f"{outcome} ~ {' + '.join(formula_parts)}"
|
390 |
+
parameters["estimation_method"] = "Statsmodels OLS for 2x2 DiD (Group * Post interaction)"
|
391 |
+
method_details_str = "DiD via Statsmodels 2x2 (Group * Post interaction)"
|
392 |
+
|
393 |
+
else: # num_time_periods > 2
|
394 |
+
logger.info(
|
395 |
+
f"Number of unique time periods is {num_time_periods} (>2). "
|
396 |
+
f"Using TWFE DiD formula: {outcome} ~ {interaction_term_col} + C({group_var}) + C({time_var})"
|
397 |
+
)
|
398 |
+
# For TWFE: outcome ~ actual_treatment_variable + UnitFE + TimeFE
|
399 |
+
# actual_treatment_variable is interaction_term_col (e.g., treated_group * post_indicator)
|
400 |
+
# UnitFE is C(group_var), TimeFE is C(time_var)
|
401 |
+
formula_parts = [
|
402 |
+
interaction_term_col,
|
403 |
+
f"C({group_var})",
|
404 |
+
f"C({time_var})"
|
405 |
+
]
|
406 |
+
interaction_term_key_for_results = interaction_term_col
|
407 |
+
main_model_terms = {outcome, interaction_term_col, group_var, time_var}
|
408 |
+
|
409 |
+
if covariates:
|
410 |
+
filtered_covs = [
|
411 |
+
c for c in covariates if c not in main_model_terms
|
412 |
+
]
|
413 |
+
if filtered_covs:
|
414 |
+
formula_parts.extend(filtered_covs)
|
415 |
+
|
416 |
+
formula = f"{outcome} ~ {' + '.join(formula_parts)}"
|
417 |
+
parameters["estimation_method"] = "Statsmodels OLS with TWFE (C() Notation)"
|
418 |
+
method_details_str = "DiD via Statsmodels TWFE (C() Notation)"
|
419 |
+
|
420 |
+
try:
|
421 |
+
logger.info(f"Using formula: {formula}")
|
422 |
+
logger.debug(f"Data head for statsmodels:\n{df_processed.head().to_string()}")
|
423 |
+
logger.debug(f"Regression DataFrame shape: {df_processed.shape}, Columns: {df_processed.columns.tolist()}")
|
424 |
+
|
425 |
+
ols_model = smf.ols(formula=formula, data=df_processed)
|
426 |
+
if group_var not in df_processed.columns:
|
427 |
+
# This check is mainly for clustering but good to ensure group_var exists.
|
428 |
+
# For 2x2, group_var (unit ID) might not be in formula but needed for clustering.
|
429 |
+
raise ValueError(f"Clustering variable '{group_var}' (panel unit ID) not found in regression data.")
|
430 |
+
logger.debug(f"Clustering standard errors by: {group_var}")
|
431 |
+
results = ols_model.fit(cov_type='cluster', cov_kwds={'groups': df_processed[group_var]})
|
432 |
+
|
433 |
+
logger.info("Statsmodels estimation complete.")
|
434 |
+
logger.info(f"Statsmodels Results Summary:\n{results.summary()}")
|
435 |
+
|
436 |
+
logger.debug(f"Extracting results using interaction term key: {interaction_term_key_for_results}")
|
437 |
+
|
438 |
+
parameters["final_formula"] = formula
|
439 |
+
parameters["interaction_term_coefficient_name"] = interaction_term_key_for_results
|
440 |
+
|
441 |
+
formatted_results = format_did_results(results, interaction_term_key_for_results,
|
442 |
+
did_diagnostics,
|
443 |
+
method_details=method_details_str,
|
444 |
+
parameters=parameters)
|
445 |
+
formatted_results["estimator"] = "statsmodels"
|
446 |
+
|
447 |
+
except Exception as e:
|
448 |
+
logger.error(f"Statsmodels OLS estimation failed: {e}", exc_info=True)
|
449 |
+
raise ValueError(f"DiD estimation failed (both DoWhy and Statsmodels): {e}")
|
450 |
+
|
451 |
+
|
452 |
+
|
453 |
+
|
454 |
+
# --- Add Interpretation --- (Now add interpretation to the formatted results)
|
455 |
+
try:
|
456 |
+
# Use the llm_instance fetched earlier
|
457 |
+
interpretation = interpret_did_results(formatted_results, did_diagnostics, dataset_description, llm=llm_instance)
|
458 |
+
formatted_results['interpretation'] = interpretation
|
459 |
+
except Exception as interp_e:
|
460 |
+
logger.error(f"DiD Interpretation failed: {interp_e}")
|
461 |
+
formatted_results['interpretation'] = "Interpretation failed."
|
462 |
+
|
463 |
+
return formatted_results
|
auto_causal/methods/difference_in_differences/llm_assist.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""LLM Assist functions for Difference-in-Differences method."""
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from typing import Optional, Any, Dict, Union
|
6 |
+
import logging
|
7 |
+
from pydantic import BaseModel, Field, ValidationError
|
8 |
+
from langchain_core.messages import HumanMessage
|
9 |
+
from langchain_core.exceptions import OutputParserException
|
10 |
+
|
11 |
+
# Import shared types if needed
|
12 |
+
from langchain_core.language_models import BaseChatModel
|
13 |
+
|
14 |
+
# Import shared LLM helpers
|
15 |
+
from auto_causal.utils.llm_helpers import call_llm_with_json_output
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
# Placeholder LLM/Helper Functions
|
20 |
+
|
21 |
+
# --- Pydantic model for LLM time variable extraction ---
|
22 |
+
class LLMTimeVar(BaseModel):
|
23 |
+
time_variable_name: Optional[str] = Field(None, description="The column name identified as the primary time variable.")
|
24 |
+
|
25 |
+
|
26 |
+
def identify_time_variable(df: pd.DataFrame,
|
27 |
+
query: Optional[str] = None,
|
28 |
+
dataset_description: Optional[str] = None,
|
29 |
+
llm: Optional[BaseChatModel] = None) -> Optional[str]:
|
30 |
+
'''Identifies the most likely time variable.
|
31 |
+
|
32 |
+
Current Implementation: Heuristic based on column names, with LLM fallback.
|
33 |
+
Future: Refine LLM prompt and parsing.
|
34 |
+
'''
|
35 |
+
# 1. Heuristic based on common time-related keywords
|
36 |
+
time_patterns = ['time', 'year', 'date', 'period', 'month', 'day']
|
37 |
+
columns = df.columns.tolist()
|
38 |
+
for col in columns:
|
39 |
+
if any(pattern in col.lower() for pattern in time_patterns):
|
40 |
+
logger.info(f"Identified '{col}' as time variable (heuristic).")
|
41 |
+
return col
|
42 |
+
|
43 |
+
# 2. LLM Fallback if heuristic fails and LLM is provided
|
44 |
+
if llm and query:
|
45 |
+
logger.warning("Heuristic failed for time variable. Trying LLM fallback...")
|
46 |
+
# --- Example: Add dataset description context ---
|
47 |
+
context_str = ""
|
48 |
+
if dataset_description:
|
49 |
+
# col_types = dataset_description.get('column_types', {}) # Description is now a string
|
50 |
+
context_str += f"\nDataset Description: {dataset_description}"
|
51 |
+
# Add other relevant info like sample values if available
|
52 |
+
# ------------------------------------------------
|
53 |
+
prompt = f"""Given the user query and the available data columns, identify the single most likely column representing the primary time dimension (e.g., year, date, period).
|
54 |
+
|
55 |
+
User Query: "{query}"
|
56 |
+
Available Columns: {columns}{context_str}
|
57 |
+
|
58 |
+
Respond ONLY with a JSON object containing the identified column name using the key 'time_variable_name'. If no suitable time variable is found, return null for the value.
|
59 |
+
Example: {{"time_variable_name": "Year"}} or {{"time_variable_name": null}}"""
|
60 |
+
|
61 |
+
messages = [HumanMessage(content=prompt)]
|
62 |
+
structured_llm = llm.with_structured_output(LLMTimeVar)
|
63 |
+
|
64 |
+
try:
|
65 |
+
parsed_result = structured_llm.invoke(messages)
|
66 |
+
llm_identified_col = parsed_result.time_variable_name
|
67 |
+
|
68 |
+
if llm_identified_col and llm_identified_col in columns:
|
69 |
+
logger.info(f"Identified '{llm_identified_col}' as time variable (LLM fallback).")
|
70 |
+
return llm_identified_col
|
71 |
+
elif llm_identified_col:
|
72 |
+
logger.warning(f"LLM fallback identified '{llm_identified_col}' but it's not in the columns. Ignoring.")
|
73 |
+
else:
|
74 |
+
logger.info("LLM fallback did not identify a time variable.")
|
75 |
+
|
76 |
+
except (OutputParserException, ValidationError) as e:
|
77 |
+
logger.error(f"LLM fallback for time variable failed parsing/validation: {e}")
|
78 |
+
except Exception as e:
|
79 |
+
logger.error(f"LLM fallback for time variable failed unexpectedly: {e}", exc_info=True)
|
80 |
+
|
81 |
+
logger.warning("Could not identify time variable using heuristics or LLM fallback.")
|
82 |
+
return None
|
83 |
+
|
84 |
+
# --- Pydantic model for LLM treatment period extraction ---
|
85 |
+
class LLMTreatmentPeriod(BaseModel):
|
86 |
+
treatment_start_period: Optional[Union[str, int, float]] = Field(None, description="The time period value (as string) when treatment is believed to start based on the query.")
|
87 |
+
|
88 |
+
def determine_treatment_period(df: pd.DataFrame, time_var: str, treatment: str,
|
89 |
+
query: Optional[str] = None,
|
90 |
+
dataset_description: Optional[str] = None,
|
91 |
+
llm: Optional[BaseChatModel] = None) -> Any:
|
92 |
+
'''Determines the period when treatment starts.
|
93 |
+
|
94 |
+
Tries LLM first if available, then falls back to heuristic.
|
95 |
+
'''
|
96 |
+
if time_var not in df.columns:
|
97 |
+
raise ValueError(f"Time variable '{time_var}' not found in DataFrame.")
|
98 |
+
|
99 |
+
unique_times_sorted = np.sort(df[time_var].dropna().unique())
|
100 |
+
if len(unique_times_sorted) < 2:
|
101 |
+
raise ValueError("Need at least two time periods for DiD")
|
102 |
+
|
103 |
+
# --- Try LLM First (if available) ---
|
104 |
+
llm_period = None
|
105 |
+
if llm and query:
|
106 |
+
logger.info("Attempting LLM call to determine treatment period start...")
|
107 |
+
# Provide sorted unique times for context
|
108 |
+
times_str = ", ".join(map(str, unique_times_sorted)) if len(unique_times_sorted) < 20 else f"{unique_times_sorted[0]}...{unique_times_sorted[-1]}"
|
109 |
+
# --- Example: Add dataset description context ---
|
110 |
+
context_str = ""
|
111 |
+
if dataset_description:
|
112 |
+
# Example: Show summary stats for time var if helpful
|
113 |
+
# time_stats = dataset_description.get('summary_stats', {}).get(time_var) # Cannot get from string
|
114 |
+
context_str += f"\nDataset Description: {dataset_description}"
|
115 |
+
# ------------------------------------------------
|
116 |
+
prompt = f"""Based on the user query and the observed time periods, determine the specific period value when the treatment ('{treatment}') likely started.
|
117 |
+
|
118 |
+
User Query: "{query}"
|
119 |
+
Time Variable Name: '{time_var}'
|
120 |
+
Observed Time Periods (sorted): [{times_str}]{context_str}
|
121 |
+
|
122 |
+
Respond ONLY with a JSON object containing the identified start period using the key 'treatment_start_period'. The value should be one of the observed periods if possible. If the query doesn't specify a start period, return null.
|
123 |
+
Example: {{"treatment_start_period": 2015}} or {{"treatment_start_period": null}}"""
|
124 |
+
|
125 |
+
messages = [HumanMessage(content=prompt)]
|
126 |
+
structured_llm = llm.with_structured_output(LLMTreatmentPeriod)
|
127 |
+
|
128 |
+
try:
|
129 |
+
parsed_result = structured_llm.invoke(messages)
|
130 |
+
potential_period = parsed_result.treatment_start_period
|
131 |
+
|
132 |
+
# Validate if the period exists in the data (might need type conversion)
|
133 |
+
if potential_period is not None:
|
134 |
+
# Try converting LLM output type to match data type if needed
|
135 |
+
try:
|
136 |
+
series_dtype = df[time_var].dtype
|
137 |
+
converted_period = pd.Series([potential_period]).astype(series_dtype).iloc[0]
|
138 |
+
except Exception:
|
139 |
+
converted_period = potential_period # Use raw if conversion fails
|
140 |
+
|
141 |
+
if converted_period in unique_times_sorted:
|
142 |
+
llm_period = converted_period
|
143 |
+
logger.info(f"LLM identified treatment period start: {llm_period}")
|
144 |
+
else:
|
145 |
+
logger.warning(f"LLM identified period '{potential_period}' (converted: '{converted_period}'), but it's not in the observed time periods. Ignoring LLM result.")
|
146 |
+
else:
|
147 |
+
logger.info("LLM did not identify a specific treatment start period from the query.")
|
148 |
+
|
149 |
+
except (OutputParserException, ValidationError) as e:
|
150 |
+
logger.error(f"LLM fallback for treatment period failed parsing/validation: {e}")
|
151 |
+
except Exception as e:
|
152 |
+
logger.error(f"LLM fallback for treatment period failed unexpectedly: {e}", exc_info=True)
|
153 |
+
|
154 |
+
if llm_period is not None:
|
155 |
+
return llm_period
|
156 |
+
|
157 |
+
# --- Fallback to Heuristic ---
|
158 |
+
logger.warning("Using heuristic (median time) to determine treatment period start.")
|
159 |
+
treatment_period_start = None
|
160 |
+
try:
|
161 |
+
if pd.api.types.is_numeric_dtype(df[time_var]):
|
162 |
+
median_time = np.median(unique_times_sorted)
|
163 |
+
possible_starts = unique_times_sorted[unique_times_sorted > median_time]
|
164 |
+
if len(possible_starts) > 0:
|
165 |
+
treatment_period_start = possible_starts[0]
|
166 |
+
else:
|
167 |
+
treatment_period_start = unique_times_sorted[-1]
|
168 |
+
logger.warning(f"Could not determine treatment start > median time. Defaulting to last period: {treatment_period_start}")
|
169 |
+
else: # Assume sortable categories or dates
|
170 |
+
median_idx = len(unique_times_sorted) // 2
|
171 |
+
if median_idx < len(unique_times_sorted):
|
172 |
+
treatment_period_start = unique_times_sorted[median_idx]
|
173 |
+
else:
|
174 |
+
treatment_period_start = unique_times_sorted[0]
|
175 |
+
|
176 |
+
if treatment_period_start is not None:
|
177 |
+
logger.info(f"Determined treatment period start: {treatment_period_start} (heuristic: median time).")
|
178 |
+
return treatment_period_start
|
179 |
+
else:
|
180 |
+
raise ValueError("Could not determine treatment start period using heuristic.")
|
181 |
+
|
182 |
+
except Exception as e:
|
183 |
+
logger.error(f"Error in heuristic for treatment period: {e}")
|
184 |
+
raise ValueError(f"Could not determine treatment start period using heuristic: {e}")
|
185 |
+
|
186 |
+
# --- Pydantic model for LLM group variable extraction ---
|
187 |
+
class LLMGroupVar(BaseModel):
|
188 |
+
group_variable_name: Optional[str] = Field(None, description="The column name identifying the panel unit (e.g., state, individual, firm).")
|
189 |
+
|
190 |
+
def identify_treatment_group(df: pd.DataFrame, treatment_var: str,
|
191 |
+
query: Optional[str] = None,
|
192 |
+
dataset_description: Optional[str] = None,
|
193 |
+
llm: Optional[BaseChatModel] = None) -> Optional[str]:
|
194 |
+
'''Identifies the variable indicating the treated group/unit ID.
|
195 |
+
|
196 |
+
Tries heuristic check for non-binary treatment_var first, then LLM,
|
197 |
+
then falls back to assuming treatment_var is the group/unit identifier.
|
198 |
+
'''
|
199 |
+
columns = df.columns.tolist()
|
200 |
+
if treatment_var not in columns:
|
201 |
+
logger.error(f"Treatment variable '{treatment_var}' provided to identify_treatment_group not found in DataFrame.")
|
202 |
+
# Fallback: Look for common ID names if specified treatment is missing
|
203 |
+
id_keywords = ['id', 'unit', 'group', 'entity', 'state', 'firm']
|
204 |
+
for col in columns:
|
205 |
+
if any(keyword in col.lower() for keyword in id_keywords):
|
206 |
+
logger.warning(f"Specified treatment '{treatment_var}' not found. Falling back to potential ID column '{col}' as group identifier.")
|
207 |
+
return col
|
208 |
+
return None # Give up if no likely ID column found
|
209 |
+
|
210 |
+
# --- Heuristic: Check if treatment_var is non-binary, if so, look for ID columns ---
|
211 |
+
is_potentially_binary = False
|
212 |
+
if pd.api.types.is_numeric_dtype(df[treatment_var]):
|
213 |
+
unique_vals = set(df[treatment_var].dropna().unique())
|
214 |
+
if unique_vals.issubset({0, 1}):
|
215 |
+
is_potentially_binary = True
|
216 |
+
|
217 |
+
if not is_potentially_binary:
|
218 |
+
logger.info(f"Provided treatment variable '{treatment_var}' is not binary (0/1). Searching for a separate group/unit ID column heuristically.")
|
219 |
+
id_keywords = ['id', 'unit', 'group', 'entity', 'state', 'firm']
|
220 |
+
# Prioritize 'group' or 'unit' if available
|
221 |
+
for keyword in ['group', 'unit']:
|
222 |
+
for col in columns:
|
223 |
+
if keyword == col.lower():
|
224 |
+
logger.info(f"Heuristically identified '{col}' as group/unit ID (treatment '{treatment_var}' was non-binary)." )
|
225 |
+
return col
|
226 |
+
# Then check other keywords
|
227 |
+
for col in columns:
|
228 |
+
if col != treatment_var and any(keyword in col.lower() for keyword in id_keywords):
|
229 |
+
logger.info(f"Heuristically identified '{col}' as group/unit ID (treatment '{treatment_var}' was non-binary)." )
|
230 |
+
return col
|
231 |
+
logger.warning("Heuristic search for group/unit ID failed when treatment was non-binary.")
|
232 |
+
|
233 |
+
# --- LLM Attempt (if heuristic didn't find an alternative or wasn't needed) ---
|
234 |
+
# Useful if query context helps disambiguate (e.g., "effect across states")
|
235 |
+
if llm and query:
|
236 |
+
logger.info("Attempting LLM call to identify group/unit variable...")
|
237 |
+
# --- Example: Add dataset description context ---
|
238 |
+
context_str = ""
|
239 |
+
if dataset_description:
|
240 |
+
# col_types = dataset_description.get('column_types', {}) # Description is now a string
|
241 |
+
context_str += f"\nDataset Description: {dataset_description}"
|
242 |
+
# ------------------------------------------------
|
243 |
+
prompt = f"""Given the user query and data columns, identify the single column that most likely represents the unique identifier for the panel units (e.g., state, individual, firm, unit ID), distinct from the treatment status indicator ('{treatment_var}').
|
244 |
+
|
245 |
+
User Query: "{query}"
|
246 |
+
Treatment Variable Mentioned: '{treatment_var}'
|
247 |
+
Available Columns: {columns}{context_str}
|
248 |
+
|
249 |
+
Respond ONLY with a JSON object containing the identified unit identifier column name using the key 'group_variable_name'. If the best identifier seems to be the treatment variable itself or none is suitable, return null.
|
250 |
+
Example: {{"group_variable_name": "state_id"}} or {{"group_variable_name": null}}"""
|
251 |
+
|
252 |
+
messages = [HumanMessage(content=prompt)]
|
253 |
+
structured_llm = llm.with_structured_output(LLMGroupVar)
|
254 |
+
|
255 |
+
try:
|
256 |
+
parsed_result = structured_llm.invoke(messages)
|
257 |
+
llm_identified_col = parsed_result.group_variable_name
|
258 |
+
|
259 |
+
if llm_identified_col and llm_identified_col in columns:
|
260 |
+
logger.info(f"Identified '{llm_identified_col}' as group/unit variable (LLM).")
|
261 |
+
return llm_identified_col
|
262 |
+
elif llm_identified_col:
|
263 |
+
logger.warning(f"LLM identified '{llm_identified_col}' but it's not in the columns. Ignoring.")
|
264 |
+
else:
|
265 |
+
logger.info("LLM did not identify a separate group/unit variable.")
|
266 |
+
|
267 |
+
except (OutputParserException, ValidationError) as e:
|
268 |
+
logger.error(f"LLM call for group/unit variable failed parsing/validation: {e}")
|
269 |
+
except Exception as e:
|
270 |
+
logger.error(f"LLM call for group/unit variable failed unexpectedly: {e}", exc_info=True)
|
271 |
+
|
272 |
+
# --- Final Fallback ---
|
273 |
+
logger.info(f"Defaulting to using provided treatment variable '{treatment_var}' as the group/unit identifier.")
|
274 |
+
return treatment_var
|
275 |
+
|
276 |
+
# --- Add interpret_did_results function ---
|
277 |
+
|
278 |
+
def interpret_did_results(
|
279 |
+
results: Dict[str, Any],
|
280 |
+
diagnostics: Optional[Dict[str, Any]],
|
281 |
+
dataset_description: Optional[str] = None,
|
282 |
+
llm: Optional[BaseChatModel] = None
|
283 |
+
) -> str:
|
284 |
+
"""Use LLM to interpret Difference-in-Differences results."""
|
285 |
+
default_interpretation = "LLM interpretation not available for DiD."
|
286 |
+
if llm is None:
|
287 |
+
logger.info("LLM not provided for DiD interpretation.")
|
288 |
+
return default_interpretation
|
289 |
+
|
290 |
+
try:
|
291 |
+
# --- Prepare summary for LLM ---
|
292 |
+
results_summary = {}
|
293 |
+
params = results.get('parameters', {})
|
294 |
+
diag_details = diagnostics.get('details', {}) if diagnostics else {}
|
295 |
+
parallel_trends = diag_details.get('parallel_trends', {})
|
296 |
+
|
297 |
+
effect = results.get('effect_estimate')
|
298 |
+
pval = results.get('p_value')
|
299 |
+
ci = results.get('confidence_interval')
|
300 |
+
|
301 |
+
results_summary['Method Used'] = results.get('method_details', 'Difference-in-Differences')
|
302 |
+
results_summary['Effect Estimate'] = f"{effect:.3f}" if isinstance(effect, (int, float)) else str(effect)
|
303 |
+
results_summary['P-value'] = f"{pval:.3f}" if isinstance(pval, (int, float)) else str(pval)
|
304 |
+
if isinstance(ci, (list, tuple)) and len(ci) == 2:
|
305 |
+
results_summary['Confidence Interval'] = f"[{ci[0]:.3f}, {ci[1]:.3f}]"
|
306 |
+
else:
|
307 |
+
results_summary['Confidence Interval'] = str(ci) if ci is not None else "N/A"
|
308 |
+
|
309 |
+
results_summary['Time Variable'] = params.get('time_var', 'N/A')
|
310 |
+
results_summary['Group/Unit Variable'] = params.get('group_var', 'N/A')
|
311 |
+
results_summary['Treatment Indicator Used'] = params.get('treatment_indicator', 'N/A')
|
312 |
+
results_summary['Treatment Start Period'] = params.get('treatment_period_start', 'N/A')
|
313 |
+
results_summary['Covariates Included'] = params.get('covariates', [])
|
314 |
+
|
315 |
+
diag_summary = {}
|
316 |
+
diag_summary['Parallel Trends Assumption Status'] = "Passed (Placeholder)" if parallel_trends.get('valid', False) else "Failed/Unknown (Placeholder)"
|
317 |
+
if not parallel_trends.get('valid', False) and parallel_trends.get('details') != "Placeholder validation":
|
318 |
+
diag_summary['Parallel Trends Details'] = parallel_trends.get('details', 'N/A')
|
319 |
+
|
320 |
+
# --- Example: Add dataset description context ---
|
321 |
+
context_str = ""
|
322 |
+
if dataset_description:
|
323 |
+
# context_str += f"\nDataset Context: {dataset_description.get('summary', 'N/A')}" # Use string directly
|
324 |
+
context_str += f"\n\nDataset Context Provided:\n{dataset_description}"
|
325 |
+
# ------------------------------------------------
|
326 |
+
|
327 |
+
# --- Construct Prompt ---
|
328 |
+
prompt = f"""
|
329 |
+
You are assisting with interpreting Difference-in-Differences (DiD) results.
|
330 |
+
{context_str} # Add context here
|
331 |
+
|
332 |
+
Estimation Results Summary:
|
333 |
+
{results_summary}
|
334 |
+
|
335 |
+
Diagnostics Summary:
|
336 |
+
{diag_summary}
|
337 |
+
|
338 |
+
Explain these DiD results in 2-4 concise sentences. Focus on:
|
339 |
+
1. The estimated average treatment effect on the treated (magnitude, direction, statistical significance based on p-value < 0.05).
|
340 |
+
2. The status of the parallel trends assumption (mentioning it's a key assumption for DiD).
|
341 |
+
3. Note that the estimation controlled for unit and time fixed effects, and potentially covariates {results_summary['Covariates Included']}
|
342 |
+
|
343 |
+
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
|
344 |
+
{{
|
345 |
+
"interpretation": "<your concise interpretation text>"
|
346 |
+
}}
|
347 |
+
"""
|
348 |
+
|
349 |
+
# --- Call LLM ---
|
350 |
+
response = call_llm_with_json_output(llm, prompt)
|
351 |
+
|
352 |
+
# --- Process Response ---
|
353 |
+
if response and isinstance(response, dict) and \
|
354 |
+
"interpretation" in response and isinstance(response["interpretation"], str):
|
355 |
+
return response["interpretation"]
|
356 |
+
else:
|
357 |
+
logger.warning(f"Failed to get valid interpretation from LLM for DiD. Response: {response}")
|
358 |
+
return default_interpretation
|
359 |
+
|
360 |
+
except Exception as e:
|
361 |
+
logger.error(f"Error during LLM interpretation for DiD: {e}")
|
362 |
+
return f"Error generating interpretation: {e}"
|
auto_causal/methods/difference_in_differences/utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Utility functions for Difference-in-Differences
|
2 |
+
import pandas as pd
|
3 |
+
import logging
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
def create_post_indicator(df: pd.DataFrame, time_var: str, treatment_period_start: any) -> pd.Series:
|
8 |
+
"""Creates the post-treatment indicator variable.
|
9 |
+
Checks if time_var is already a 0/1 indicator; otherwise, compares to treatment_period_start.
|
10 |
+
"""
|
11 |
+
try:
|
12 |
+
time_var_series = df[time_var]
|
13 |
+
# Ensure numeric for checks and direct comparison
|
14 |
+
if pd.api.types.is_bool_dtype(time_var_series):
|
15 |
+
time_var_series = time_var_series.astype(int)
|
16 |
+
|
17 |
+
# Check if it's already a binary 0/1 indicator
|
18 |
+
if pd.api.types.is_numeric_dtype(time_var_series):
|
19 |
+
unique_vals = set(time_var_series.dropna().unique())
|
20 |
+
if unique_vals == {0, 1}:
|
21 |
+
logger.info(f"Time variable '{time_var}' is already a binary 0/1 indicator. Using it directly as post indicator.")
|
22 |
+
return time_var_series.astype(int)
|
23 |
+
else:
|
24 |
+
# Numeric, but not 0/1, so compare with treatment_period_start
|
25 |
+
logger.info(f"Time variable '{time_var}' is numeric. Comparing with treatment_period_start: {treatment_period_start}")
|
26 |
+
return (time_var_series >= treatment_period_start).astype(int)
|
27 |
+
else:
|
28 |
+
# Non-numeric and not boolean, will likely fall into TypeError for datetime conversion
|
29 |
+
# This else block might not be strictly necessary if TypeError is caught below
|
30 |
+
# but added for logical completeness before attempting datetime conversion.
|
31 |
+
pass # Let it fall through to TypeError if not numeric here
|
32 |
+
|
33 |
+
# If we reached here, it means it wasn't numeric or bool, try direct comparison which will likely raise TypeError
|
34 |
+
# and be caught by the except block for datetime conversion if applicable.
|
35 |
+
# This line is kept to ensure non-numeric non-datetime-like strings also trigger the except.
|
36 |
+
return (df[time_var] >= treatment_period_start).astype(int)
|
37 |
+
|
38 |
+
except TypeError:
|
39 |
+
# If direct comparison fails (e.g., comparing datetime with int/str, or non-numeric string with number),
|
40 |
+
# attempt to convert both to datetime objects for comparison.
|
41 |
+
logger.info(f"Direct comparison/numeric check failed for time_var '{time_var}'. Attempting datetime conversion.")
|
42 |
+
try:
|
43 |
+
time_series_dt = pd.to_datetime(df[time_var], errors='coerce')
|
44 |
+
# Try to convert treatment_period_start to datetime if it's not already
|
45 |
+
# This handles cases where treatment_period_start might be a date string
|
46 |
+
try:
|
47 |
+
treatment_start_dt = pd.to_datetime(treatment_period_start)
|
48 |
+
except Exception as e_conv:
|
49 |
+
logger.error(f"Could not convert treatment_period_start '{treatment_period_start}' to datetime: {e_conv}")
|
50 |
+
raise TypeError(f"treatment_period_start '{treatment_period_start}' could not be converted to a comparable datetime format.")
|
51 |
+
|
52 |
+
if time_series_dt.isna().all(): # if all values are NaT after conversion
|
53 |
+
raise ValueError(f"Time variable '{time_var}' could not be converted to datetime (all values NaT).")
|
54 |
+
if pd.isna(treatment_start_dt):
|
55 |
+
raise ValueError(f"Treatment start period '{treatment_period_start}' converted to NaT.")
|
56 |
+
|
57 |
+
logger.info(f"Comparing time_var '{time_var}' (as datetime) with treatment_start_dt '{treatment_start_dt}' (as datetime).")
|
58 |
+
return (time_series_dt >= treatment_start_dt).astype(int)
|
59 |
+
except Exception as e:
|
60 |
+
logger.error(f"Failed to compare time variable '{time_var}' with treatment start '{treatment_period_start}' using datetime logic: {e}", exc_info=True)
|
61 |
+
raise TypeError(f"Could not compare time variable '{time_var}' with treatment start '{treatment_period_start}'. Ensure they are comparable or convertible to datetime. Error: {e}")
|
62 |
+
except Exception as ex:
|
63 |
+
# Catch any other unexpected errors during the initial numeric processing
|
64 |
+
logger.error(f"Unexpected error processing time_var '{time_var}' for post indicator: {ex}", exc_info=True)
|
65 |
+
raise TypeError(f"Unexpected error processing time_var '{time_var}': {ex}")
|
auto_causal/methods/generalized_propensity_score/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Generalized Propensity Score (GPS) method for continuous treatments.
|
3 |
+
"""
|
auto_causal/methods/generalized_propensity_score/diagnostics.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Diagnostic checks for the Generalized Propensity Score (GPS) method.
|
3 |
+
"""
|
4 |
+
from typing import Dict, List, Any
|
5 |
+
import pandas as pd
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
import statsmodels.api as sm
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
def assess_gps_balance(
|
13 |
+
df_with_gps: pd.DataFrame,
|
14 |
+
treatment_var: str,
|
15 |
+
covariate_vars: List[str],
|
16 |
+
gps_col_name: str,
|
17 |
+
**kwargs: Any
|
18 |
+
) -> Dict[str, Any]:
|
19 |
+
"""
|
20 |
+
Assesses the balance of covariates conditional on the estimated GPS.
|
21 |
+
|
22 |
+
This function is typically called after GPS estimation to validate the
|
23 |
+
assumption that covariates are independent of treatment conditional on GPS.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
df_with_gps: DataFrame containing the original data plus the estimated GPS column.
|
27 |
+
treatment_var: The name of the continuous treatment variable column.
|
28 |
+
covariate_vars: A list of covariate column names to check for balance.
|
29 |
+
gps_col_name: The name of the column containing the estimated GPS values.
|
30 |
+
**kwargs: Additional arguments (e.g., number of strata for checking balance).
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
A dictionary containing balance statistics and summaries. For example:
|
34 |
+
{
|
35 |
+
"overall_balance_metric": 0.05,
|
36 |
+
"covariate_balance": {
|
37 |
+
"cov1": {"statistic": 0.03, "p_value": 0.5, "balanced": True},
|
38 |
+
"cov2": {"statistic": 0.12, "p_value": 0.02, "balanced": False}
|
39 |
+
},
|
40 |
+
"summary": "Balance assessment complete."
|
41 |
+
}
|
42 |
+
"""
|
43 |
+
logger.info(f"Assessing GPS balance for covariates: {covariate_vars}")
|
44 |
+
|
45 |
+
# Default to 5 strata (quintiles) if not specified
|
46 |
+
num_strata = kwargs.get('num_strata', 5)
|
47 |
+
if not isinstance(num_strata, int) or num_strata <= 1:
|
48 |
+
logger.warning(f"Invalid num_strata ({num_strata}), defaulting to 5.")
|
49 |
+
num_strata = 5
|
50 |
+
|
51 |
+
balance_results = {}
|
52 |
+
overall_summary = {
|
53 |
+
"num_strata_used": num_strata,
|
54 |
+
"covariates_tested": len(covariate_vars),
|
55 |
+
"warnings": [],
|
56 |
+
"all_strata_coefficients": {cov: [] for cov in covariate_vars},
|
57 |
+
"all_strata_p_values": {cov: [] for cov in covariate_vars}
|
58 |
+
}
|
59 |
+
|
60 |
+
if df_with_gps[gps_col_name].isnull().all():
|
61 |
+
logger.error(f"All GPS scores in column '{gps_col_name}' are NaN. Cannot perform balance assessment.")
|
62 |
+
overall_summary["error"] = "All GPS scores are NaN."
|
63 |
+
return {
|
64 |
+
"error": "All GPS scores are NaN.",
|
65 |
+
"summary": "Balance assessment failed."
|
66 |
+
}
|
67 |
+
|
68 |
+
try:
|
69 |
+
# Create GPS strata (e.g., quintiles)
|
70 |
+
# Ensure unique bin edges for qcut, duplicates='drop' will handle cases with sparse GPS values
|
71 |
+
# but might result in fewer than num_strata if GPS distribution is highly skewed or has few unique values.
|
72 |
+
try:
|
73 |
+
df_with_gps['gps_stratum'] = pd.qcut(df_with_gps[gps_col_name], num_strata, labels=False, duplicates='drop')
|
74 |
+
actual_num_strata = df_with_gps['gps_stratum'].nunique()
|
75 |
+
if actual_num_strata < num_strata and actual_num_strata > 0:
|
76 |
+
logger.warning(f"Requested {num_strata} strata, but due to GPS distribution, only {actual_num_strata} could be formed.")
|
77 |
+
overall_summary["warnings"].append(f"Only {actual_num_strata} strata formed out of {num_strata} requested.")
|
78 |
+
overall_summary["actual_num_strata_formed"] = actual_num_strata
|
79 |
+
except ValueError as ve:
|
80 |
+
logger.error(f"Could not create strata using pd.qcut due to: {ve}. This might happen if GPS has too few unique values.")
|
81 |
+
logger.info("Attempting to use unique GPS values as strata if count is low.")
|
82 |
+
unique_gps_count = df_with_gps[gps_col_name].nunique()
|
83 |
+
if unique_gps_count <= num_strata * 2 and unique_gps_count > 1: # Arbitrary threshold to try unique values as strata
|
84 |
+
strata_map = {val: i for i, val in enumerate(df_with_gps[gps_col_name].unique())}
|
85 |
+
df_with_gps['gps_stratum'] = df_with_gps[gps_col_name].map(strata_map)
|
86 |
+
actual_num_strata = df_with_gps['gps_stratum'].nunique()
|
87 |
+
overall_summary["actual_num_strata_formed"] = actual_num_strata
|
88 |
+
overall_summary["warnings"].append(f"Used {actual_num_strata} unique GPS values as strata due to qcut error.")
|
89 |
+
else:
|
90 |
+
overall_summary["error"] = f"Failed to create GPS strata: {ve}. GPS may have too few unique values."
|
91 |
+
return {
|
92 |
+
"error": overall_summary["error"],
|
93 |
+
"summary": "Balance assessment failed due to strata creation issues."
|
94 |
+
}
|
95 |
+
|
96 |
+
if df_with_gps['gps_stratum'].isnull().all():
|
97 |
+
logger.error("Stratum assignment resulted in all NaNs.")
|
98 |
+
overall_summary["error"] = "Stratum assignment resulted in all NaNs."
|
99 |
+
return {"error": overall_summary["error"], "summary": "Balance assessment failed."}
|
100 |
+
|
101 |
+
|
102 |
+
for cov in covariate_vars:
|
103 |
+
balance_results[cov] = {
|
104 |
+
"strata_details": [],
|
105 |
+
"mean_abs_coefficient": None,
|
106 |
+
"num_significant_strata_p005": 0,
|
107 |
+
"balanced_heuristic": True # Assume balanced until proven otherwise
|
108 |
+
}
|
109 |
+
coeffs_for_cov = []
|
110 |
+
p_values_for_cov = []
|
111 |
+
|
112 |
+
for stratum_idx in sorted(df_with_gps['gps_stratum'].dropna().unique()):
|
113 |
+
stratum_data = df_with_gps[df_with_gps['gps_stratum'] == stratum_idx]
|
114 |
+
stratum_detail = {"stratum_index": int(stratum_idx), "n_obs": len(stratum_data)}
|
115 |
+
|
116 |
+
if len(stratum_data) < 10: # Need a minimum number of observations for stable regression
|
117 |
+
stratum_detail["status"] = "Skipped (too few observations)"
|
118 |
+
stratum_detail["coefficient_on_treatment"] = np.nan
|
119 |
+
stratum_detail["p_value_on_treatment"] = np.nan
|
120 |
+
balance_results[cov]["strata_details"].append(stratum_detail)
|
121 |
+
continue
|
122 |
+
|
123 |
+
# Ensure covariate and treatment have variance within the stratum
|
124 |
+
if stratum_data[cov].nunique() < 2 or stratum_data[treatment_var].nunique() < 2:
|
125 |
+
stratum_detail["status"] = "Skipped (no variance in cov or treatment)"
|
126 |
+
stratum_detail["coefficient_on_treatment"] = np.nan
|
127 |
+
stratum_detail["p_value_on_treatment"] = np.nan
|
128 |
+
balance_results[cov]["strata_details"].append(stratum_detail)
|
129 |
+
continue
|
130 |
+
|
131 |
+
try:
|
132 |
+
X_balance = sm.add_constant(stratum_data[[treatment_var]])
|
133 |
+
y_balance = stratum_data[cov]
|
134 |
+
|
135 |
+
# Drop NaNs for this specific regression within stratum
|
136 |
+
temp_df = pd.concat([y_balance, X_balance], axis=1).dropna()
|
137 |
+
if len(temp_df) < X_balance.shape[1] +1: # Check for enough data points after NaNs for regression
|
138 |
+
stratum_detail["status"] = "Skipped (too few non-NaN obs for regression)"
|
139 |
+
stratum_detail["coefficient_on_treatment"] = np.nan
|
140 |
+
stratum_detail["p_value_on_treatment"] = np.nan
|
141 |
+
balance_results[cov]["strata_details"].append(stratum_detail)
|
142 |
+
continue
|
143 |
+
|
144 |
+
y_balance_fit = temp_df[cov]
|
145 |
+
X_balance_fit = temp_df[[col for col in temp_df.columns if col != cov]]
|
146 |
+
|
147 |
+
balance_model = sm.OLS(y_balance_fit, X_balance_fit).fit()
|
148 |
+
coeff = balance_model.params.get(treatment_var, np.nan)
|
149 |
+
p_value = balance_model.pvalues.get(treatment_var, np.nan)
|
150 |
+
|
151 |
+
coeffs_for_cov.append(coeff)
|
152 |
+
p_values_for_cov.append(p_value)
|
153 |
+
overall_summary["all_strata_coefficients"][cov].append(coeff)
|
154 |
+
overall_summary["all_strata_p_values"][cov].append(p_value)
|
155 |
+
|
156 |
+
stratum_detail["status"] = "Analyzed"
|
157 |
+
stratum_detail["coefficient_on_treatment"] = coeff
|
158 |
+
stratum_detail["p_value_on_treatment"] = p_value
|
159 |
+
if not pd.isna(p_value) and p_value < 0.05:
|
160 |
+
balance_results[cov]["num_significant_strata_p005"] += 1
|
161 |
+
balance_results[cov]["balanced_heuristic"] = False # If any stratum is unbalanced
|
162 |
+
|
163 |
+
except Exception as e_bal:
|
164 |
+
logger.debug(f"Balance check regression failed for {cov} in stratum {stratum_idx}: {e_bal}")
|
165 |
+
stratum_detail["status"] = f"Error: {str(e_bal)}"
|
166 |
+
stratum_detail["coefficient_on_treatment"] = np.nan
|
167 |
+
stratum_detail["p_value_on_treatment"] = np.nan
|
168 |
+
|
169 |
+
balance_results[cov]["strata_details"].append(stratum_detail)
|
170 |
+
|
171 |
+
if coeffs_for_cov:
|
172 |
+
balance_results[cov]["mean_abs_coefficient"] = np.nanmean(np.abs(coeffs_for_cov))
|
173 |
+
else:
|
174 |
+
balance_results[cov]["mean_abs_coefficient"] = np.nan # No strata were analyzable
|
175 |
+
|
176 |
+
overall_summary["num_covariates_potentially_imbalanced_p005"] = sum(
|
177 |
+
1 for cov_data in balance_results.values() if not cov_data["balanced_heuristic"]
|
178 |
+
)
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
logger.error(f"Error during GPS balance assessment: {e}", exc_info=True)
|
182 |
+
overall_summary["error"] = f"Overall assessment error: {str(e)}"
|
183 |
+
return {
|
184 |
+
"error": str(e),
|
185 |
+
"balance_results": balance_results,
|
186 |
+
"summary_stats": overall_summary,
|
187 |
+
"summary": "Balance assessment failed due to an unexpected error."
|
188 |
+
}
|
189 |
+
|
190 |
+
logger.info("GPS balance assessment complete.")
|
191 |
+
|
192 |
+
return {
|
193 |
+
"balance_results_per_covariate": balance_results,
|
194 |
+
"summary_stats": overall_summary,
|
195 |
+
"summary": "GPS balance assessment finished. Review strata details and mean absolute coefficients."
|
196 |
+
}
|
auto_causal/methods/generalized_propensity_score/estimator.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Core estimation logic for the Generalized Propensity Score (GPS) method.
|
3 |
+
"""
|
4 |
+
from typing import Dict, List, Any
|
5 |
+
import pandas as pd
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
import statsmodels.api as sm
|
9 |
+
|
10 |
+
from .diagnostics import assess_gps_balance # Import for balance check
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
def estimate_effect_gps(
|
15 |
+
df: pd.DataFrame,
|
16 |
+
treatment: str,
|
17 |
+
outcome: str,
|
18 |
+
covariates: List[str],
|
19 |
+
**kwargs: Any
|
20 |
+
) -> Dict[str, Any]:
|
21 |
+
"""
|
22 |
+
Estimates the causal effect using the Generalized Propensity Score method
|
23 |
+
for continuous treatments.
|
24 |
+
|
25 |
+
This function will be called by the method_executor_tool.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
df: The input DataFrame.
|
29 |
+
treatment: The name of the continuous treatment variable column.
|
30 |
+
outcome: The name of the outcome variable column.
|
31 |
+
covariates: A list of covariate column names.
|
32 |
+
**kwargs: Additional arguments for controlling the estimation, including:
|
33 |
+
- gps_model_spec (dict): Specification for the GPS model (T ~ X).
|
34 |
+
- outcome_model_spec (dict): Specification for the outcome model (Y ~ T, GPS).
|
35 |
+
- t_values_range (list or dict): Specification for treatment levels for ADRF.
|
36 |
+
- n_bootstraps (int): Number of bootstrap replications for SEs.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
A dictionary containing the estimation results, including:
|
40 |
+
- "effect_estimate": Typically the ADRF or a specific contrast.
|
41 |
+
- "standard_error": Standard error for the primary effect estimate.
|
42 |
+
- "confidence_interval": Confidence interval for the primary estimate.
|
43 |
+
- "adrf_curve": Data representing the Average Dose-Response Function.
|
44 |
+
- "specific_contrasts": Any calculated specific contrasts.
|
45 |
+
- "diagnostics": Results from diagnostic checks (e.g., balance).
|
46 |
+
- "method_details": Description of the method and models used.
|
47 |
+
- "parameters_used": Dictionary of parameters used.
|
48 |
+
"""
|
49 |
+
logger.info(f"Starting GPS estimation for treatment '{treatment}', outcome '{outcome}'.")
|
50 |
+
|
51 |
+
# --- Parameter Extraction and Defaults ---
|
52 |
+
gps_model_spec = kwargs.get('gps_model_spec', {"type": "linear"})
|
53 |
+
outcome_model_spec = kwargs.get('outcome_model_spec', {"type": "polynomial", "degree": 2, "interaction": True})
|
54 |
+
|
55 |
+
# Get t_values for ADRF from llm_assist or kwargs, default to 10 points over observed range
|
56 |
+
# For simplicity, we'll use a simple range here. In a full impl, this might call llm_assist.
|
57 |
+
t_values_for_adrf = kwargs.get('t_values_for_adrf')
|
58 |
+
if t_values_for_adrf is None:
|
59 |
+
min_t_obs = df[treatment].min()
|
60 |
+
max_t_obs = df[treatment].max()
|
61 |
+
if pd.isna(min_t_obs) or pd.isna(max_t_obs) or min_t_obs == max_t_obs:
|
62 |
+
logger.warning(f"Cannot determine a valid range for treatment '{treatment}' for ADRF. Using limited points.")
|
63 |
+
t_values_for_adrf = sorted(list(df[treatment].dropna().unique()))[:10] # Fallback
|
64 |
+
else:
|
65 |
+
t_values_for_adrf = np.linspace(min_t_obs, max_t_obs, 10).tolist()
|
66 |
+
|
67 |
+
n_bootstraps = kwargs.get('n_bootstraps', 0) # Default to 0, meaning no bootstrap for now
|
68 |
+
|
69 |
+
logger.info(f"Using GPS model spec: {gps_model_spec}")
|
70 |
+
logger.info(f"Using outcome model spec: {outcome_model_spec}")
|
71 |
+
logger.info(f"Evaluating ADRF at t-values: {t_values_for_adrf}")
|
72 |
+
|
73 |
+
try:
|
74 |
+
# 2. Estimate GPS Values
|
75 |
+
df_with_gps, gps_estimation_diagnostics = _estimate_gps_values(
|
76 |
+
df.copy(), treatment, covariates, gps_model_spec
|
77 |
+
)
|
78 |
+
if 'gps_score' not in df_with_gps.columns or df_with_gps['gps_score'].isnull().all():
|
79 |
+
logger.error("GPS estimation failed or resulted in all NaNs.")
|
80 |
+
return {
|
81 |
+
"error": "GPS estimation failed.",
|
82 |
+
"diagnostics": gps_estimation_diagnostics,
|
83 |
+
"method_details": "GPS (Failed)",
|
84 |
+
"parameters_used": kwargs
|
85 |
+
}
|
86 |
+
|
87 |
+
# Drop rows where GPS or outcome or necessary modeling variables are NaN before proceeding
|
88 |
+
modeling_cols = [outcome, treatment, 'gps_score'] + covariates
|
89 |
+
df_with_gps.dropna(subset=modeling_cols, inplace=True)
|
90 |
+
if df_with_gps.empty:
|
91 |
+
logger.error("DataFrame is empty after GPS estimation and NaN removal.")
|
92 |
+
return {"error": "No data available after GPS estimation and NaN removal.", "method_details": "GPS (Failed)", "parameters_used": kwargs}
|
93 |
+
|
94 |
+
|
95 |
+
# 3. Assess GPS Balance (call diagnostics.assess_gps_balance)
|
96 |
+
balance_diagnostics = assess_gps_balance(
|
97 |
+
df_with_gps, treatment, covariates, 'gps_score' # kwargs for assess_gps_balance can be passed if needed
|
98 |
+
)
|
99 |
+
|
100 |
+
# 4. Estimate Outcome Model
|
101 |
+
fitted_outcome_model = _estimate_outcome_model(
|
102 |
+
df_with_gps, outcome, treatment, 'gps_score', outcome_model_spec
|
103 |
+
)
|
104 |
+
|
105 |
+
# 5. Generate Dose-Response Function
|
106 |
+
adrf_results = _generate_dose_response_function(
|
107 |
+
df_with_gps, fitted_outcome_model, treatment, 'gps_score', outcome_model_spec, t_values_for_adrf
|
108 |
+
)
|
109 |
+
adrf_curve_data = {"t_levels": t_values_for_adrf, "expected_outcomes": adrf_results}
|
110 |
+
|
111 |
+
# 6. Calculate specific contrasts if requested (Placeholder)
|
112 |
+
specific_contrasts = {"info": "Specific contrasts not implemented in this version."}
|
113 |
+
|
114 |
+
# 7. Perform bootstrapping for SEs if requested (Placeholder for now)
|
115 |
+
standard_error_info = {"info": "Bootstrap SEs not implemented in this version."}
|
116 |
+
confidence_interval_info = {"info": "Bootstrap CIs not implemented in this version."}
|
117 |
+
if n_bootstraps > 0:
|
118 |
+
logger.info(f"Bootstrapping with {n_bootstraps} replications (placeholder).")
|
119 |
+
# Actual bootstrapping logic would go here.
|
120 |
+
# For now, we'll just note that it's not implemented.
|
121 |
+
|
122 |
+
logger.info("GPS estimation steps completed.")
|
123 |
+
|
124 |
+
# Consolidate diagnostics
|
125 |
+
all_diagnostics = {
|
126 |
+
"gps_estimation_diagnostics": gps_estimation_diagnostics,
|
127 |
+
"balance_check": balance_diagnostics, # Now using the actual balance check results
|
128 |
+
"outcome_model_summary": str(fitted_outcome_model.summary()) if fitted_outcome_model else "Outcome model not fitted.",
|
129 |
+
"warnings": [], # Populate with any warnings during the process
|
130 |
+
"summary": "GPS estimation complete."
|
131 |
+
}
|
132 |
+
|
133 |
+
return {
|
134 |
+
"effect_estimate": adrf_curve_data, # The ADRF is the primary "effect"
|
135 |
+
"standard_error_info": standard_error_info, # Placeholder
|
136 |
+
"confidence_interval_info": confidence_interval_info, # Placeholder
|
137 |
+
"adrf_curve": adrf_curve_data,
|
138 |
+
"specific_contrasts": specific_contrasts, # Placeholder
|
139 |
+
"diagnostics": all_diagnostics,
|
140 |
+
"method_details": f"Generalized Propensity Score (GPS) with {gps_model_spec.get('type', 'N/A')} GPS model and {outcome_model_spec.get('type', 'N/A')} outcome model.",
|
141 |
+
"parameters_used": {
|
142 |
+
"treatment_var": treatment,
|
143 |
+
"outcome_var": outcome,
|
144 |
+
"covariate_vars": covariates,
|
145 |
+
"gps_model_spec": gps_model_spec,
|
146 |
+
"outcome_model_spec": outcome_model_spec,
|
147 |
+
"t_values_for_adrf": t_values_for_adrf,
|
148 |
+
"n_bootstraps": n_bootstraps,
|
149 |
+
**kwargs
|
150 |
+
}
|
151 |
+
}
|
152 |
+
except Exception as e:
|
153 |
+
logger.error(f"Error during GPS estimation pipeline: {e}", exc_info=True)
|
154 |
+
return {
|
155 |
+
"error": f"Pipeline failed: {str(e)}",
|
156 |
+
"method_details": "GPS (Failed)",
|
157 |
+
"diagnostics": {"error": f"Pipeline failed during GPS estimation: {str(e)}"}, # Add diagnostics here too
|
158 |
+
"parameters_used": kwargs
|
159 |
+
}
|
160 |
+
|
161 |
+
|
162 |
+
# Placeholder for internal helper functions
|
163 |
+
def _estimate_gps_values(
|
164 |
+
df: pd.DataFrame,
|
165 |
+
treatment: str,
|
166 |
+
covariates: List[str],
|
167 |
+
gps_model_spec: Dict
|
168 |
+
) -> tuple[pd.DataFrame, Dict]:
|
169 |
+
"""
|
170 |
+
Estimates Generalized Propensity Scores.
|
171 |
+
Assumes T | X ~ N(X*beta, sigma^2), so GPS is the conditional density.
|
172 |
+
"""
|
173 |
+
logger.info(f"Estimating GPS for treatment '{treatment}' using covariates: {covariates}")
|
174 |
+
diagnostics = {}
|
175 |
+
|
176 |
+
if not covariates:
|
177 |
+
logger.error("No covariates provided for GPS estimation.")
|
178 |
+
diagnostics["error"] = "No covariates provided."
|
179 |
+
df['gps_score'] = np.nan # Ensure gps_score column is added
|
180 |
+
return df, diagnostics
|
181 |
+
|
182 |
+
X_df = df[covariates]
|
183 |
+
T_series = df[treatment]
|
184 |
+
|
185 |
+
# Handle potential NaN values in covariates or treatment before modeling
|
186 |
+
valid_indices = X_df.dropna().index.intersection(T_series.dropna().index)
|
187 |
+
if len(valid_indices) < len(df):
|
188 |
+
logger.warning(f"Dropped {len(df) - len(valid_indices)} rows due to NaNs in treatment/covariates before GPS estimation.")
|
189 |
+
diagnostics["pre_estimation_nan_rows_dropped"] = len(df) - len(valid_indices)
|
190 |
+
|
191 |
+
X = X_df.loc[valid_indices]
|
192 |
+
T = T_series.loc[valid_indices]
|
193 |
+
|
194 |
+
if X.empty or T.empty:
|
195 |
+
logger.error("Covariate or treatment data is empty after NaN handling.")
|
196 |
+
diagnostics["error"] = "Covariate or treatment data is empty after NaN handling."
|
197 |
+
return df, diagnostics
|
198 |
+
|
199 |
+
X_sm = sm.add_constant(X, has_constant='add')
|
200 |
+
|
201 |
+
try:
|
202 |
+
if gps_model_spec.get("type") == 'linear':
|
203 |
+
model = sm.OLS(T, X_sm).fit()
|
204 |
+
t_hat = model.predict(X_sm)
|
205 |
+
residuals = T - t_hat
|
206 |
+
# MSE: sum of squared residuals / (n - k) where k is number of regressors (including const)
|
207 |
+
if len(T) <= X_sm.shape[1]:
|
208 |
+
logger.error("Not enough degrees of freedom to estimate sigma_sq_hat.")
|
209 |
+
diagnostics["error"] = "Not enough degrees of freedom for GPS variance."
|
210 |
+
df['gps_score'] = np.nan
|
211 |
+
return df, diagnostics
|
212 |
+
|
213 |
+
sigma_sq_hat = np.sum(residuals**2) / (len(T) - X_sm.shape[1])
|
214 |
+
|
215 |
+
if sigma_sq_hat <= 1e-9: # Check for effectively zero or very small variance
|
216 |
+
logger.warning(f"Estimated residual variance (sigma_sq_hat) is very close to zero ({sigma_sq_hat}). GPS will be set to NaN.")
|
217 |
+
diagnostics["warning_sigma_sq_hat_near_zero"] = sigma_sq_hat
|
218 |
+
df['gps_score'] = np.nan # Set GPS to NaN as density is ill-defined
|
219 |
+
if sigma_sq_hat == 0: # if it is exactly zero, add specific error
|
220 |
+
diagnostics["error_sigma_sq_hat_is_zero"] = "Residual variance is exactly zero."
|
221 |
+
return df, diagnostics
|
222 |
+
|
223 |
+
|
224 |
+
# Calculate GPS: (1 / sqrt(2*pi*sigma_hat^2)) * exp(-(T_i - T_hat_i)^2 / (2*sigma_hat^2))
|
225 |
+
# Ensure calculation is done on the original T values (T_series.loc[valid_indices])
|
226 |
+
# and corresponding t_hat for those valid_indices
|
227 |
+
gps_values_calculated = (1 / np.sqrt(2 * np.pi * sigma_sq_hat)) * np.exp(-((T - t_hat)**2) / (2 * sigma_sq_hat))
|
228 |
+
|
229 |
+
# Assign back to the original DataFrame using .loc to ensure alignment
|
230 |
+
df['gps_score'] = np.nan # Initialize column
|
231 |
+
df.loc[valid_indices, 'gps_score'] = gps_values_calculated
|
232 |
+
|
233 |
+
diagnostics["gps_model_type"] = "linear_ols"
|
234 |
+
diagnostics["gps_model_rsquared"] = model.rsquared
|
235 |
+
diagnostics["gps_residual_variance_mse"] = sigma_sq_hat
|
236 |
+
diagnostics["num_observations_for_gps_model"] = len(T)
|
237 |
+
|
238 |
+
else:
|
239 |
+
logger.error(f"GPS model type '{gps_model_spec.get('type')}' not implemented.")
|
240 |
+
diagnostics["error"] = f"GPS model type '{gps_model_spec.get('type')}' not implemented."
|
241 |
+
df['gps_score'] = np.nan
|
242 |
+
|
243 |
+
except Exception as e:
|
244 |
+
logger.error(f"Error during GPS model estimation: {e}", exc_info=True)
|
245 |
+
diagnostics["error"] = f"Exception during GPS estimation: {str(e)}"
|
246 |
+
df['gps_score'] = np.nan
|
247 |
+
|
248 |
+
# Ensure the original df is not modified if no valid indices for GPS estimation
|
249 |
+
if 'gps_score' not in df.columns:
|
250 |
+
df['gps_score'] = np.nan
|
251 |
+
|
252 |
+
return df, diagnostics
|
253 |
+
|
254 |
+
def _estimate_outcome_model(
|
255 |
+
df_with_gps: pd.DataFrame,
|
256 |
+
outcome: str,
|
257 |
+
treatment: str,
|
258 |
+
gps_col_name: str,
|
259 |
+
outcome_model_spec: Dict
|
260 |
+
) -> Any: # Returns a fitted statsmodels model
|
261 |
+
"""
|
262 |
+
Estimates the outcome model Y ~ f(T, GPS).
|
263 |
+
"""
|
264 |
+
logger.info(f"Estimating outcome model for '{outcome}' using T='{treatment}', GPS='{gps_col_name}'")
|
265 |
+
|
266 |
+
Y = df_with_gps[outcome]
|
267 |
+
T_val = pd.Series(df_with_gps[treatment].values, index=df_with_gps.index)
|
268 |
+
GPS_val = pd.Series(df_with_gps[gps_col_name].values, index=df_with_gps.index)
|
269 |
+
|
270 |
+
X_outcome_dict = {'intercept': np.ones(len(df_with_gps))}
|
271 |
+
|
272 |
+
model_type = outcome_model_spec.get("type", "polynomial")
|
273 |
+
degree = outcome_model_spec.get("degree", 2)
|
274 |
+
interaction = outcome_model_spec.get("interaction", True)
|
275 |
+
|
276 |
+
if model_type == "polynomial":
|
277 |
+
X_outcome_dict['T'] = T_val
|
278 |
+
X_outcome_dict['GPS'] = GPS_val
|
279 |
+
if degree >= 2:
|
280 |
+
X_outcome_dict['T_sq'] = T_val**2
|
281 |
+
X_outcome_dict['GPS_sq'] = GPS_val**2
|
282 |
+
if degree >=3: # Example for higher order, can be made more general
|
283 |
+
X_outcome_dict['T_cub'] = T_val**3
|
284 |
+
X_outcome_dict['GPS_cub'] = GPS_val**3
|
285 |
+
if interaction:
|
286 |
+
X_outcome_dict['T_x_GPS'] = T_val * GPS_val
|
287 |
+
if degree >=2: # Interaction with squared terms if degree allows
|
288 |
+
X_outcome_dict['T_sq_x_GPS'] = (T_val**2) * GPS_val
|
289 |
+
X_outcome_dict['T_x_GPS_sq'] = T_val * (GPS_val**2)
|
290 |
+
|
291 |
+
# Add more model types as needed (e.g., splines)
|
292 |
+
else:
|
293 |
+
logger.warning(f"Outcome model type '{model_type}' not fully recognized. Defaulting to T + GPS.")
|
294 |
+
X_outcome_dict['T'] = T_val
|
295 |
+
X_outcome_dict['GPS'] = GPS_val
|
296 |
+
# Fallback to linear if spec is unknown or simple
|
297 |
+
|
298 |
+
X_outcome_df = pd.DataFrame(X_outcome_dict, index=df_with_gps.index)
|
299 |
+
|
300 |
+
# Drop rows with NaNs that might have been introduced by transformations if T or GPS were NaN
|
301 |
+
# (though earlier dropna should handle most of this for input T/GPS)
|
302 |
+
valid_outcome_model_indices = Y.dropna().index.intersection(X_outcome_df.dropna().index)
|
303 |
+
if len(valid_outcome_model_indices) < len(df_with_gps):
|
304 |
+
logger.warning(f"Dropped {len(df_with_gps) - len(valid_outcome_model_indices)} rows due to NaNs before outcome model fitting.")
|
305 |
+
|
306 |
+
Y_fit = Y.loc[valid_outcome_model_indices]
|
307 |
+
X_outcome_df_fit = X_outcome_df.loc[valid_outcome_model_indices]
|
308 |
+
|
309 |
+
if Y_fit.empty or X_outcome_df_fit.empty:
|
310 |
+
logger.error("Not enough data to fit outcome model after NaN handling.")
|
311 |
+
raise ValueError("Empty data for outcome model fitting.")
|
312 |
+
|
313 |
+
try:
|
314 |
+
model = sm.OLS(Y_fit, X_outcome_df_fit).fit()
|
315 |
+
logger.info("Outcome model estimated successfully.")
|
316 |
+
return model
|
317 |
+
except Exception as e:
|
318 |
+
logger.error(f"Error during outcome model estimation: {e}", exc_info=True)
|
319 |
+
raise # Re-raise the exception to be caught by the main try-except block
|
320 |
+
|
321 |
+
def _generate_dose_response_function(
|
322 |
+
df_with_gps: pd.DataFrame,
|
323 |
+
fitted_outcome_model: Any,
|
324 |
+
treatment: str,
|
325 |
+
gps_col_name: str,
|
326 |
+
outcome_model_spec: Dict, # To know how to construct X_pred features
|
327 |
+
t_values_to_evaluate: List[float]
|
328 |
+
) -> List[float]:
|
329 |
+
"""
|
330 |
+
Calculates the Average Dose-Response Function (ADRF).
|
331 |
+
E[Y(t)] = integral over E[Y | T=t, GPS=g] * f(g) dg
|
332 |
+
~= (1/N) * sum_i E[Y | T=t, GPS=g_i] (using observed GPS values)
|
333 |
+
"""
|
334 |
+
logger.info(f"Calculating ADRF for treatment levels: {t_values_to_evaluate}")
|
335 |
+
adrf_estimates = []
|
336 |
+
|
337 |
+
if not t_values_to_evaluate: # Handle empty list case
|
338 |
+
logger.warning("t_values_to_evaluate is empty. ADRF calculation will be skipped.")
|
339 |
+
return []
|
340 |
+
|
341 |
+
model_exog_names = fitted_outcome_model.model.exog_names
|
342 |
+
|
343 |
+
# Original GPS values from the dataframe
|
344 |
+
original_gps_values = pd.Series(df_with_gps[gps_col_name].values, index=df_with_gps.index)
|
345 |
+
|
346 |
+
for t_level in t_values_to_evaluate:
|
347 |
+
# Create a new DataFrame for prediction at this t_level
|
348 |
+
# Each row corresponds to an original observation's GPS, but with T set to t_level
|
349 |
+
X_pred_dict = {'intercept': np.ones(len(df_with_gps))}
|
350 |
+
|
351 |
+
# Reconstruct features based on outcome_model_spec and model_exog_names
|
352 |
+
# This mirrors the construction in _estimate_outcome_model
|
353 |
+
degree = outcome_model_spec.get("degree", 2)
|
354 |
+
interaction = outcome_model_spec.get("interaction", True)
|
355 |
+
|
356 |
+
if 'T' in model_exog_names: X_pred_dict['T'] = t_level
|
357 |
+
if 'GPS' in model_exog_names: X_pred_dict['GPS'] = original_gps_values
|
358 |
+
|
359 |
+
if 'T_sq' in model_exog_names: X_pred_dict['T_sq'] = t_level**2
|
360 |
+
if 'GPS_sq' in model_exog_names: X_pred_dict['GPS_sq'] = original_gps_values**2
|
361 |
+
|
362 |
+
if 'T_cub' in model_exog_names: X_pred_dict['T_cub'] = t_level**3 # Example
|
363 |
+
if 'GPS_cub' in model_exog_names: X_pred_dict['GPS_cub'] = original_gps_values**3 # Example
|
364 |
+
|
365 |
+
if 'T_x_GPS' in model_exog_names and interaction:
|
366 |
+
X_pred_dict['T_x_GPS'] = t_level * original_gps_values
|
367 |
+
if 'T_sq_x_GPS' in model_exog_names and interaction and degree >=2:
|
368 |
+
X_pred_dict['T_sq_x_GPS'] = (t_level**2) * original_gps_values
|
369 |
+
if 'T_x_GPS_sq' in model_exog_names and interaction and degree >=2:
|
370 |
+
X_pred_dict['T_x_GPS_sq'] = t_level * (original_gps_values**2)
|
371 |
+
|
372 |
+
X_pred_df = pd.DataFrame(X_pred_dict, index=df_with_gps.index)
|
373 |
+
|
374 |
+
# Ensure all required columns are present and in the correct order
|
375 |
+
# Drop any rows that might have NaNs if original_gps_values had NaNs (though they should be filtered before this)
|
376 |
+
X_pred_df_fit = X_pred_df[model_exog_names].dropna()
|
377 |
+
|
378 |
+
if X_pred_df_fit.empty:
|
379 |
+
logger.warning(f"Prediction data for t_level={t_level} is empty after NaN drop. Assigning NaN to ADRF point.")
|
380 |
+
adrf_estimates.append(np.nan)
|
381 |
+
continue
|
382 |
+
|
383 |
+
predicted_outcomes_at_t = fitted_outcome_model.predict(X_pred_df_fit)
|
384 |
+
adrf_estimates.append(np.mean(predicted_outcomes_at_t))
|
385 |
+
|
386 |
+
return adrf_estimates
|
auto_causal/methods/generalized_propensity_score/llm_assist.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM-assisted components for the Generalized Propensity Score (GPS) method.
|
3 |
+
|
4 |
+
These functions help in suggesting model specifications or parameters
|
5 |
+
by leveraging an LLM, providing intelligent defaults when not specified by the user.
|
6 |
+
"""
|
7 |
+
from typing import Dict, List, Any, Optional
|
8 |
+
import pandas as pd
|
9 |
+
import logging
|
10 |
+
from auto_causal.utils.llm_helpers import call_llm_with_json_output # Hypothetical import
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
def suggest_treatment_model_spec(
|
15 |
+
df: pd.DataFrame,
|
16 |
+
treatment_var: str,
|
17 |
+
covariate_vars: List[str],
|
18 |
+
query: Optional[str] = None,
|
19 |
+
llm_client: Optional[Any] = None
|
20 |
+
) -> Dict[str, Any]:
|
21 |
+
"""
|
22 |
+
Suggests a model specification for the treatment mechanism (T ~ X) in GPS.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
df: The input DataFrame.
|
26 |
+
treatment_var: The name of the continuous treatment variable.
|
27 |
+
covariate_vars: A list of covariate names.
|
28 |
+
query: Optional user query for context.
|
29 |
+
llm_client: Optional LLM client for making a call.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
A dictionary representing the suggested model specification.
|
33 |
+
E.g., {"type": "linear", "formula": "T ~ X1 + X2"} or
|
34 |
+
{"type": "random_forest", "params": {...}}
|
35 |
+
"""
|
36 |
+
logger.info(f"Suggesting treatment model spec for: {treatment_var}")
|
37 |
+
|
38 |
+
# Example of constructing a more detailed prompt for an LLM
|
39 |
+
prompt_parts = [
|
40 |
+
f"You are an expert econometrician. The user wants to estimate a Generalized Propensity Score (GPS) for a continuous treatment variable '{treatment_var}'.",
|
41 |
+
f"The available covariates are: {covariate_vars}.",
|
42 |
+
f"The user's research query is: '{query if query else 'Not specified'}'.",
|
43 |
+
"Based on this information and general best practices for GPS estimation:",
|
44 |
+
"1. Suggest a suitable model type for estimating the treatment (T) given covariates (X). Common choices include 'linear' (OLS), or flexible models like 'random_forest' or 'gradient_boosting' if non-linearities are suspected.",
|
45 |
+
"2. If suggesting a regression model like OLS, provide a Patsy-style formula string (e.g., 'treatment ~ cov1 + cov2 + cov1*cov2').",
|
46 |
+
"3. If suggesting a machine learning model, list key hyperparameters and reasonable starting values (e.g., n_estimators, max_depth).",
|
47 |
+
"Return your suggestion as a JSON object with the following structure:",
|
48 |
+
'''
|
49 |
+
{
|
50 |
+
"model_type": "<e.g., linear, random_forest>",
|
51 |
+
"formula": "<Patsy formula if model_type is linear/glm, else null>",
|
52 |
+
"parameters": { // if applicable for ML models
|
53 |
+
"<param1_name>": "<param1_value>",
|
54 |
+
"<param2_name>": "<param2_value>"
|
55 |
+
},
|
56 |
+
"reasoning": "<Brief justification for your suggestion>"
|
57 |
+
}
|
58 |
+
'''
|
59 |
+
]
|
60 |
+
full_prompt = "\n".join(prompt_parts)
|
61 |
+
|
62 |
+
if llm_client:
|
63 |
+
logger.info("LLM client provided. Sending constructed prompt (actual call is hypothetical).")
|
64 |
+
logger.debug(f"LLM Prompt for treatment model spec:\n{full_prompt}")
|
65 |
+
# In a real implementation:
|
66 |
+
# response_json = call_llm_with_json_output(llm_client, full_prompt)
|
67 |
+
# if response_json and isinstance(response_json, dict):
|
68 |
+
# return response_json
|
69 |
+
# else:
|
70 |
+
# logger.warning("LLM did not return a valid JSON dict for treatment model spec.")
|
71 |
+
pass # Pass for now as it's a hypothetical call
|
72 |
+
|
73 |
+
# Default suggestion if no LLM or LLM fails
|
74 |
+
return {
|
75 |
+
"model_type": "linear",
|
76 |
+
"formula": f"{treatment_var} ~ {' + '.join(covariate_vars) if covariate_vars else '1'}",
|
77 |
+
"parameters": None,
|
78 |
+
"reasoning": "Defaulting to a linear model for T ~ X. Consider a more flexible model if non-linearities are expected.",
|
79 |
+
"comment": "This is a default suggestion."
|
80 |
+
}
|
81 |
+
|
82 |
+
def suggest_outcome_model_spec(
|
83 |
+
df: pd.DataFrame,
|
84 |
+
outcome_var: str,
|
85 |
+
treatment_var: str,
|
86 |
+
gps_col_name: str,
|
87 |
+
query: Optional[str] = None,
|
88 |
+
llm_client: Optional[Any] = None
|
89 |
+
) -> Dict[str, Any]:
|
90 |
+
"""
|
91 |
+
Suggests a model specification for the outcome mechanism (Y ~ T, GPS) in GPS.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
df: The input DataFrame.
|
95 |
+
outcome_var: The name of the outcome variable.
|
96 |
+
treatment_var: The name of the continuous treatment variable.
|
97 |
+
gps_col_name: The name of the GPS column.
|
98 |
+
query: Optional user query for context.
|
99 |
+
llm_client: Optional LLM client for making a call.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
A dictionary representing the suggested model specification.
|
103 |
+
E.g., {"type": "polynomial", "degree": 2, "interaction": True,
|
104 |
+
"formula": "Y ~ T + T^2 + GPS + GPS^2 + T*GPS"}
|
105 |
+
"""
|
106 |
+
logger.info(f"Suggesting outcome model spec for: {outcome_var}")
|
107 |
+
|
108 |
+
prompt_parts = [
|
109 |
+
f"You are an expert econometrician. For a Generalized Propensity Score (GPS) analysis, the user needs to model the outcome '{outcome_var}' conditional on the continuous treatment '{treatment_var}' and the estimated GPS (column name '{gps_col_name}').",
|
110 |
+
"The goal is to flexibly capture the relationship E[Y | T, GPS]. A common approach is to use a polynomial specification for T and GPS, including interaction terms.",
|
111 |
+
f"The user's research query is: '{query if query else 'Not specified'}'.",
|
112 |
+
"Suggest a specification for this outcome model. Consider:",
|
113 |
+
"1. The functional form for T (e.g., linear, quadratic, cubic).",
|
114 |
+
"2. The functional form for GPS (e.g., linear, quadratic, cubic).",
|
115 |
+
"3. Whether to include interaction terms between T and GPS (e.g., T*GPS, T^2*GPS, T*GPS^2).",
|
116 |
+
"Return your suggestion as a JSON object with the following structure:",
|
117 |
+
'''
|
118 |
+
{
|
119 |
+
"model_type": "polynomial", // Or other types like "splines"
|
120 |
+
"treatment_terms": ["T", "T_sq"], // e.g., ["T"] for linear, ["T", "T_sq"] for quadratic
|
121 |
+
"gps_terms": ["GPS", "GPS_sq"], // e.g., ["GPS"] for linear, ["GPS", "GPS_sq"] for quadratic
|
122 |
+
"interaction_terms": ["T_x_GPS", "T_sq_x_GPS", "T_x_GPS_sq"], // Interactions to include, or empty list
|
123 |
+
"reasoning": "<Brief justification for your suggestion>"
|
124 |
+
}
|
125 |
+
'''
|
126 |
+
]
|
127 |
+
full_prompt = "\n".join(prompt_parts)
|
128 |
+
|
129 |
+
if llm_client:
|
130 |
+
logger.info("LLM client provided. Sending constructed prompt for outcome model (hypothetical call).")
|
131 |
+
logger.debug(f"LLM Prompt for outcome model spec:\n{full_prompt}")
|
132 |
+
# In a real implementation:
|
133 |
+
# response_json = call_llm_with_json_output(llm_client, full_prompt)
|
134 |
+
# if response_json and isinstance(response_json, dict):
|
135 |
+
# # Basic validation of expected keys for outcome model could go here
|
136 |
+
# return response_json
|
137 |
+
# else:
|
138 |
+
# logger.warning("LLM did not return a valid JSON dict for outcome model spec.")
|
139 |
+
pass # Pass for now
|
140 |
+
|
141 |
+
# Default suggestion
|
142 |
+
return {
|
143 |
+
"model_type": "polynomial",
|
144 |
+
"treatment_terms": ["T", "T_sq"],
|
145 |
+
"gps_terms": ["GPS", "GPS_sq"],
|
146 |
+
"interaction_terms": ["T_x_GPS"],
|
147 |
+
"reasoning": "Defaulting to a quadratic specification for T and GPS with a simple T*GPS interaction. This is a common starting point.",
|
148 |
+
"comment": "This is a default suggestion."
|
149 |
+
}
|
150 |
+
|
151 |
+
def suggest_dose_response_t_values(
|
152 |
+
df: pd.DataFrame,
|
153 |
+
treatment_var: str,
|
154 |
+
num_points: int = 20,
|
155 |
+
llm_client: Optional[Any] = None
|
156 |
+
) -> List[float]:
|
157 |
+
"""
|
158 |
+
Suggests a relevant range and number of points for estimating the ADRF.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
df: The input DataFrame.
|
162 |
+
treatment_var: The name of the continuous treatment variable.
|
163 |
+
num_points: Desired number of points for the ADRF curve.
|
164 |
+
llm_client: Optional LLM client for making a call.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
A list of treatment values at which to evaluate the ADRF.
|
168 |
+
"""
|
169 |
+
logger.info(f"Suggesting dose response t-values for: {treatment_var}")
|
170 |
+
|
171 |
+
prompt_parts = [
|
172 |
+
f"For a Generalized Propensity Score (GPS) analysis with continuous treatment '{treatment_var}', the user needs to estimate an Average Dose-Response Function (ADRF).",
|
173 |
+
f"The observed range of '{treatment_var}' is from {df[treatment_var].min():.2f} to {df[treatment_var].max():.2f}.",
|
174 |
+
f"The user desires approximately {num_points} points for the ADRF curve.",
|
175 |
+
f"The user's research query is: '{query if query else 'Not specified'}'.",
|
176 |
+
"Suggest a list of specific treatment values (t_values) at which to evaluate the ADRF. Consider:",
|
177 |
+
"1. Covering the observed range of the treatment.",
|
178 |
+
"2. Potentially including specific points of policy interest if deducible from the query (though this is advanced).",
|
179 |
+
"3. Ensuring a reasonable distribution of points (e.g., equally spaced, or based on quantiles).",
|
180 |
+
"Return your suggestion as a JSON object with a single key 't_values' holding a list of floats:",
|
181 |
+
'''
|
182 |
+
{
|
183 |
+
"t_values": [<float>, <float>, ..., <float>],
|
184 |
+
"reasoning": "<Brief justification for the choice/distribution of these t_values>"
|
185 |
+
}
|
186 |
+
'''
|
187 |
+
]
|
188 |
+
full_prompt = "\n".join(prompt_parts)
|
189 |
+
|
190 |
+
if llm_client:
|
191 |
+
logger.info("LLM client provided. Sending prompt for t-values (hypothetical call).")
|
192 |
+
logger.debug(f"LLM Prompt for t-values:\n{full_prompt}")
|
193 |
+
# In a real implementation:
|
194 |
+
# response_json = call_llm_with_json_output(llm_client, full_prompt)
|
195 |
+
# if response_json and isinstance(response_json, dict) and 't_values' in response_json and isinstance(response_json['t_values'], list):
|
196 |
+
# return response_json['t_values'] # Assuming it returns the list directly based on current function signature
|
197 |
+
# else:
|
198 |
+
# logger.warning("LLM did not return a valid JSON with 't_values' list for ADRF points.")
|
199 |
+
pass # Pass for now
|
200 |
+
|
201 |
+
# Default: Linearly spaced points
|
202 |
+
min_t = df[treatment_var].min()
|
203 |
+
max_t = df[treatment_var].max()
|
204 |
+
if pd.isna(min_t) or pd.isna(max_t) or min_t == max_t:
|
205 |
+
logger.warning(f"Could not determine a valid range for treatment '{treatment_var}'. Returning empty list.")
|
206 |
+
return []
|
207 |
+
|
208 |
+
return list(pd.Series.linspace(min_t, max_t, num_points))
|
auto_causal/methods/instrumental_variable/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .estimator import estimate_effect
|
auto_causal/methods/instrumental_variable/diagnostics.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Placeholder for IV-specific diagnostic functions
|
2 |
+
import pandas as pd
|
3 |
+
import statsmodels.api as sm
|
4 |
+
from statsmodels.regression.linear_model import OLS
|
5 |
+
# from statsmodels.sandbox.regression.gmm import IV2SLSResults # Removed problematic import
|
6 |
+
from typing import Dict, Any, List, Tuple, Optional
|
7 |
+
import logging # Import logging
|
8 |
+
import numpy as np # Import numpy for np.zeros
|
9 |
+
|
10 |
+
# Configure logger
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
def calculate_first_stage_f_statistic(df: pd.DataFrame, treatment: str, instruments: List[str], covariates: List[str]) -> Tuple[Optional[float], Optional[float]]:
|
14 |
+
"""
|
15 |
+
Calculates the F-statistic for instrument relevance in the first stage regression.
|
16 |
+
|
17 |
+
Regresses treatment ~ instruments + covariates.
|
18 |
+
Tests the joint significance of the instrument coefficients.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
df: Input DataFrame.
|
22 |
+
treatment: Name of the treatment variable.
|
23 |
+
instruments: List of instrument variable names.
|
24 |
+
covariates: List of covariate names.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
A tuple containing (F-statistic, p-value). Returns (None, None) on error.
|
28 |
+
"""
|
29 |
+
logger.info("Diagnostics: Calculating First-Stage F-statistic...")
|
30 |
+
try:
|
31 |
+
df_copy = df.copy()
|
32 |
+
df_copy['intercept'] = 1
|
33 |
+
exog_vars = ['intercept'] + covariates
|
34 |
+
all_first_stage_exog = list(dict.fromkeys(exog_vars + instruments)) # Ensure unique columns
|
35 |
+
|
36 |
+
endog = df_copy[treatment]
|
37 |
+
exog = df_copy[all_first_stage_exog]
|
38 |
+
|
39 |
+
# Check for perfect multicollinearity before fitting
|
40 |
+
if exog.shape[1] > 1:
|
41 |
+
corr_matrix = exog.corr()
|
42 |
+
# Check if correlation matrix calculation failed (e.g., constant columns) or high correlation
|
43 |
+
if corr_matrix.isnull().values.any() or (corr_matrix.abs() > 0.9999).sum().sum() > exog.shape[1]: # Check off-diagonal elements
|
44 |
+
logger.warning("High multicollinearity or constant column detected in first stage exogenous variables.")
|
45 |
+
# Note: statsmodels OLS might handle perfect collinearity by dropping columns, but F-test might be unreliable.
|
46 |
+
|
47 |
+
first_stage_model = OLS(endog, exog).fit()
|
48 |
+
|
49 |
+
# Construct the restriction matrix (R) to test H0: instrument coeffs = 0
|
50 |
+
num_instruments = len(instruments)
|
51 |
+
if num_instruments == 0:
|
52 |
+
logger.warning("No instruments provided for F-statistic calculation.")
|
53 |
+
return None, None
|
54 |
+
num_exog_total = len(all_first_stage_exog)
|
55 |
+
|
56 |
+
# Ensure instruments are actually in the fitted model's exog names (in case statsmodels dropped some)
|
57 |
+
fitted_exog_names = first_stage_model.model.exog_names
|
58 |
+
valid_instruments = [inst for inst in instruments if inst in fitted_exog_names]
|
59 |
+
if not valid_instruments:
|
60 |
+
logger.error("None of the provided instruments were included in the first-stage regression model (possibly due to collinearity).")
|
61 |
+
return None, None
|
62 |
+
if len(valid_instruments) < len(instruments):
|
63 |
+
logger.warning(f"Instruments dropped by OLS: {set(instruments) - set(valid_instruments)}")
|
64 |
+
|
65 |
+
instrument_indices = [fitted_exog_names.index(inst) for inst in valid_instruments]
|
66 |
+
|
67 |
+
# Need to adjust R matrix size based on fitted model's exog
|
68 |
+
R = np.zeros((len(valid_instruments), len(fitted_exog_names)))
|
69 |
+
for i, idx in enumerate(instrument_indices):
|
70 |
+
R[i, idx] = 1
|
71 |
+
|
72 |
+
# Perform F-test
|
73 |
+
f_test_result = first_stage_model.f_test(R)
|
74 |
+
|
75 |
+
f_statistic = float(f_test_result.fvalue)
|
76 |
+
p_value = float(f_test_result.pvalue)
|
77 |
+
|
78 |
+
logger.info(f" F-statistic: {f_statistic:.4f}, p-value: {p_value:.4f}")
|
79 |
+
return f_statistic, p_value
|
80 |
+
|
81 |
+
except Exception as e:
|
82 |
+
logger.error(f"Error calculating first-stage F-statistic: {e}", exc_info=True)
|
83 |
+
return None, None
|
84 |
+
|
85 |
+
def run_overidentification_test(sm_results: Optional[Any], df: pd.DataFrame, treatment: str, outcome: str, instruments: List[str], covariates: List[str]) -> Tuple[Optional[float], Optional[float], Optional[str]]:
|
86 |
+
"""
|
87 |
+
Runs an overidentification test (Sargan-Hansen) if applicable.
|
88 |
+
|
89 |
+
This test is only valid if the number of instruments exceeds the number
|
90 |
+
of endogenous regressors (typically 1, the treatment variable).
|
91 |
+
|
92 |
+
Requires results from a statsmodels IV estimation.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
sm_results: The fitted results object from statsmodels IV2SLS.fit().
|
96 |
+
df: Input DataFrame.
|
97 |
+
treatment: Name of the treatment variable.
|
98 |
+
outcome: Name of the outcome variable.
|
99 |
+
instruments: List of instrument variable names.
|
100 |
+
covariates: List of covariate names.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Tuple: (test_statistic, p_value, status_message) or (None, None, error_message)
|
104 |
+
"""
|
105 |
+
logger.info("Diagnostics: Running Overidentification Test...")
|
106 |
+
num_instruments = len(instruments)
|
107 |
+
num_endog = 1 # Assuming only one treatment variable is endogenous
|
108 |
+
|
109 |
+
if num_instruments <= num_endog:
|
110 |
+
logger.info(" Over-ID test not applicable (model is exactly identified or underidentified).")
|
111 |
+
return None, None, "Test not applicable (Need more instruments than endogenous regressors)"
|
112 |
+
|
113 |
+
if sm_results is None or not hasattr(sm_results, 'resid'):
|
114 |
+
logger.warning(" Over-ID test requires valid statsmodels results object with residuals.")
|
115 |
+
return None, None, "Statsmodels results object not available or invalid for test."
|
116 |
+
|
117 |
+
try:
|
118 |
+
# Statsmodels IV2SLSResults does not seem to have a direct method for this test (as of common versions).
|
119 |
+
# We need to calculate it manually using residuals and instruments.
|
120 |
+
# Formula: N * R^2 from regressing residuals (u_hat) on all exogenous variables (instruments + covariates).
|
121 |
+
# Degrees of freedom = num_instruments - num_endogenous_vars
|
122 |
+
|
123 |
+
residuals = sm_results.resid
|
124 |
+
df_copy = df.copy()
|
125 |
+
df_copy['intercept'] = 1
|
126 |
+
exog_vars = ['intercept'] + covariates
|
127 |
+
all_exog_instruments = list(dict.fromkeys(exog_vars + instruments))
|
128 |
+
|
129 |
+
# Ensure columns exist in the dataframe before selecting
|
130 |
+
missing_cols = [col for col in all_exog_instruments if col not in df_copy.columns]
|
131 |
+
if missing_cols:
|
132 |
+
raise ValueError(f"Missing columns required for Over-ID test: {missing_cols}")
|
133 |
+
|
134 |
+
exog_for_test = df_copy[all_exog_instruments]
|
135 |
+
|
136 |
+
# Check shapes match after potential NA handling in main estimator
|
137 |
+
if len(residuals) != exog_for_test.shape[0]:
|
138 |
+
# Attempt to align based on index if lengths differ (might happen if NAs were dropped)
|
139 |
+
logger.warning(f"Residual length ({len(residuals)}) differs from exog_for_test rows ({exog_for_test.shape[0]}). Trying to align indices.")
|
140 |
+
common_index = residuals.index.intersection(exog_for_test.index)
|
141 |
+
if len(common_index) == 0:
|
142 |
+
raise ValueError("Cannot align residuals and exogenous variables for Over-ID test after NA handling.")
|
143 |
+
residuals = residuals.loc[common_index]
|
144 |
+
exog_for_test = exog_for_test.loc[common_index]
|
145 |
+
logger.warning(f"Aligned to {len(common_index)} common observations.")
|
146 |
+
|
147 |
+
|
148 |
+
# Regress residuals on all exogenous instruments
|
149 |
+
aux_model = OLS(residuals, exog_for_test).fit()
|
150 |
+
r_squared = aux_model.rsquared
|
151 |
+
n_obs = len(residuals) # Use length of residuals after potential alignment
|
152 |
+
|
153 |
+
test_statistic = n_obs * r_squared
|
154 |
+
|
155 |
+
# Calculate p-value from Chi-squared distribution
|
156 |
+
from scipy.stats import chi2
|
157 |
+
degrees_of_freedom = num_instruments - num_endog
|
158 |
+
if degrees_of_freedom < 0:
|
159 |
+
# This shouldn't happen if the initial check passed, but as a safeguard
|
160 |
+
raise ValueError("Degrees of freedom for Sargan test are negative.")
|
161 |
+
elif degrees_of_freedom == 0:
|
162 |
+
# R-squared should be 0 if exactly identified, but handle edge case
|
163 |
+
p_value = 1.0 if np.isclose(test_statistic, 0) else 0.0
|
164 |
+
else:
|
165 |
+
p_value = chi2.sf(test_statistic, degrees_of_freedom)
|
166 |
+
|
167 |
+
logger.info(f" Sargan Test Statistic: {test_statistic:.4f}, p-value: {p_value:.4f}, df: {degrees_of_freedom}")
|
168 |
+
return test_statistic, p_value, "Test successful"
|
169 |
+
|
170 |
+
except Exception as e:
|
171 |
+
logger.error(f"Error running overidentification test: {e}", exc_info=True)
|
172 |
+
return None, None, f"Error during test: {e}"
|
173 |
+
|
174 |
+
def run_iv_diagnostics(df: pd.DataFrame, treatment: str, outcome: str, instruments: List[str], covariates: List[str], sm_results: Optional[Any] = None, dw_results: Optional[Any] = None) -> Dict[str, Any]:
|
175 |
+
"""
|
176 |
+
Runs standard IV diagnostic checks.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
df: Input DataFrame.
|
180 |
+
treatment: Name of the treatment variable.
|
181 |
+
outcome: Name of the outcome variable.
|
182 |
+
instruments: List of instrument variable names.
|
183 |
+
covariates: List of covariate names.
|
184 |
+
sm_results: Optional fitted results object from statsmodels IV2SLS.fit().
|
185 |
+
dw_results: Optional results object from DoWhy (structure may vary).
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
Dictionary containing diagnostic results.
|
189 |
+
"""
|
190 |
+
diagnostics = {}
|
191 |
+
|
192 |
+
# 1. Instrument Relevance / Weak Instrument Test (First-Stage F-statistic)
|
193 |
+
f_stat, f_p_val = calculate_first_stage_f_statistic(df, treatment, instruments, covariates)
|
194 |
+
diagnostics['first_stage_f_statistic'] = f_stat
|
195 |
+
diagnostics['first_stage_p_value'] = f_p_val
|
196 |
+
diagnostics['is_instrument_weak'] = (f_stat < 10) if f_stat is not None else None # Common rule of thumb
|
197 |
+
if f_stat is None:
|
198 |
+
diagnostics['weak_instrument_test_status'] = "Error during calculation"
|
199 |
+
elif diagnostics['is_instrument_weak']:
|
200 |
+
diagnostics['weak_instrument_test_status'] = "Warning: Instrument(s) may be weak (F < 10)"
|
201 |
+
else:
|
202 |
+
diagnostics['weak_instrument_test_status'] = "Instrument(s) appear sufficiently strong (F >= 10)"
|
203 |
+
|
204 |
+
|
205 |
+
# 2. Overidentification Test (e.g., Sargan-Hansen)
|
206 |
+
overid_stat, overid_p_val, overid_status = run_overidentification_test(sm_results, df, treatment, outcome, instruments, covariates)
|
207 |
+
diagnostics['overid_test_statistic'] = overid_stat
|
208 |
+
diagnostics['overid_test_p_value'] = overid_p_val
|
209 |
+
diagnostics['overid_test_status'] = overid_status
|
210 |
+
diagnostics['overid_test_applicable'] = not ("not applicable" in overid_status.lower() if overid_status else True)
|
211 |
+
|
212 |
+
# 3. Exogeneity/Exclusion Restriction (Conceptual Check)
|
213 |
+
diagnostics['exclusion_restriction_assumption'] = "Assumed based on graph/input; cannot be statistically tested directly. Qualitative LLM check recommended."
|
214 |
+
|
215 |
+
# Potential future additions:
|
216 |
+
# - Endogeneity tests (e.g., Hausman test - requires comparing OLS and IV estimates)
|
217 |
+
|
218 |
+
return diagnostics
|
auto_causal/methods/instrumental_variable/estimator.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import statsmodels.api as sm
|
3 |
+
from statsmodels.sandbox.regression.gmm import IV2SLS
|
4 |
+
from dowhy import CausalModel # Primary path
|
5 |
+
from typing import Dict, Any, List, Union, Optional
|
6 |
+
import logging
|
7 |
+
from langchain.chat_models.base import BaseChatModel
|
8 |
+
|
9 |
+
from .diagnostics import run_iv_diagnostics
|
10 |
+
from .llm_assist import identify_instrument_variable, validate_instrument_assumptions_qualitative, interpret_iv_results
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
def build_iv_graph_gml(treatment: str, outcome: str, instruments: List[str], covariates: List[str]) -> str:
|
15 |
+
"""
|
16 |
+
Constructs a GML string representing the causal graph for IV.
|
17 |
+
|
18 |
+
Assumptions:
|
19 |
+
- Instruments cause Treatment
|
20 |
+
- Covariates cause Treatment and Outcome
|
21 |
+
- Treatment causes Outcome
|
22 |
+
- Instruments do NOT directly cause Outcome (Exclusion)
|
23 |
+
- Instruments are NOT caused by Covariates (can be relaxed if needed)
|
24 |
+
- Unobserved Confounder (U) affects Treatment and Outcome
|
25 |
+
|
26 |
+
Args:
|
27 |
+
treatment: Name of the treatment variable.
|
28 |
+
outcome: Name of the outcome variable.
|
29 |
+
instruments: List of instrument variable names.
|
30 |
+
covariates: List of covariate names.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
A GML graph string.
|
34 |
+
"""
|
35 |
+
nodes = []
|
36 |
+
edges = []
|
37 |
+
|
38 |
+
# Define nodes - ensure no duplicates if a variable is both instrument and covariate (SHOULD NOT HAPPEN)
|
39 |
+
# Use a set to ensure unique variable names
|
40 |
+
all_vars_set = set([treatment, outcome] + instruments + covariates + ['U'])
|
41 |
+
all_vars = list(all_vars_set)
|
42 |
+
|
43 |
+
for var in all_vars:
|
44 |
+
nodes.append(f'node [ id "{var}" label "{var}" ]')
|
45 |
+
|
46 |
+
# Define edges
|
47 |
+
# Instruments -> Treatment
|
48 |
+
for inst in instruments:
|
49 |
+
edges.append(f'edge [ source "{inst}" target "{treatment}" ]')
|
50 |
+
|
51 |
+
# Covariates -> Treatment
|
52 |
+
for cov in covariates:
|
53 |
+
# Ensure we don't add self-loops or duplicate edges if cov == treatment (shouldn't happen)
|
54 |
+
if cov != treatment:
|
55 |
+
edges.append(f'edge [ source "{cov}" target "{treatment}" ]')
|
56 |
+
|
57 |
+
# Covariates -> Outcome
|
58 |
+
for cov in covariates:
|
59 |
+
if cov != outcome:
|
60 |
+
edges.append(f'edge [ source "{cov}" target "{outcome}" ]')
|
61 |
+
|
62 |
+
# Treatment -> Outcome
|
63 |
+
edges.append(f'edge [ source "{treatment}" target "{outcome}" ]')
|
64 |
+
|
65 |
+
# Unobserved Confounder -> Treatment and Outcome
|
66 |
+
edges.append(f'edge [ source "U" target "{treatment}" ]')
|
67 |
+
edges.append(f'edge [ source "U" target "{outcome}" ]')
|
68 |
+
|
69 |
+
# Core IV Assumption: Instruments are NOT caused by U (implicitly handled by not adding edge)
|
70 |
+
# Core IV Assumption: Instruments do NOT directly cause Outcome (handled by not adding edge)
|
71 |
+
|
72 |
+
# Format nodes and edges with indentation before inserting into f-string
|
73 |
+
formatted_nodes = '\n '.join(nodes)
|
74 |
+
formatted_edges = '\n '.join(edges)
|
75 |
+
|
76 |
+
gml_string = f"""
|
77 |
+
graph [
|
78 |
+
directed 1
|
79 |
+
{formatted_nodes}
|
80 |
+
{formatted_edges}
|
81 |
+
]
|
82 |
+
"""
|
83 |
+
# Convert print to logger
|
84 |
+
logger.debug("\n--- Generated GML Graph ---")
|
85 |
+
logger.debug(gml_string)
|
86 |
+
logger.debug("-------------------------\n")
|
87 |
+
return gml_string
|
88 |
+
|
89 |
+
def format_iv_results(estimate: Optional[float], raw_results: Dict, diagnostics: Dict, treatment: str, outcome: str, instrument: List[str], method_used: str, llm: Optional[BaseChatModel] = None) -> Dict[str, Any]:
|
90 |
+
"""
|
91 |
+
Formats the results from IV estimation into a standardized dictionary.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
estimate: The point estimate of the causal effect.
|
95 |
+
raw_results: Dictionary containing raw outputs from DoWhy/statsmodels.
|
96 |
+
diagnostics: Dictionary containing diagnostic results.
|
97 |
+
treatment: Name of the treatment variable.
|
98 |
+
outcome: Name of the outcome variable.
|
99 |
+
instrument: List of instrument variable names.
|
100 |
+
method_used: 'dowhy' or 'statsmodels'.
|
101 |
+
llm: Optional LLM instance for interpretation.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
Standardized results dictionary.
|
105 |
+
"""
|
106 |
+
formatted = {
|
107 |
+
"effect_estimate": estimate,
|
108 |
+
"treatment_variable": treatment,
|
109 |
+
"outcome_variable": outcome,
|
110 |
+
"instrument_variables": instrument,
|
111 |
+
"method_used": method_used,
|
112 |
+
"diagnostics": diagnostics,
|
113 |
+
"raw_results": {k: str(v) for k, v in raw_results.items() if "object" not in k}, # Avoid serializing large objects
|
114 |
+
"confidence_interval": None,
|
115 |
+
"standard_error": None,
|
116 |
+
"p_value": None,
|
117 |
+
"interpretation": "Placeholder"
|
118 |
+
}
|
119 |
+
|
120 |
+
# Extract details from statsmodels results if available
|
121 |
+
sm_results = raw_results.get('statsmodels_results_object')
|
122 |
+
if method_used == 'statsmodels' and sm_results:
|
123 |
+
try:
|
124 |
+
# Use .bse for standard error in statsmodels results
|
125 |
+
formatted["standard_error"] = float(sm_results.bse[treatment])
|
126 |
+
formatted["p_value"] = float(sm_results.pvalues[treatment])
|
127 |
+
conf_int = sm_results.conf_int().loc[treatment].tolist()
|
128 |
+
formatted["confidence_interval"] = [float(ci) for ci in conf_int]
|
129 |
+
except AttributeError as e:
|
130 |
+
logger.warning(f"Could not extract all details from statsmodels results object (likely missing attribute): {e}")
|
131 |
+
except Exception as e:
|
132 |
+
logger.warning(f"Error extracting details from statsmodels results: {e}")
|
133 |
+
|
134 |
+
# Extract details from DoWhy results if available
|
135 |
+
# Note: DoWhy's CausalEstimate object structure needs inspection
|
136 |
+
dw_results = raw_results.get('dowhy_results_object')
|
137 |
+
if method_used == 'dowhy' and dw_results:
|
138 |
+
try:
|
139 |
+
# Attempt common attributes, may need adjustment based on DoWhy version/output
|
140 |
+
if hasattr(dw_results, 'stderr'):
|
141 |
+
formatted["standard_error"] = float(dw_results.stderr)
|
142 |
+
if hasattr(dw_results, 'p_value'):
|
143 |
+
formatted["p_value"] = float(dw_results.p_value)
|
144 |
+
if hasattr(dw_results, 'conf_intervals'):
|
145 |
+
# Assuming it's stored similarly to statsmodels, might need adjustment
|
146 |
+
ci = dw_results.conf_intervals().loc[treatment].tolist() # Fictional attribute/method - check DoWhy docs!
|
147 |
+
formatted["confidence_interval"] = [float(c) for c in ci]
|
148 |
+
elif hasattr(dw_results, 'get_confidence_intervals'):
|
149 |
+
ci = dw_results.get_confidence_intervals() # Check DoWhy docs for format
|
150 |
+
# Check format of ci before converting
|
151 |
+
if isinstance(ci, (list, tuple)) and len(ci) == 2:
|
152 |
+
formatted["confidence_interval"] = [float(c) for c in ci] # Adapt parsing
|
153 |
+
else:
|
154 |
+
logger.warning(f"Could not parse confidence intervals from DoWhy object: {ci}")
|
155 |
+
|
156 |
+
except Exception as e:
|
157 |
+
logger.warning(f"Could not extract all details from DoWhy results: {e}. Structure might be different.", exc_info=True)
|
158 |
+
# Avoid printing dir in production code, use logger.debug if needed for dev
|
159 |
+
# logger.debug(f"DoWhy result object dir(): {dir(dw_results)}")
|
160 |
+
|
161 |
+
# Generate LLM interpretation - pass llm object
|
162 |
+
if estimate is not None:
|
163 |
+
formatted["interpretation"] = interpret_iv_results(formatted, diagnostics, llm=llm)
|
164 |
+
else:
|
165 |
+
formatted["interpretation"] = "Estimation failed, cannot interpret results."
|
166 |
+
|
167 |
+
|
168 |
+
return formatted
|
169 |
+
|
170 |
+
def estimate_effect(
|
171 |
+
df: pd.DataFrame,
|
172 |
+
treatment: str,
|
173 |
+
outcome: str,
|
174 |
+
covariates: List[str],
|
175 |
+
query: Optional[str] = None,
|
176 |
+
dataset_description: Optional[str] = None,
|
177 |
+
llm: Optional[BaseChatModel] = None,
|
178 |
+
**kwargs
|
179 |
+
) -> Dict[str, Any]:
|
180 |
+
|
181 |
+
instrument = kwargs.get('instrument_variable')
|
182 |
+
if not instrument:
|
183 |
+
return {"error": "Instrument variable ('instrument_variable') not found in kwargs.", "method_used": "none", "diagnostics": {}}
|
184 |
+
|
185 |
+
instrument_list = [instrument] if isinstance(instrument, str) else instrument
|
186 |
+
valid_instruments = [inst for inst in instrument_list if isinstance(inst, str)]
|
187 |
+
clean_covariates = [cov for cov in covariates if cov not in valid_instruments]
|
188 |
+
|
189 |
+
logger.info(f"\n--- Starting Instrumental Variable Estimation ---")
|
190 |
+
logger.info(f"Treatment: {treatment}, Outcome: {outcome}, Instrument(s): {valid_instruments}, Original Covariates: {covariates}, Cleaned Covariates: {clean_covariates}")
|
191 |
+
results = {}
|
192 |
+
method_used = "none"
|
193 |
+
sm_results_obj = None
|
194 |
+
dw_results_obj = None
|
195 |
+
identified_estimand = None # Initialize
|
196 |
+
model = None # Initialize
|
197 |
+
refutation_results = {} # Initialize
|
198 |
+
|
199 |
+
# --- Input Validation ---
|
200 |
+
required_cols = [treatment, outcome] + valid_instruments + clean_covariates
|
201 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
202 |
+
if missing_cols:
|
203 |
+
return {"error": f"Missing required columns in DataFrame: {missing_cols}", "method_used": method_used, "diagnostics": {}}
|
204 |
+
if not valid_instruments:
|
205 |
+
return {"error": "Instrument variable(s) must be provided and valid.", "method_used": method_used, "diagnostics": {}}
|
206 |
+
|
207 |
+
# --- LLM Pre-Checks ---
|
208 |
+
if query and llm:
|
209 |
+
qualitative_check = validate_instrument_assumptions_qualitative(treatment, outcome, valid_instruments, clean_covariates, query, llm=llm)
|
210 |
+
results['llm_assumption_check'] = qualitative_check
|
211 |
+
logger.info(f"LLM Qualitative Assumption Check: {qualitative_check}")
|
212 |
+
|
213 |
+
# --- Build Graph and Instantiate CausalModel (Do this before estimation attempts) ---
|
214 |
+
# This allows using identify_effect and refute_estimate even if DoWhy estimation fails
|
215 |
+
try:
|
216 |
+
graph = build_iv_graph_gml(treatment, outcome, valid_instruments, clean_covariates)
|
217 |
+
if not graph:
|
218 |
+
raise ValueError("Failed to build GML graph for DoWhy.")
|
219 |
+
|
220 |
+
model = CausalModel(data=df, treatment=treatment, outcome=outcome, graph=graph)
|
221 |
+
|
222 |
+
# Identify Effect (essential for refutation later)
|
223 |
+
identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
|
224 |
+
logger.debug("\nDoWhy Identified Estimand:")
|
225 |
+
logger.debug(identified_estimand)
|
226 |
+
if not identified_estimand:
|
227 |
+
raise ValueError("DoWhy could not identify a valid estimand.")
|
228 |
+
|
229 |
+
except Exception as model_init_e:
|
230 |
+
logger.error(f"Failed to initialize CausalModel or identify effect: {model_init_e}", exc_info=True)
|
231 |
+
# Cannot proceed without model/estimand for DoWhy or refutation
|
232 |
+
results['error'] = f"Failed to initialize CausalModel: {model_init_e}"
|
233 |
+
# Attempt statsmodels anyway? Or return error? Let's try statsmodels.
|
234 |
+
pass # Allow falling through to statsmodels if desired
|
235 |
+
|
236 |
+
# --- Primary Path: DoWhy Estimation ---
|
237 |
+
if model and identified_estimand and not kwargs.get('force_statsmodels', False):
|
238 |
+
logger.info("\nAttempting estimation with DoWhy...")
|
239 |
+
try:
|
240 |
+
dw_results_obj = model.estimate_effect(
|
241 |
+
identified_estimand,
|
242 |
+
method_name="iv.instrumental_variable",
|
243 |
+
method_params={'iv_instrument_name': valid_instruments}
|
244 |
+
)
|
245 |
+
logger.debug("\nDoWhy Estimation Result:")
|
246 |
+
logger.debug(dw_results_obj)
|
247 |
+
results['dowhy_estimate'] = dw_results_obj.value
|
248 |
+
results['dowhy_results_object'] = dw_results_obj
|
249 |
+
method_used = 'dowhy'
|
250 |
+
logger.info("DoWhy estimation successful.")
|
251 |
+
except Exception as e:
|
252 |
+
logger.error(f"DoWhy IV estimation failed: {e}", exc_info=True)
|
253 |
+
results['dowhy_error'] = str(e)
|
254 |
+
if not kwargs.get('allow_fallback', True):
|
255 |
+
logger.warning("Fallback to statsmodels disabled. Estimation failed.")
|
256 |
+
method_used = "dowhy_failed"
|
257 |
+
# Still run diagnostics and format output
|
258 |
+
else:
|
259 |
+
logger.info("Proceeding to statsmodels fallback.")
|
260 |
+
elif not model or not identified_estimand:
|
261 |
+
logger.warning("Skipping DoWhy estimation due to CausalModel initialization/identification failure.")
|
262 |
+
# Ensure we proceed to statsmodels if fallback is allowed
|
263 |
+
if not kwargs.get('allow_fallback', True):
|
264 |
+
logger.error("Cannot estimate effect: CausalModel failed and fallback disabled.")
|
265 |
+
method_used = "dowhy_failed"
|
266 |
+
else:
|
267 |
+
logger.info("Proceeding to statsmodels fallback.")
|
268 |
+
|
269 |
+
# --- Fallback Path: statsmodels IV2SLS ---
|
270 |
+
if method_used not in ['dowhy', 'dowhy_failed']:
|
271 |
+
logger.info("\nAttempting estimation with statsmodels IV2SLS...")
|
272 |
+
try:
|
273 |
+
df_copy = df.copy().dropna(subset=required_cols)
|
274 |
+
if df_copy.empty:
|
275 |
+
raise ValueError("DataFrame becomes empty after dropping NAs in required columns.")
|
276 |
+
df_copy['intercept'] = 1
|
277 |
+
exog_regressors = ['intercept'] + clean_covariates
|
278 |
+
endog_var = treatment
|
279 |
+
all_instruments_for_sm = list(dict.fromkeys(exog_regressors + valid_instruments))
|
280 |
+
endog_data = df_copy[outcome]
|
281 |
+
exog_data_sm_cols = list(dict.fromkeys(exog_regressors + [endog_var]))
|
282 |
+
exog_data_sm = df_copy[exog_data_sm_cols]
|
283 |
+
instrument_data_sm = df_copy[all_instruments_for_sm]
|
284 |
+
num_endog = 1
|
285 |
+
num_external_iv = len(valid_instruments)
|
286 |
+
if num_endog > num_external_iv:
|
287 |
+
raise ValueError(f"Model underidentified: More endogenous regressors ({num_endog}) than unique external instruments ({num_external_iv}).")
|
288 |
+
iv_model = IV2SLS(endog=endog_data, exog=exog_data_sm, instrument=instrument_data_sm)
|
289 |
+
sm_results_obj = iv_model.fit()
|
290 |
+
logger.info("\nStatsmodels Estimation Summary:")
|
291 |
+
logger.info(f" Estimate for {treatment}: {sm_results_obj.params[treatment]}")
|
292 |
+
logger.info(f" Std Error: {sm_results_obj.bse[treatment]}")
|
293 |
+
logger.info(f" P-value: {sm_results_obj.pvalues[treatment]}")
|
294 |
+
results['statsmodels_estimate'] = sm_results_obj.params[treatment]
|
295 |
+
results['statsmodels_results_object'] = sm_results_obj
|
296 |
+
method_used = 'statsmodels'
|
297 |
+
logger.info("Statsmodels estimation successful.")
|
298 |
+
except Exception as sm_e:
|
299 |
+
logger.error(f"Statsmodels IV estimation also failed: {sm_e}", exc_info=True)
|
300 |
+
results['statsmodels_error'] = str(sm_e)
|
301 |
+
method_used = 'statsmodels_failed' if method_used == "none" else "dowhy_failed_sm_failed"
|
302 |
+
|
303 |
+
# --- Diagnostics ---
|
304 |
+
logger.info("\nRunning diagnostics...")
|
305 |
+
diagnostics = run_iv_diagnostics(df, treatment, outcome, valid_instruments, clean_covariates, sm_results_obj, dw_results_obj)
|
306 |
+
results['diagnostics'] = diagnostics
|
307 |
+
|
308 |
+
# --- Refutation Step ---
|
309 |
+
final_estimate_value = results.get('dowhy_estimate') if method_used == 'dowhy' else results.get('statsmodels_estimate')
|
310 |
+
|
311 |
+
# Only run permute refuter if estimate is valid AND came from DoWhy
|
312 |
+
if method_used == 'dowhy' and dw_results_obj and final_estimate_value is not None:
|
313 |
+
logger.info("\nRunning refutation test (Placebo Treatment - Permute - requires DoWhy estimate object)...")
|
314 |
+
try:
|
315 |
+
# Pass the actual DoWhy estimate object
|
316 |
+
refuter_result = model.refute_estimate(
|
317 |
+
identified_estimand,
|
318 |
+
dw_results_obj, # Pass the original DoWhy result object
|
319 |
+
method_name="placebo_treatment_refuter",
|
320 |
+
placebo_type="permute" # Necessary for IV according to docs/examples
|
321 |
+
)
|
322 |
+
logger.info("Refutation test completed.")
|
323 |
+
logger.debug(f"Refuter Result:\n{refuter_result}")
|
324 |
+
# Store relevant info from refuter_result (check its structure)
|
325 |
+
refutation_results = {
|
326 |
+
"refuter": "placebo_treatment_refuter",
|
327 |
+
"new_effect": getattr(refuter_result, 'new_effect', 'N/A'),
|
328 |
+
"p_value": getattr(refuter_result, 'refutation_result', {}).get('p_value', 'N/A') if hasattr(refuter_result, 'refutation_result') else 'N/A',
|
329 |
+
# Passed if p-value > 0.05 (or not statistically significant)
|
330 |
+
"passed": getattr(refuter_result, 'refutation_result', {}).get('is_statistically_significant', None) == False if hasattr(refuter_result, 'refutation_result') else None
|
331 |
+
}
|
332 |
+
except Exception as refute_e:
|
333 |
+
logger.error(f"Refutation test failed: {refute_e}", exc_info=True)
|
334 |
+
refutation_results = {"error": f"Refutation failed: {refute_e}"}
|
335 |
+
|
336 |
+
elif final_estimate_value is not None and method_used == 'statsmodels':
|
337 |
+
logger.warning("Skipping placebo permutation refuter: Estimate was generated by statsmodels, not DoWhy's IV estimator.")
|
338 |
+
refutation_results = {"status": "skipped_wrong_estimator_for_permute"}
|
339 |
+
|
340 |
+
elif final_estimate_value is None:
|
341 |
+
logger.warning("Skipping refutation test because estimation failed.")
|
342 |
+
refutation_results = {"status": "skipped_due_to_failed_estimation"}
|
343 |
+
|
344 |
+
else: # Model or estimand failed earlier, or unknown method_used
|
345 |
+
logger.warning(f"Skipping refutation test due to earlier failure (method_used: {method_used}).")
|
346 |
+
refutation_results = {"status": "skipped_due_to_model_failure_or_unknown"}
|
347 |
+
|
348 |
+
results['refutation_results'] = refutation_results # Add to main results
|
349 |
+
|
350 |
+
# --- Formatting Results ---
|
351 |
+
if final_estimate_value is None and method_used not in ['dowhy', 'statsmodels']:
|
352 |
+
logger.error("ERROR: Both estimation methods failed.")
|
353 |
+
# Ensure error key exists if not set earlier
|
354 |
+
if 'error' not in results:
|
355 |
+
results['error'] = "Both DoWhy and statsmodels IV estimation failed."
|
356 |
+
|
357 |
+
logger.info("\n--- Formatting Final Results ---")
|
358 |
+
formatted_results = format_iv_results(
|
359 |
+
final_estimate_value, # Pass the numeric value
|
360 |
+
results, # Pass the dict containing estimate objects and refutation results
|
361 |
+
diagnostics,
|
362 |
+
treatment,
|
363 |
+
outcome,
|
364 |
+
valid_instruments,
|
365 |
+
method_used,
|
366 |
+
llm=llm
|
367 |
+
)
|
368 |
+
|
369 |
+
logger.info("--- Instrumental Variable Estimation Complete ---\n")
|
370 |
+
return formatted_results
|
auto_causal/methods/instrumental_variable/llm_assist.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM assistance functions for Instrumental Variable (IV) analysis.
|
3 |
+
|
4 |
+
This module provides functions for LLM-based assistance in instrumental variable analysis,
|
5 |
+
including identifying potential instruments, validating IV assumptions, and interpreting results.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import List, Dict, Any, Optional
|
9 |
+
import logging
|
10 |
+
|
11 |
+
# Imported for type hinting
|
12 |
+
from langchain.chat_models.base import BaseChatModel
|
13 |
+
|
14 |
+
# Import shared LLM helpers
|
15 |
+
from auto_causal.utils.llm_helpers import call_llm_with_json_output
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
def identify_instrument_variable(
|
20 |
+
df_cols: List[str],
|
21 |
+
query: str,
|
22 |
+
llm: Optional[BaseChatModel] = None
|
23 |
+
) -> List[str]:
|
24 |
+
"""
|
25 |
+
Use LLM to identify potential instrumental variables from available columns.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
df_cols: List of column names from the dataset
|
29 |
+
query: User's causal query text
|
30 |
+
llm: Optional LLM model instance
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
List of column names identified as potential instruments
|
34 |
+
"""
|
35 |
+
if llm is None:
|
36 |
+
logger.warning("No LLM provided for instrument identification")
|
37 |
+
return []
|
38 |
+
|
39 |
+
prompt = f"""
|
40 |
+
You are assisting with an instrumental variable analysis.
|
41 |
+
|
42 |
+
Available columns in the dataset: {df_cols}
|
43 |
+
User query: {query}
|
44 |
+
|
45 |
+
Identify potential instrumental variable(s) from the available columns based on the query.
|
46 |
+
The treatment and outcome should NOT be included as instruments.
|
47 |
+
|
48 |
+
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
|
49 |
+
{{
|
50 |
+
"potential_instruments": ["column_name1", "column_name2", ...]
|
51 |
+
}}
|
52 |
+
"""
|
53 |
+
|
54 |
+
response = call_llm_with_json_output(llm, prompt)
|
55 |
+
|
56 |
+
if response and "potential_instruments" in response and isinstance(response["potential_instruments"], list):
|
57 |
+
# Basic validation: ensure items are strings (column names)
|
58 |
+
valid_instruments = [item for item in response["potential_instruments"] if isinstance(item, str)]
|
59 |
+
if len(valid_instruments) != len(response["potential_instruments"]):
|
60 |
+
logger.warning("LLM returned non-string items in potential_instruments list.")
|
61 |
+
return valid_instruments
|
62 |
+
|
63 |
+
logger.warning(f"Failed to get valid instrument recommendations from LLM. Response: {response}")
|
64 |
+
return []
|
65 |
+
|
66 |
+
def validate_instrument_assumptions_qualitative(
|
67 |
+
treatment: str,
|
68 |
+
outcome: str,
|
69 |
+
instrument: List[str],
|
70 |
+
covariates: List[str],
|
71 |
+
query: str,
|
72 |
+
llm: Optional[BaseChatModel] = None
|
73 |
+
) -> Dict[str, str]:
|
74 |
+
"""
|
75 |
+
Use LLM to provide qualitative assessment of IV assumptions.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
treatment: Treatment variable name
|
79 |
+
outcome: Outcome variable name
|
80 |
+
instrument: List of instrumental variable names
|
81 |
+
covariates: List of covariate variable names
|
82 |
+
query: User's causal query text
|
83 |
+
llm: Optional LLM model instance
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
Dictionary with qualitative assessments of exclusion and exogeneity assumptions
|
87 |
+
"""
|
88 |
+
default_fail = {
|
89 |
+
"exclusion_assessment": "LLM Check Failed",
|
90 |
+
"exogeneity_assessment": "LLM Check Failed"
|
91 |
+
}
|
92 |
+
|
93 |
+
if llm is None:
|
94 |
+
return {
|
95 |
+
"exclusion_assessment": "LLM Not Provided",
|
96 |
+
"exogeneity_assessment": "LLM Not Provided"
|
97 |
+
}
|
98 |
+
|
99 |
+
prompt = f"""
|
100 |
+
You are assisting with assessing the validity of instrumental variable assumptions.
|
101 |
+
|
102 |
+
Treatment variable: {treatment}
|
103 |
+
Outcome variable: {outcome}
|
104 |
+
Instrumental variable(s): {instrument}
|
105 |
+
Covariates: {covariates}
|
106 |
+
User query: {query}
|
107 |
+
|
108 |
+
Assess the core Instrumental Variable (IV) assumptions based *only* on the provided variable names and query context:
|
109 |
+
1. Exclusion restriction: Plausibility that the instrument(s) affect the outcome ONLY through the treatment.
|
110 |
+
2. Exogeneity (also called Independence): Plausibility that the instrument(s) are not correlated with unobserved confounders that also affect the outcome.
|
111 |
+
|
112 |
+
Provide a brief, qualitative assessment (e.g., 'Plausible', 'Unlikely', 'Requires Domain Knowledge', 'Potentially Violated').
|
113 |
+
|
114 |
+
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
|
115 |
+
{{
|
116 |
+
"exclusion_assessment": "<brief assessment of exclusion restriction>",
|
117 |
+
"exogeneity_assessment": "<brief assessment of exogeneity assumption>"
|
118 |
+
}}
|
119 |
+
"""
|
120 |
+
|
121 |
+
response = call_llm_with_json_output(llm, prompt)
|
122 |
+
|
123 |
+
if response and isinstance(response, dict) and \
|
124 |
+
"exclusion_assessment" in response and isinstance(response["exclusion_assessment"], str) and \
|
125 |
+
"exogeneity_assessment" in response and isinstance(response["exogeneity_assessment"], str):
|
126 |
+
return response
|
127 |
+
|
128 |
+
logger.warning(f"Failed to get valid assumption assessment from LLM. Response: {response}")
|
129 |
+
return default_fail
|
130 |
+
|
131 |
+
def interpret_iv_results(
|
132 |
+
results: Dict[str, Any],
|
133 |
+
diagnostics: Dict[str, Any],
|
134 |
+
llm: Optional[BaseChatModel] = None
|
135 |
+
) -> str:
|
136 |
+
"""
|
137 |
+
Use LLM to interpret IV results in natural language.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
results: Dictionary of estimation results (e.g., effect_estimate, p_value, confidence_interval)
|
141 |
+
diagnostics: Dictionary of diagnostic test results (e.g., first_stage_f_statistic, overid_test)
|
142 |
+
llm: Optional LLM model instance
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
String containing natural language interpretation of results
|
146 |
+
"""
|
147 |
+
if llm is None:
|
148 |
+
return "LLM was not available to provide interpretation. Please review the numeric results manually."
|
149 |
+
|
150 |
+
# Construct a concise summary of inputs for the prompt
|
151 |
+
results_summary = {}
|
152 |
+
|
153 |
+
effect = results.get('effect_estimate')
|
154 |
+
if effect is not None:
|
155 |
+
try:
|
156 |
+
results_summary['Effect Estimate'] = f"{float(effect):.3f}"
|
157 |
+
except (ValueError, TypeError):
|
158 |
+
results_summary['Effect Estimate'] = 'N/A (Invalid Format)'
|
159 |
+
else:
|
160 |
+
results_summary['Effect Estimate'] = 'N/A'
|
161 |
+
|
162 |
+
p_value = results.get('p_value')
|
163 |
+
if p_value is not None:
|
164 |
+
try:
|
165 |
+
results_summary['P-value'] = f"{float(p_value):.3f}"
|
166 |
+
except (ValueError, TypeError):
|
167 |
+
results_summary['P-value'] = 'N/A (Invalid Format)'
|
168 |
+
else:
|
169 |
+
results_summary['P-value'] = 'N/A'
|
170 |
+
|
171 |
+
ci = results.get('confidence_interval')
|
172 |
+
if ci is not None and isinstance(ci, (list, tuple)) and len(ci) == 2:
|
173 |
+
try:
|
174 |
+
results_summary['Confidence Interval'] = f"[{float(ci[0]):.3f}, {float(ci[1]):.3f}]"
|
175 |
+
except (ValueError, TypeError):
|
176 |
+
results_summary['Confidence Interval'] = 'N/A (Invalid Format)'
|
177 |
+
else:
|
178 |
+
# Handle cases where CI is None or not a 2-element list/tuple
|
179 |
+
results_summary['Confidence Interval'] = str(ci) if ci is not None else 'N/A'
|
180 |
+
|
181 |
+
if 'treatment_variable' in results:
|
182 |
+
results_summary['Treatment'] = results['treatment_variable']
|
183 |
+
if 'outcome_variable' in results:
|
184 |
+
results_summary['Outcome'] = results['outcome_variable']
|
185 |
+
|
186 |
+
diagnostics_summary = {}
|
187 |
+
f_stat = diagnostics.get('first_stage_f_statistic')
|
188 |
+
if f_stat is not None:
|
189 |
+
try:
|
190 |
+
diagnostics_summary['First-Stage F-statistic'] = f"{float(f_stat):.2f}"
|
191 |
+
except (ValueError, TypeError):
|
192 |
+
diagnostics_summary['First-Stage F-statistic'] = 'N/A (Invalid Format)'
|
193 |
+
else:
|
194 |
+
diagnostics_summary['First-Stage F-statistic'] = 'N/A'
|
195 |
+
|
196 |
+
if 'weak_instrument_test_status' in diagnostics:
|
197 |
+
diagnostics_summary['Weak Instrument Test'] = diagnostics['weak_instrument_test_status']
|
198 |
+
|
199 |
+
overid_p = diagnostics.get('overid_test_p_value')
|
200 |
+
if overid_p is not None:
|
201 |
+
try:
|
202 |
+
diagnostics_summary['Overidentification Test P-value'] = f"{float(overid_p):.3f}"
|
203 |
+
diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A')
|
204 |
+
except (ValueError, TypeError):
|
205 |
+
diagnostics_summary['Overidentification Test P-value'] = 'N/A (Invalid Format)'
|
206 |
+
diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A')
|
207 |
+
else:
|
208 |
+
# Explicitly state if not applicable or not available
|
209 |
+
if diagnostics.get('overid_test_applicable') == False:
|
210 |
+
diagnostics_summary['Overidentification Test'] = 'Not Applicable'
|
211 |
+
else:
|
212 |
+
diagnostics_summary['Overidentification Test P-value'] = 'N/A'
|
213 |
+
diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A')
|
214 |
+
|
215 |
+
prompt = f"""
|
216 |
+
You are assisting with interpreting instrumental variable (IV) analysis results.
|
217 |
+
|
218 |
+
Estimation results summary: {results_summary}
|
219 |
+
Diagnostic test results summary: {diagnostics_summary}
|
220 |
+
|
221 |
+
Explain these Instrumental Variable (IV) results in clear, concise language (2-4 sentences).
|
222 |
+
Focus on:
|
223 |
+
1. The estimated causal effect (magnitude, direction, statistical significance based on p-value < 0.05).
|
224 |
+
2. The strength of the instrument(s) (based on F-statistic, typically > 10 indicates strength).
|
225 |
+
3. Any implications from other diagnostic tests (e.g., overidentification test suggesting instrument validity issues if p < 0.05).
|
226 |
+
|
227 |
+
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
|
228 |
+
{{
|
229 |
+
"interpretation": "<your concise interpretation text>"
|
230 |
+
}}
|
231 |
+
"""
|
232 |
+
|
233 |
+
response = call_llm_with_json_output(llm, prompt)
|
234 |
+
|
235 |
+
if response and isinstance(response, dict) and \
|
236 |
+
"interpretation" in response and isinstance(response["interpretation"], str):
|
237 |
+
return response["interpretation"]
|
238 |
+
|
239 |
+
logger.warning(f"Failed to get valid interpretation from LLM. Response: {response}")
|
240 |
+
return "LLM interpretation could not be generated. Please review the numeric results manually."
|
auto_causal/methods/linear_regression/__init__.py
ADDED
File without changes
|
auto_causal/methods/linear_regression/diagnostics.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Diagnostic checks for Linear Regression models.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Dict, Any
|
6 |
+
import statsmodels.api as sm
|
7 |
+
from statsmodels.stats.diagnostic import het_breuschpagan, normal_ad
|
8 |
+
from statsmodels.stats.stattools import jarque_bera
|
9 |
+
from statsmodels.regression.linear_model import RegressionResultsWrapper
|
10 |
+
import pandas as pd
|
11 |
+
import logging
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
def run_lr_diagnostics(results: RegressionResultsWrapper, X: pd.DataFrame) -> Dict[str, Any]:
|
16 |
+
"""
|
17 |
+
Runs diagnostic checks on a fitted OLS model.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
results: A fitted statsmodels OLS results object.
|
21 |
+
X: The design matrix (including constant) used for the regression.
|
22 |
+
Needed for heteroskedasticity tests.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
Dictionary containing diagnostic metrics.
|
26 |
+
"""
|
27 |
+
|
28 |
+
diagnostics = {}
|
29 |
+
|
30 |
+
try:
|
31 |
+
diagnostics['r_squared'] = results.rsquared
|
32 |
+
diagnostics['adj_r_squared'] = results.rsquared_adj
|
33 |
+
diagnostics['f_statistic'] = results.fvalue
|
34 |
+
diagnostics['f_p_value'] = results.f_pvalue
|
35 |
+
diagnostics['n_observations'] = int(results.nobs)
|
36 |
+
diagnostics['degrees_of_freedom_resid'] = int(results.df_resid)
|
37 |
+
|
38 |
+
# --- Normality of Residuals (Jarque-Bera) ---
|
39 |
+
try:
|
40 |
+
jb_value, jb_p_value, skew, kurtosis = jarque_bera(results.resid)
|
41 |
+
diagnostics['residuals_normality_jb_stat'] = jb_value
|
42 |
+
diagnostics['residuals_normality_jb_p_value'] = jb_p_value
|
43 |
+
diagnostics['residuals_skewness'] = skew
|
44 |
+
diagnostics['residuals_kurtosis'] = kurtosis
|
45 |
+
diagnostics['residuals_normality_status'] = "Normal" if jb_p_value > 0.05 else "Non-Normal"
|
46 |
+
except Exception as e:
|
47 |
+
logger.warning(f"Could not run Jarque-Bera test: {e}")
|
48 |
+
diagnostics['residuals_normality_status'] = "Test Failed"
|
49 |
+
|
50 |
+
# --- Homoscedasticity (Breusch-Pagan) ---
|
51 |
+
# Requires the design matrix X used in the model fitting
|
52 |
+
try:
|
53 |
+
lm_stat, lm_p_value, f_stat, f_p_value = het_breuschpagan(results.resid, X)
|
54 |
+
diagnostics['homoscedasticity_bp_lm_stat'] = lm_stat
|
55 |
+
diagnostics['homoscedasticity_bp_lm_p_value'] = lm_p_value
|
56 |
+
diagnostics['homoscedasticity_bp_f_stat'] = f_stat
|
57 |
+
diagnostics['homoscedasticity_bp_f_p_value'] = f_p_value
|
58 |
+
diagnostics['homoscedasticity_status'] = "Homoscedastic" if lm_p_value > 0.05 else "Heteroscedastic"
|
59 |
+
except Exception as e:
|
60 |
+
logger.warning(f"Could not run Breusch-Pagan test: {e}")
|
61 |
+
diagnostics['homoscedasticity_status'] = "Test Failed"
|
62 |
+
|
63 |
+
# --- Linearity (Basic check - often requires visual inspection) ---
|
64 |
+
# No standard quantitative test implemented here. Usually assessed via residual plots.
|
65 |
+
diagnostics['linearity_check'] = "Requires visual inspection (e.g., residual vs fitted plot)"
|
66 |
+
|
67 |
+
# --- Multicollinearity (Placeholder - requires VIF calculation) ---
|
68 |
+
# VIF requires iterating through predictors, more involved
|
69 |
+
diagnostics['multicollinearity_check'] = "Not Implemented (Requires VIF)"
|
70 |
+
|
71 |
+
return {"status": "Success", "details": diagnostics}
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
logger.error(f"Error running LR diagnostics: {e}")
|
75 |
+
return {"status": "Failed", "error": str(e), "details": diagnostics} # Return partial results if possible
|
76 |
+
|
auto_causal/methods/linear_regression/estimator.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Linear Regression Estimator for Causal Inference.
|
3 |
+
|
4 |
+
Uses Ordinary Least Squares (OLS) to estimate the treatment effect, potentially
|
5 |
+
adjusting for covariates.
|
6 |
+
"""
|
7 |
+
import pandas as pd
|
8 |
+
import statsmodels.api as sm
|
9 |
+
import statsmodels.formula.api as smf
|
10 |
+
from typing import Dict, Any, List, Optional, Union
|
11 |
+
import logging
|
12 |
+
from langchain.chat_models.base import BaseChatModel
|
13 |
+
import re
|
14 |
+
import json
|
15 |
+
from pydantic import BaseModel, ValidationError
|
16 |
+
from langchain_core.messages import HumanMessage
|
17 |
+
from langchain_core.exceptions import OutputParserException
|
18 |
+
|
19 |
+
|
20 |
+
from auto_causal.models import LLMIdentifiedRelevantParams
|
21 |
+
from auto_causal.prompts.regression_prompts import STATSMODELS_PARAMS_IDENTIFICATION_PROMPT_TEMPLATE
|
22 |
+
from auto_causal.config import get_llm_client
|
23 |
+
|
24 |
+
# Placeholder for potential future LLM assistance integration
|
25 |
+
# from .llm_assist import interpret_lr_results, suggest_lr_covariates
|
26 |
+
# Placeholder for potential future diagnostics integration
|
27 |
+
# from .diagnostics import run_lr_diagnostics
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel) -> Optional[BaseModel]:
|
32 |
+
"""Helper to call LLM with structured output and handle errors."""
|
33 |
+
try:
|
34 |
+
messages = [HumanMessage(content=prompt)]
|
35 |
+
structured_llm = llm.with_structured_output(pydantic_model)
|
36 |
+
parsed_result = structured_llm.invoke(messages)
|
37 |
+
return parsed_result
|
38 |
+
except (OutputParserException, ValidationError) as e:
|
39 |
+
logger.error(f"LLM call failed parsing/validation for {pydantic_model.__name__}: {e}")
|
40 |
+
except Exception as e:
|
41 |
+
logger.error(f"LLM call failed unexpectedly for {pydantic_model.__name__}: {e}", exc_info=True)
|
42 |
+
return None
|
43 |
+
|
44 |
+
# Define module-level helper function
|
45 |
+
def _clean_variable_name_for_patsy_local(name: str) -> str:
|
46 |
+
if not isinstance(name, str):
|
47 |
+
name = str(name)
|
48 |
+
name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
49 |
+
if not re.match(r'^[a-zA-Z_]', name):
|
50 |
+
name = 'var_' + name
|
51 |
+
return name
|
52 |
+
|
53 |
+
|
54 |
+
def estimate_effect(
|
55 |
+
df: pd.DataFrame,
|
56 |
+
treatment: str,
|
57 |
+
outcome: str,
|
58 |
+
covariates: Optional[List[str]] = None,
|
59 |
+
query_str: Optional[str] = None, # For potential LLM use
|
60 |
+
llm: Optional[BaseChatModel] = None, # For potential LLM use
|
61 |
+
**kwargs # To capture any other potential arguments
|
62 |
+
) -> Dict[str, Any]:
|
63 |
+
"""
|
64 |
+
Estimates the causal effect using Linear Regression (OLS).
|
65 |
+
|
66 |
+
Args:
|
67 |
+
df: Input DataFrame.
|
68 |
+
treatment: Name of the treatment variable column.
|
69 |
+
outcome: Name of the outcome variable column.
|
70 |
+
covariates: Optional list of covariate names.
|
71 |
+
query_str: Optional user query for context (e.g., for LLM).
|
72 |
+
llm: Optional Language Model instance.
|
73 |
+
**kwargs: Additional keyword arguments.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Dictionary containing estimation results:
|
77 |
+
- 'effect_estimate': The estimated coefficient for the treatment variable.
|
78 |
+
- 'p_value': The p-value associated with the treatment coefficient.
|
79 |
+
- 'confidence_interval': The 95% confidence interval for the effect.
|
80 |
+
- 'standard_error': The standard error of the treatment coefficient.
|
81 |
+
- 'formula': The regression formula used.
|
82 |
+
- 'model_summary': Summary object from statsmodels.
|
83 |
+
- 'diagnostics': Placeholder for diagnostic results.
|
84 |
+
- 'interpretation': Placeholder for LLM interpretation.
|
85 |
+
"""
|
86 |
+
if covariates is None:
|
87 |
+
covariates = []
|
88 |
+
|
89 |
+
# Retrieve additional args from kwargs
|
90 |
+
interaction_term_suggested = kwargs.get('interaction_term_suggested', False)
|
91 |
+
# interaction_variable_candidate is the *original* name from query_interpreter
|
92 |
+
interaction_variable_candidate_orig_name = kwargs.get('interaction_variable_candidate')
|
93 |
+
treatment_reference_level = kwargs.get('treatment_reference_level')
|
94 |
+
column_mappings = kwargs.get('column_mappings', {})
|
95 |
+
|
96 |
+
required_cols = [treatment, outcome] + covariates
|
97 |
+
# If interaction variable is suggested, ensure it (or its processed form) is in df for analysis
|
98 |
+
# This check is complex here as interaction_variable_candidate_orig_name needs mapping to processed column(s)
|
99 |
+
# We'll rely on df_analysis.dropna() and formula construction to handle missing interaction var columns later
|
100 |
+
|
101 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
102 |
+
if missing_cols:
|
103 |
+
raise ValueError(f"Missing required columns: {missing_cols}")
|
104 |
+
|
105 |
+
# Prepare data for statsmodels (add constant, handle potential NaNs)
|
106 |
+
df_analysis = df[required_cols].dropna()
|
107 |
+
if df_analysis.empty:
|
108 |
+
raise ValueError("No data remaining after dropping NaNs for required columns.")
|
109 |
+
|
110 |
+
X = df_analysis[[treatment] + covariates]
|
111 |
+
X = sm.add_constant(X) # Add intercept
|
112 |
+
y = df_analysis[outcome]
|
113 |
+
|
114 |
+
# --- Formula Construction ---
|
115 |
+
outcome_col_name = outcome # Name in processed df
|
116 |
+
treatment_col_name = treatment # Name in processed df
|
117 |
+
processed_covariate_col_names = covariates # List of names in processed df
|
118 |
+
|
119 |
+
rhs_terms = []
|
120 |
+
|
121 |
+
# 1. Treatment Term
|
122 |
+
treatment_patsy_term = treatment_col_name # Default
|
123 |
+
original_treatment_info = column_mappings.get(treatment_col_name, {}) # Info from preprocess_data
|
124 |
+
|
125 |
+
is_binary_encoded = original_treatment_info.get('transformed_as') == 'label_encoded_binary'
|
126 |
+
is_still_categorical_in_df = df_analysis[treatment_col_name].dtype.name in ['object', 'category']
|
127 |
+
|
128 |
+
if is_still_categorical_in_df and not is_binary_encoded: # Covers multi-level and binary categoricals not yet numeric
|
129 |
+
if treatment_reference_level:
|
130 |
+
treatment_patsy_term = f"C({treatment_col_name}, Treatment(reference='{treatment_reference_level}'))"
|
131 |
+
logger.info(f"Treating '{treatment_col_name}' as multi-level categorical with reference '{treatment_reference_level}'.")
|
132 |
+
else:
|
133 |
+
# Default C() wrapping for categoricals if no specific reference is given.
|
134 |
+
# This applies to multi-level or binary categoricals that were not label_encoded to 0/1 by preprocess_data.
|
135 |
+
treatment_patsy_term = f"C({treatment_col_name})"
|
136 |
+
logger.info(f"Treating '{treatment_col_name}' as categorical (Patsy will pick reference).")
|
137 |
+
elif is_binary_encoded: # Was binary and explicitly label encoded to 0/1 by preprocess_data
|
138 |
+
# Even if it's now numeric 0/1, C() ensures Patsy treats it categorically for parameter naming consistency.
|
139 |
+
treatment_patsy_term = f"C({treatment_col_name})"
|
140 |
+
logger.info(f"Treating label-encoded binary '{treatment_col_name}' as categorical for Patsy.")
|
141 |
+
else: # Assumed to be already numeric (continuous or discrete numeric not needing C() for main effect)
|
142 |
+
# treatment_patsy_term remains treatment_col_name (default)
|
143 |
+
logger.info(f"Treating '{treatment_col_name}' as numeric for Patsy formula.")
|
144 |
+
|
145 |
+
rhs_terms.append(treatment_patsy_term)
|
146 |
+
|
147 |
+
# 2. Covariate Terms
|
148 |
+
for cov_col_name in processed_covariate_col_names:
|
149 |
+
if cov_col_name == treatment_col_name: # Should not happen if covariates list is clean
|
150 |
+
continue
|
151 |
+
# Assume covariates are already numeric/dummy. If one was object/category in df_analysis (unlikely), C() it.
|
152 |
+
if df_analysis[cov_col_name].dtype.name in ['object', 'category']:
|
153 |
+
rhs_terms.append(f"C({cov_col_name})")
|
154 |
+
else:
|
155 |
+
rhs_terms.append(cov_col_name)
|
156 |
+
|
157 |
+
# 3. Interaction Term (Simplified: interaction_variable_candidate_orig_name must map to a single column in df_analysis)
|
158 |
+
actual_interaction_term_added_to_formula = None
|
159 |
+
if interaction_term_suggested and interaction_variable_candidate_orig_name:
|
160 |
+
processed_interaction_col_name = None
|
161 |
+
interaction_var_info = column_mappings.get(interaction_variable_candidate_orig_name, {})
|
162 |
+
|
163 |
+
if interaction_var_info.get('transformed_as') == 'one_hot_encoded':
|
164 |
+
logger.warning(f"Interaction with one-hot encoded variable '{interaction_variable_candidate_orig_name}' is complex. Currently skipping this interaction for Linear Regression.")
|
165 |
+
elif interaction_var_info.get('new_column_name') and interaction_var_info['new_column_name'] in df_analysis.columns:
|
166 |
+
processed_interaction_col_name = interaction_var_info['new_column_name']
|
167 |
+
elif interaction_variable_candidate_orig_name in df_analysis.columns: # Was not in mappings, or mapping didn't change name (e.g. numeric)
|
168 |
+
processed_interaction_col_name = interaction_variable_candidate_orig_name
|
169 |
+
|
170 |
+
if processed_interaction_col_name:
|
171 |
+
interaction_var_patsy_term = processed_interaction_col_name
|
172 |
+
# If the processed interaction column itself is categorical (e.g. label encoded binary)
|
173 |
+
if df_analysis[processed_interaction_col_name].dtype.name in ['object', 'category', 'bool'] or \
|
174 |
+
interaction_var_info.get('original_dtype') in ['bool', 'category']:
|
175 |
+
interaction_var_patsy_term = f"C({processed_interaction_col_name})"
|
176 |
+
|
177 |
+
actual_interaction_term_added_to_formula = f"{treatment_patsy_term}:{interaction_var_patsy_term}"
|
178 |
+
rhs_terms.append(actual_interaction_term_added_to_formula)
|
179 |
+
logger.info(f"Adding interaction term to formula: {actual_interaction_term_added_to_formula}")
|
180 |
+
elif interaction_variable_candidate_orig_name: # Log if it was suggested but couldn't be mapped/found
|
181 |
+
logger.warning(f"Could not resolve interaction variable candidate '{interaction_variable_candidate_orig_name}' to a single usable column in processed data. Skipping interaction term.")
|
182 |
+
|
183 |
+
# Build the formula string for reporting and fitting
|
184 |
+
if not rhs_terms: # Should always have at least treatment
|
185 |
+
formula = f"{outcome_col_name} ~ 1"
|
186 |
+
else:
|
187 |
+
formula = f"{outcome_col_name} ~ {' + '.join(rhs_terms)}"
|
188 |
+
logger.info(f"Using formula for Linear Regression: {formula}")
|
189 |
+
|
190 |
+
try:
|
191 |
+
model = smf.ols(formula=formula, data=df_analysis)
|
192 |
+
results = model.fit()
|
193 |
+
logger.info("OLS model fitted successfully.")
|
194 |
+
logger.info(results.summary()) # Changed to debug level for less verbose default logging
|
195 |
+
|
196 |
+
# --- Result Extraction: LLM attempt first, then Regex fallback ---
|
197 |
+
effect_estimates_by_level = {}
|
198 |
+
all_params_extracted = False # Default to False
|
199 |
+
llm_extraction_successful = False
|
200 |
+
|
201 |
+
# Attempt LLM-based extraction if llm client and query are available
|
202 |
+
llm = get_llm_client()
|
203 |
+
if llm and query_str:
|
204 |
+
logger.info(f"Attempting LLM-based result extraction (informed by query: '{query_str[:50]}...').")
|
205 |
+
try:
|
206 |
+
param_names_list = results.params.index.tolist()
|
207 |
+
param_estimates_list = results.params.tolist()
|
208 |
+
param_p_values_list = results.pvalues.tolist()
|
209 |
+
param_std_errs_list = results.bse.tolist()
|
210 |
+
|
211 |
+
conf_int_df = results.conf_int(alpha=0.05)
|
212 |
+
param_conf_ints_low_list = []
|
213 |
+
param_conf_ints_high_list = []
|
214 |
+
|
215 |
+
if not conf_int_df.empty and len(conf_int_df.columns) == 2:
|
216 |
+
aligned_conf_int_df = conf_int_df.reindex(results.params.index)
|
217 |
+
param_conf_ints_low_list = aligned_conf_int_df.iloc[:, 0].fillna(float('nan')).tolist()
|
218 |
+
param_conf_ints_high_list = aligned_conf_int_df.iloc[:, 1].fillna(float('nan')).tolist()
|
219 |
+
else:
|
220 |
+
nan_list_ci = [float('nan')] * len(param_names_list)
|
221 |
+
param_conf_ints_low_list = nan_list_ci
|
222 |
+
param_conf_ints_high_list = nan_list_ci
|
223 |
+
|
224 |
+
# Placeholder for the new prompt template tailored for this extraction task
|
225 |
+
# MOVED TO causalscientist/auto_causal/prompts/regression_prompts.py
|
226 |
+
|
227 |
+
is_multilevel_case_for_prompt = bool(treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded)
|
228 |
+
reference_level_for_prompt_str = str(treatment_reference_level) if is_multilevel_case_for_prompt else "N/A"
|
229 |
+
|
230 |
+
indexed_param_names_for_prompt = [f"{idx}: '{name}'" for idx, name in enumerate(param_names_list)]
|
231 |
+
indexed_param_names_str_for_prompt = "\n".join(indexed_param_names_for_prompt)
|
232 |
+
|
233 |
+
prompt_text_for_identification = STATSMODELS_PARAMS_IDENTIFICATION_PROMPT_TEMPLATE.format(
|
234 |
+
user_query=query_str,
|
235 |
+
treatment_patsy_term=treatment_patsy_term,
|
236 |
+
treatment_col_name=treatment_col_name,
|
237 |
+
is_multilevel_case=is_multilevel_case_for_prompt,
|
238 |
+
reference_level_for_prompt=reference_level_for_prompt_str,
|
239 |
+
indexed_param_names_str=indexed_param_names_str_for_prompt, # Pass the indexed list as a string
|
240 |
+
llm_response_schema_json=json.dumps(LLMIdentifiedRelevantParams.model_json_schema(), indent=2)
|
241 |
+
)
|
242 |
+
|
243 |
+
llm_identification_response = _call_llm_for_var(llm, prompt_text_for_identification, LLMIdentifiedRelevantParams)
|
244 |
+
|
245 |
+
if llm_identification_response and llm_identification_response.identified_params:
|
246 |
+
logger.info("LLM identified relevant parameters. Proceeding with programmatic extraction.")
|
247 |
+
for item in llm_identification_response.identified_params:
|
248 |
+
param_idx = item.param_index
|
249 |
+
# Validate index against actual list length
|
250 |
+
if 0 <= param_idx < len(results.params.index):
|
251 |
+
actual_param_name = results.params.index[param_idx]
|
252 |
+
# Sanity check if LLM returned name matches actual name at index
|
253 |
+
if item.param_name != actual_param_name:
|
254 |
+
logger.warning(f"LLM returned param_name '{item.param_name}' but name at index {param_idx} is '{actual_param_name}'. Using actual name from results.")
|
255 |
+
|
256 |
+
current_effect_stats = {
|
257 |
+
'estimate': results.params.iloc[param_idx],
|
258 |
+
'p_value': results.pvalues.iloc[param_idx],
|
259 |
+
'conf_int': results.conf_int(alpha=0.05).iloc[param_idx].tolist(),
|
260 |
+
'std_err': results.bse.iloc[param_idx]
|
261 |
+
}
|
262 |
+
|
263 |
+
key_for_effect_dict = 'treatment_effect' # Default for single/binary
|
264 |
+
if is_multilevel_case_for_prompt: # If it was a multi-level case
|
265 |
+
match = re.search(r'\[T\.([^]]+)]', actual_param_name) # Use actual_param_name
|
266 |
+
if match:
|
267 |
+
level = match.group(1)
|
268 |
+
if level != reference_level_for_prompt_str: # Ensure it's not the ref level itself
|
269 |
+
key_for_effect_dict = level
|
270 |
+
else:
|
271 |
+
logger.warning(f"Could not parse level from LLM-identified param: {actual_param_name}. Storing with raw name.")
|
272 |
+
key_for_effect_dict = actual_param_name # Fallback key
|
273 |
+
|
274 |
+
effect_estimates_by_level[key_for_effect_dict] = current_effect_stats
|
275 |
+
else:
|
276 |
+
logger.warning(f"LLM returned an invalid parameter index: {param_idx}. Skipping.")
|
277 |
+
|
278 |
+
if effect_estimates_by_level: # If any effects were successfully processed
|
279 |
+
all_params_extracted = llm_identification_response.all_parameters_successfully_identified
|
280 |
+
llm_extraction_successful = True
|
281 |
+
logger.info(f"Successfully processed LLM-identified parameters. all_parameters_successfully_identified={all_params_extracted}")
|
282 |
+
print(f"effect_estimates_by_level: {effect_estimates_by_level}")
|
283 |
+
else:
|
284 |
+
logger.warning("LLM identified parameters, but none could be processed into effects_estimates_by_level. Falling back to regex.")
|
285 |
+
else:
|
286 |
+
logger.warning("LLM parameter identification did not yield usable parameters. Falling back to regex.")
|
287 |
+
|
288 |
+
except Exception as e_llm:
|
289 |
+
logger.warning(f"LLM-based result extraction failed: {e_llm}. Falling back to regex.", exc_info=True)
|
290 |
+
|
291 |
+
|
292 |
+
# --- End of Existing Regex Logic Block ---
|
293 |
+
|
294 |
+
# Primary effect_estimate for simple reporting (e.g. first level or the only one)
|
295 |
+
# For multi-level, this is ambiguous. For now, let's report None or the first one.
|
296 |
+
# The full details are in effect_estimates_by_level.
|
297 |
+
main_effect_estimate = None
|
298 |
+
main_p_value = None
|
299 |
+
main_conf_int = [None, None] # Default for single or if no effects
|
300 |
+
main_std_err = None
|
301 |
+
|
302 |
+
if effect_estimates_by_level:
|
303 |
+
if 'treatment_effect' in effect_estimates_by_level: # Single effect case
|
304 |
+
single_effect_data = effect_estimates_by_level['treatment_effect']
|
305 |
+
main_effect_estimate = single_effect_data['estimate']
|
306 |
+
main_p_value = single_effect_data['p_value']
|
307 |
+
main_conf_int = single_effect_data['conf_int']
|
308 |
+
main_std_err = single_effect_data['std_err']
|
309 |
+
else: # Multi-level case
|
310 |
+
logger.info("Multi-level treatment effects extracted. Populating dicts for main estimate fields.")
|
311 |
+
effect_estimate_dict = {}
|
312 |
+
p_value_dict = {}
|
313 |
+
conf_int_dict = {}
|
314 |
+
std_err_dict = {}
|
315 |
+
for level, stats in effect_estimates_by_level.items():
|
316 |
+
effect_estimate_dict[level] = stats.get('estimate')
|
317 |
+
p_value_dict[level] = stats.get('p_value')
|
318 |
+
conf_int_dict[level] = stats.get('conf_int') # This is already a list [low, high]
|
319 |
+
std_err_dict[level] = stats.get('std_err')
|
320 |
+
|
321 |
+
main_effect_estimate = effect_estimate_dict
|
322 |
+
main_p_value = p_value_dict
|
323 |
+
main_conf_int = conf_int_dict
|
324 |
+
main_std_err = std_err_dict
|
325 |
+
|
326 |
+
interpretation_details = {}
|
327 |
+
if actual_interaction_term_added_to_formula and actual_interaction_term_added_to_formula in results.params.index:
|
328 |
+
interpretation_details['interaction_term_coefficient'] = results.params[actual_interaction_term_added_to_formula]
|
329 |
+
interpretation_details['interaction_term_p_value'] = results.pvalues[actual_interaction_term_added_to_formula]
|
330 |
+
logger.info(f"Interaction term '{actual_interaction_term_added_to_formula}' coeff: {interpretation_details['interaction_term_coefficient']}")
|
331 |
+
|
332 |
+
diag_results = {}
|
333 |
+
interpretation = "Interpretation not available."
|
334 |
+
|
335 |
+
output_dict = {
|
336 |
+
'effect_estimate': main_effect_estimate,
|
337 |
+
'p_value': main_p_value,
|
338 |
+
'confidence_interval': main_conf_int,
|
339 |
+
'standard_error': main_std_err,
|
340 |
+
'estimated_effects_by_level': effect_estimates_by_level if (treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded and effect_estimates_by_level) else None,
|
341 |
+
'reference_level_used': treatment_reference_level if (treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded) else None,
|
342 |
+
'formula': formula,
|
343 |
+
'model_summary_text': results.summary().as_text(), # Store as text for easier serialization
|
344 |
+
'diagnostics': diag_results,
|
345 |
+
'interpretation_details': interpretation_details, # Added interaction details
|
346 |
+
'interpretation': interpretation,
|
347 |
+
'method_used': 'Linear Regression (OLS)'
|
348 |
+
}
|
349 |
+
if not all_params_extracted:
|
350 |
+
output_dict['warnings'] = ["Could not reliably extract all requested parameters from model results. Please check model_summary_text."]
|
351 |
+
return output_dict
|
352 |
+
|
353 |
+
except Exception as e:
|
354 |
+
logger.error(f"Linear Regression failed: {e}")
|
355 |
+
raise # Re-raise the exception after logging
|
auto_causal/methods/linear_regression/llm_assist.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM assistance functions for Linear Regression analysis.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import List, Dict, Any, Optional
|
6 |
+
import logging
|
7 |
+
|
8 |
+
# Imported for type hinting
|
9 |
+
from langchain.chat_models.base import BaseChatModel
|
10 |
+
from statsmodels.regression.linear_model import RegressionResultsWrapper
|
11 |
+
|
12 |
+
# Import shared LLM helpers
|
13 |
+
from auto_causal.utils.llm_helpers import call_llm_with_json_output
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
def suggest_lr_covariates(
|
18 |
+
df_cols: List[str],
|
19 |
+
treatment: str,
|
20 |
+
outcome: str,
|
21 |
+
query: str,
|
22 |
+
llm: Optional[BaseChatModel] = None
|
23 |
+
) -> List[str]:
|
24 |
+
"""
|
25 |
+
(Placeholder) Use LLM to suggest relevant covariates for linear regression.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
df_cols: List of available column names.
|
29 |
+
treatment: Treatment variable name.
|
30 |
+
outcome: Outcome variable name.
|
31 |
+
query: User's causal query text.
|
32 |
+
llm: Optional LLM model instance.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
List of suggested covariate names.
|
36 |
+
"""
|
37 |
+
logger.info("LLM covariate suggestion for LR is not implemented yet.")
|
38 |
+
if llm:
|
39 |
+
# Placeholder: Call LLM here in future
|
40 |
+
pass
|
41 |
+
return []
|
42 |
+
|
43 |
+
def interpret_lr_results(
|
44 |
+
results: RegressionResultsWrapper,
|
45 |
+
diagnostics: Dict[str, Any],
|
46 |
+
treatment_var: str, # Need treatment variable name to extract coefficient
|
47 |
+
llm: Optional[BaseChatModel] = None
|
48 |
+
) -> str:
|
49 |
+
"""
|
50 |
+
Use LLM to interpret Linear Regression results.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
results: Fitted statsmodels OLS results object.
|
54 |
+
diagnostics: Dictionary of diagnostic test results.
|
55 |
+
treatment_var: Name of the treatment variable.
|
56 |
+
llm: Optional LLM model instance.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
String containing natural language interpretation.
|
60 |
+
"""
|
61 |
+
default_interpretation = "LLM interpretation not available for Linear Regression."
|
62 |
+
if llm is None:
|
63 |
+
logger.info("LLM not provided for LR interpretation.")
|
64 |
+
return default_interpretation
|
65 |
+
|
66 |
+
try:
|
67 |
+
# --- Prepare summary for LLM ---
|
68 |
+
results_summary = {}
|
69 |
+
treatment_val = results.params.get(treatment_var)
|
70 |
+
pval_val = results.pvalues.get(treatment_var)
|
71 |
+
|
72 |
+
if treatment_val is not None:
|
73 |
+
results_summary['Treatment Effect Estimate'] = f"{treatment_val:.3f}"
|
74 |
+
else:
|
75 |
+
logger.warning(f"Treatment variable '{treatment_var}' not found in regression parameters.")
|
76 |
+
results_summary['Treatment Effect Estimate'] = "Not Found"
|
77 |
+
|
78 |
+
if pval_val is not None:
|
79 |
+
results_summary['Treatment P-value'] = f"{pval_val:.3f}"
|
80 |
+
else:
|
81 |
+
logger.warning(f"P-value for treatment variable '{treatment_var}' not found in regression results.")
|
82 |
+
results_summary['Treatment P-value'] = "Not Found"
|
83 |
+
|
84 |
+
try:
|
85 |
+
conf_int = results.conf_int().loc[treatment_var]
|
86 |
+
results_summary['Treatment 95% CI'] = f"[{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
|
87 |
+
except KeyError:
|
88 |
+
logger.warning(f"Confidence interval for treatment variable '{treatment_var}' not found.")
|
89 |
+
results_summary['Treatment 95% CI'] = "Not Found"
|
90 |
+
except Exception as ci_e:
|
91 |
+
logger.warning(f"Could not extract confidence interval for '{treatment_var}': {ci_e}")
|
92 |
+
results_summary['Treatment 95% CI'] = "Error"
|
93 |
+
|
94 |
+
results_summary['R-squared'] = f"{results.rsquared:.3f}"
|
95 |
+
results_summary['Adj. R-squared'] = f"{results.rsquared_adj:.3f}"
|
96 |
+
|
97 |
+
diag_summary = {}
|
98 |
+
if diagnostics.get("status") == "Success":
|
99 |
+
diag_details = diagnostics.get("details", {})
|
100 |
+
# Format p-values only if they are numbers
|
101 |
+
jb_p = diag_details.get('residuals_normality_jb_p_value')
|
102 |
+
bp_p = diag_details.get('homoscedasticity_bp_lm_p_value')
|
103 |
+
diag_summary['Residuals Normality (Jarque-Bera P-value)'] = f"{jb_p:.3f}" if isinstance(jb_p, (int, float)) else str(jb_p)
|
104 |
+
diag_summary['Homoscedasticity (Breusch-Pagan P-value)'] = f"{bp_p:.3f}" if isinstance(bp_p, (int, float)) else str(bp_p)
|
105 |
+
diag_summary['Homoscedasticity Status'] = diag_details.get('homoscedasticity_status', 'N/A')
|
106 |
+
diag_summary['Residuals Normality Status'] = diag_details.get('residuals_normality_status', 'N/A')
|
107 |
+
else:
|
108 |
+
diag_summary['Status'] = diagnostics.get("status", "Unknown")
|
109 |
+
if "error" in diagnostics:
|
110 |
+
diag_summary['Error'] = diagnostics["error"]
|
111 |
+
|
112 |
+
# --- Construct Prompt ---
|
113 |
+
prompt = f"""
|
114 |
+
You are assisting with interpreting Linear Regression (OLS) results for causal inference.
|
115 |
+
|
116 |
+
Model Results Summary:
|
117 |
+
{results_summary}
|
118 |
+
|
119 |
+
Model Diagnostics Summary:
|
120 |
+
{diag_summary}
|
121 |
+
|
122 |
+
Explain these results in 2-4 concise sentences. Focus on:
|
123 |
+
1. The estimated causal effect of the treatment variable '{treatment_var}' (magnitude, direction, statistical significance based on p-value < 0.05).
|
124 |
+
2. Overall model fit (using R-squared as a rough guide).
|
125 |
+
3. Key diagnostic findings (specifically, mention if residuals are non-normal or if heteroscedasticity is detected, as these violate OLS assumptions and can affect inference).
|
126 |
+
|
127 |
+
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
|
128 |
+
{{
|
129 |
+
"interpretation": "<your concise interpretation text>"
|
130 |
+
}}
|
131 |
+
"""
|
132 |
+
|
133 |
+
# --- Call LLM ---
|
134 |
+
response = call_llm_with_json_output(llm, prompt)
|
135 |
+
|
136 |
+
# --- Process Response ---
|
137 |
+
if response and isinstance(response, dict) and \
|
138 |
+
"interpretation" in response and isinstance(response["interpretation"], str):
|
139 |
+
return response["interpretation"]
|
140 |
+
else:
|
141 |
+
logger.warning(f"Failed to get valid interpretation from LLM. Response: {response}")
|
142 |
+
return default_interpretation
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"Error during LLM interpretation for LR: {e}")
|
146 |
+
return f"Error generating interpretation: {e}"
|
auto_causal/methods/propensity_score/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import estimate_propensity_scores
|
2 |
+
from .matching import estimate_effect as estimate_matching_effect
|
3 |
+
from .weighting import estimate_effect as estimate_weighting_effect
|
4 |
+
from .diagnostics import assess_balance, plot_overlap, plot_balance
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
"estimate_propensity_scores",
|
8 |
+
"estimate_matching_effect",
|
9 |
+
"estimate_weighting_effect",
|
10 |
+
"assess_balance",
|
11 |
+
"plot_overlap",
|
12 |
+
"plot_balance"
|
13 |
+
]
|
auto_causal/methods/propensity_score/base.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base functionality for Propensity Score methods
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.linear_model import LogisticRegression
|
5 |
+
from sklearn.preprocessing import StandardScaler
|
6 |
+
from typing import List, Optional, Dict, Any
|
7 |
+
|
8 |
+
# Placeholder for LLM interaction to select model type
|
9 |
+
def select_propensity_model(df: pd.DataFrame, treatment: str, covariates: List[str],
|
10 |
+
query: Optional[str] = None) -> str:
|
11 |
+
'''Selects the appropriate propensity score model type (e.g., logistic, GBM).
|
12 |
+
|
13 |
+
Placeholder: Currently defaults to Logistic Regression.
|
14 |
+
'''
|
15 |
+
# TODO: Implement LLM call or heuristic to select model based on data characteristics
|
16 |
+
return "logistic"
|
17 |
+
|
18 |
+
def estimate_propensity_scores(df: pd.DataFrame, treatment: str,
|
19 |
+
covariates: List[str], model_type: str = 'logistic',
|
20 |
+
**kwargs) -> np.ndarray:
|
21 |
+
'''Estimate propensity scores using a specified model.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
df: DataFrame containing the data
|
25 |
+
treatment: Name of the treatment variable
|
26 |
+
covariates: List of covariate variable names
|
27 |
+
model_type: Type of model to use ('logistic' supported for now)
|
28 |
+
**kwargs: Additional arguments for the model
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
Array of propensity scores
|
32 |
+
'''
|
33 |
+
|
34 |
+
X = df[covariates]
|
35 |
+
y = df[treatment]
|
36 |
+
|
37 |
+
# Standardize covariates for logistic regression
|
38 |
+
scaler = StandardScaler()
|
39 |
+
X_scaled = scaler.fit_transform(X)
|
40 |
+
|
41 |
+
if model_type.lower() == 'logistic':
|
42 |
+
# Fit logistic regression
|
43 |
+
model = LogisticRegression(max_iter=kwargs.get('max_iter', 1000),
|
44 |
+
solver=kwargs.get('solver', 'liblinear'), # Use liblinear for L1/L2
|
45 |
+
C=kwargs.get('C', 1.0),
|
46 |
+
penalty=kwargs.get('penalty', 'l2'))
|
47 |
+
model.fit(X_scaled, y)
|
48 |
+
|
49 |
+
# Predict probabilities
|
50 |
+
propensity_scores = model.predict_proba(X_scaled)[:, 1]
|
51 |
+
# TODO: Add other model types like Gradient Boosting, etc.
|
52 |
+
# elif model_type.lower() == 'gbm':
|
53 |
+
# from sklearn.ensemble import GradientBoostingClassifier
|
54 |
+
# model = GradientBoostingClassifier(...)
|
55 |
+
# model.fit(X, y)
|
56 |
+
# propensity_scores = model.predict_proba(X)[:, 1]
|
57 |
+
else:
|
58 |
+
raise ValueError(f"Unsupported propensity score model type: {model_type}")
|
59 |
+
|
60 |
+
# Clip scores to avoid extremes which can cause issues in weighting/matching
|
61 |
+
propensity_scores = np.clip(propensity_scores, 0.01, 0.99)
|
62 |
+
|
63 |
+
return propensity_scores
|
64 |
+
|
65 |
+
# Common formatting function (can be expanded)
|
66 |
+
def format_ps_results(effect_estimate: float, effect_se: float,
|
67 |
+
diagnostics: Dict[str, Any], method_details: str,
|
68 |
+
parameters: Dict[str, Any]) -> Dict[str, Any]:
|
69 |
+
'''Standard formatter for PS method results.'''
|
70 |
+
ci_lower = effect_estimate - 1.96 * effect_se
|
71 |
+
ci_upper = effect_estimate + 1.96 * effect_se
|
72 |
+
return {
|
73 |
+
"effect_estimate": float(effect_estimate),
|
74 |
+
"effect_se": float(effect_se),
|
75 |
+
"confidence_interval": [float(ci_lower), float(ci_upper)],
|
76 |
+
"diagnostics": diagnostics,
|
77 |
+
"method_details": method_details,
|
78 |
+
"parameters": parameters
|
79 |
+
# Add p-value if needed (can be calculated from estimate and SE)
|
80 |
+
}
|
auto_causal/methods/propensity_score/diagnostics.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Balance and sensitivity analysis diagnostics for Propensity Score methods
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from typing import Dict, List, Optional, Any
|
6 |
+
|
7 |
+
# Import necessary plotting libraries if visualizations are needed
|
8 |
+
# import matplotlib.pyplot as plt
|
9 |
+
# import seaborn as sns
|
10 |
+
|
11 |
+
# Import utility for standardized differences if needed
|
12 |
+
from auto_causal.methods.utils import calculate_standardized_differences
|
13 |
+
|
14 |
+
def assess_balance(df_original: pd.DataFrame, df_matched_or_weighted: pd.DataFrame,
|
15 |
+
treatment: str, covariates: List[str],
|
16 |
+
method: str,
|
17 |
+
propensity_scores_original: Optional[np.ndarray] = None,
|
18 |
+
propensity_scores_matched: Optional[np.ndarray] = None,
|
19 |
+
weights: Optional[np.ndarray] = None) -> Dict[str, Any]:
|
20 |
+
'''Assesses covariate balance before and after matching/weighting.
|
21 |
+
|
22 |
+
Placeholder: Returns dummy diagnostic data.
|
23 |
+
'''
|
24 |
+
print(f"Assessing balance for {method}...")
|
25 |
+
# TODO: Implement actual balance checking using standardized differences,
|
26 |
+
# variance ratios, KS tests, etc.
|
27 |
+
# Example using standardized differences (needs calculate_standardized_differences):
|
28 |
+
# std_diff_before = calculate_standardized_differences(df_original, treatment, covariates)
|
29 |
+
# std_diff_after = calculate_standardized_differences(df_matched_or_weighted, treatment, covariates, weights=weights)
|
30 |
+
|
31 |
+
dummy_balance_metric = {cov: np.random.rand() * 0.1 for cov in covariates} # Simulate good balance
|
32 |
+
|
33 |
+
return {
|
34 |
+
"balance_metrics": dummy_balance_metric,
|
35 |
+
"balance_achieved": True, # Placeholder
|
36 |
+
"problematic_covariates": [], # Placeholder
|
37 |
+
# Add plots or paths to plots if generated
|
38 |
+
"plots": {
|
39 |
+
"balance_plot": "balance_plot.png",
|
40 |
+
"overlap_plot": "overlap_plot.png"
|
41 |
+
}
|
42 |
+
}
|
43 |
+
|
44 |
+
def assess_weight_distribution(weights: np.ndarray, treatment_indicator: pd.Series) -> Dict[str, Any]:
|
45 |
+
'''Assesses the distribution of IPW weights.
|
46 |
+
|
47 |
+
Placeholder: Returns dummy diagnostic data.
|
48 |
+
'''
|
49 |
+
print("Assessing weight distribution...")
|
50 |
+
# TODO: Implement checks for extreme weights, effective sample size, etc.
|
51 |
+
return {
|
52 |
+
"min_weight": float(np.min(weights)),
|
53 |
+
"max_weight": float(np.max(weights)),
|
54 |
+
"mean_weight": float(np.mean(weights)),
|
55 |
+
"std_dev_weight": float(np.std(weights)),
|
56 |
+
"effective_sample_size": len(weights) / (1 + np.std(weights)**2 / np.mean(weights)**2), # Kish's ESS approx
|
57 |
+
"potential_issues": np.max(weights) > 20 # Example check
|
58 |
+
}
|
59 |
+
|
60 |
+
def plot_overlap(df: pd.DataFrame, treatment: str, propensity_scores: np.ndarray, save_path: str = 'overlap_plot.png'):
|
61 |
+
'''Generates plot showing propensity score overlap.
|
62 |
+
Placeholder: Does nothing.
|
63 |
+
'''
|
64 |
+
print(f"Generating overlap plot (placeholder) -> {save_path}")
|
65 |
+
# TODO: Implement actual plotting (e.g., using seaborn histplot or kdeplot)
|
66 |
+
pass
|
67 |
+
|
68 |
+
def plot_balance(balance_metrics_before: Dict[str, float], balance_metrics_after: Dict[str, float], save_path: str = 'balance_plot.png'):
|
69 |
+
'''Generates plot showing covariate balance before/after.
|
70 |
+
Placeholder: Does nothing.
|
71 |
+
'''
|
72 |
+
print(f"Generating balance plot (placeholder) -> {save_path}")
|
73 |
+
# TODO: Implement actual plotting (e.g., Love plot)
|
74 |
+
pass
|
auto_causal/methods/propensity_score/llm_assist.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LLM Integration points for Propensity Score methods
|
2 |
+
import pandas as pd
|
3 |
+
from typing import List, Optional, Dict, Any
|
4 |
+
|
5 |
+
def determine_optimal_caliper(df: pd.DataFrame, treatment: str,
|
6 |
+
covariates: List[str],
|
7 |
+
query: Optional[str] = None) -> float:
|
8 |
+
'''Determines optimal caliper for PSM using data or LLM.
|
9 |
+
|
10 |
+
Placeholder: Returns a default value.
|
11 |
+
'''
|
12 |
+
# TODO: Implement data-driven (e.g., based on PS distribution) or LLM-assisted caliper selection.
|
13 |
+
# Common rule of thumb is 0.2 * std dev of logit(PS), but that requires calculating PS first.
|
14 |
+
return 0.2
|
15 |
+
|
16 |
+
def determine_optimal_weight_type(df: pd.DataFrame, treatment: str,
|
17 |
+
query: Optional[str] = None) -> str:
|
18 |
+
'''Determines the optimal type of IPW weights (ATE, ATT, etc.).
|
19 |
+
|
20 |
+
Placeholder: Defaults to ATE.
|
21 |
+
'''
|
22 |
+
# TODO: Implement LLM or rule-based selection.
|
23 |
+
return "ATE"
|
24 |
+
|
25 |
+
def determine_optimal_trim_threshold(df: pd.DataFrame, treatment: str,
|
26 |
+
propensity_scores: Optional[pd.Series] = None,
|
27 |
+
query: Optional[str] = None) -> Optional[float]:
|
28 |
+
'''Determines optimal threshold for trimming extreme propensity scores.
|
29 |
+
|
30 |
+
Placeholder: Defaults to no trimming (None).
|
31 |
+
'''
|
32 |
+
# TODO: Implement data-driven or LLM-assisted threshold selection (e.g., based on score distribution).
|
33 |
+
return None # Corresponds to no trimming by default
|
34 |
+
|
35 |
+
# Placeholder for calling LLM to get parameters (can use the one in utils if general enough)
|
36 |
+
def get_llm_parameters(df: pd.DataFrame, query: str, method: str) -> Dict[str, Any]:
|
37 |
+
'''Placeholder to get parameters via LLM based on dataset and query.'''
|
38 |
+
# In reality, call something like analyze_dataset_for_method from utils.llm_helpers
|
39 |
+
print(f"Simulating LLM call to get parameters for {method}...")
|
40 |
+
if method == "PS.Matching":
|
41 |
+
return {"parameters": {"caliper": 0.15}, "validation": {"check_balance": True}}
|
42 |
+
elif method == "PS.Weighting":
|
43 |
+
return {"parameters": {"weight_type": "ATE", "trim_threshold": 0.05}, "validation": {"check_weights": True}}
|
44 |
+
else:
|
45 |
+
return {"parameters": {}, "validation": {}}
|
auto_causal/methods/propensity_score/matching.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Propensity Score Matching Implementation
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.neighbors import NearestNeighbors
|
5 |
+
import statsmodels.api as sm # For bias adjustment regression
|
6 |
+
import logging # For logging fallback
|
7 |
+
from typing import Dict, List, Optional, Any
|
8 |
+
|
9 |
+
# Import DoWhy
|
10 |
+
from dowhy import CausalModel
|
11 |
+
|
12 |
+
from .base import estimate_propensity_scores, format_ps_results, select_propensity_model
|
13 |
+
from .diagnostics import assess_balance #, plot_overlap, plot_balance # Import diagnostic functions
|
14 |
+
# Remove determine_optimal_caliper, it will be replaced by a heuristic
|
15 |
+
from .llm_assist import get_llm_parameters # Import LLM helpers
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
def _calculate_logit(pscore):
|
20 |
+
"""Calculate logit of propensity score, clipping to avoid inf."""
|
21 |
+
# Clip pscore to prevent log(0) or log(1) issues which lead to inf
|
22 |
+
epsilon = 1e-7
|
23 |
+
pscore_clipped = np.clip(pscore, epsilon, 1 - epsilon)
|
24 |
+
return np.log(pscore_clipped / (1 - pscore_clipped))
|
25 |
+
|
26 |
+
def _perform_matching_and_get_att(
|
27 |
+
df_sample: pd.DataFrame,
|
28 |
+
treatment: str,
|
29 |
+
outcome: str,
|
30 |
+
covariates: List[str],
|
31 |
+
propensity_model_type: str,
|
32 |
+
n_neighbors: int,
|
33 |
+
caliper: float,
|
34 |
+
perform_bias_adjustment: bool,
|
35 |
+
**kwargs
|
36 |
+
) -> float:
|
37 |
+
"""
|
38 |
+
Helper to perform Custom KNN PSM and calculate ATT, potentially with bias adjustment.
|
39 |
+
Returns the ATT estimate.
|
40 |
+
"""
|
41 |
+
df_ps = df_sample.copy()
|
42 |
+
try:
|
43 |
+
propensity_scores = estimate_propensity_scores(
|
44 |
+
df_ps, treatment, covariates, model_type=propensity_model_type, **kwargs
|
45 |
+
)
|
46 |
+
except Exception as e:
|
47 |
+
logger.warning(f"Propensity score estimation failed in helper: {e}")
|
48 |
+
return np.nan # Cannot proceed without propensity scores
|
49 |
+
|
50 |
+
df_ps['propensity_score'] = propensity_scores
|
51 |
+
|
52 |
+
treated = df_ps[df_ps[treatment] == 1]
|
53 |
+
control = df_ps[df_ps[treatment] == 0]
|
54 |
+
|
55 |
+
if treated.empty or control.empty:
|
56 |
+
return np.nan
|
57 |
+
|
58 |
+
nn = NearestNeighbors(n_neighbors=n_neighbors, radius=caliper if caliper is not None else np.inf, metric='minkowski', p=2)
|
59 |
+
try:
|
60 |
+
# Ensure control PS are valid before fitting
|
61 |
+
control_ps_values = control[['propensity_score']].values
|
62 |
+
if np.isnan(control_ps_values).any():
|
63 |
+
logger.warning("NaN values found in control propensity scores before NN fitting.")
|
64 |
+
return np.nan
|
65 |
+
nn.fit(control_ps_values)
|
66 |
+
|
67 |
+
# Ensure treated PS are valid before querying
|
68 |
+
treated_ps_values = treated[['propensity_score']].values
|
69 |
+
if np.isnan(treated_ps_values).any():
|
70 |
+
logger.warning("NaN values found in treated propensity scores before NN query.")
|
71 |
+
return np.nan
|
72 |
+
distances, indices = nn.kneighbors(treated_ps_values)
|
73 |
+
|
74 |
+
except ValueError as e:
|
75 |
+
# Handles case where control group might be too small or have NaN PS scores
|
76 |
+
logger.warning(f"NearestNeighbors fitting/query failed: {e}")
|
77 |
+
return np.nan
|
78 |
+
|
79 |
+
matched_outcomes_treated = []
|
80 |
+
matched_outcomes_control_means = []
|
81 |
+
propensity_diffs = []
|
82 |
+
|
83 |
+
for i in range(len(treated)):
|
84 |
+
treated_unit = treated.iloc[[i]]
|
85 |
+
valid_neighbors_mask = distances[i] <= (caliper if caliper is not None else np.inf)
|
86 |
+
valid_neighbors_idx = indices[i][valid_neighbors_mask]
|
87 |
+
|
88 |
+
if len(valid_neighbors_idx) > 0:
|
89 |
+
matched_controls_for_this_treated = control.iloc[valid_neighbors_idx]
|
90 |
+
if matched_controls_for_this_treated.empty:
|
91 |
+
continue # Should not happen with valid_neighbors_idx check, but safety
|
92 |
+
|
93 |
+
matched_outcomes_treated.append(treated_unit[outcome].values[0])
|
94 |
+
matched_outcomes_control_means.append(matched_controls_for_this_treated[outcome].mean())
|
95 |
+
|
96 |
+
if perform_bias_adjustment:
|
97 |
+
# Ensure PS scores are valid before calculating difference
|
98 |
+
treated_ps = treated_unit['propensity_score'].values[0]
|
99 |
+
control_ps_mean = matched_controls_for_this_treated['propensity_score'].mean()
|
100 |
+
if np.isnan(treated_ps) or np.isnan(control_ps_mean):
|
101 |
+
logger.warning("NaN propensity score encountered during bias adjustment calculation.")
|
102 |
+
# Cannot perform bias adjustment for this unit, potentially skip or handle
|
103 |
+
# For now, let's skip adding to propensity_diffs if NaN found
|
104 |
+
continue
|
105 |
+
propensity_diff = treated_ps - control_ps_mean
|
106 |
+
propensity_diffs.append(propensity_diff)
|
107 |
+
|
108 |
+
if not matched_outcomes_treated:
|
109 |
+
return np.nan
|
110 |
+
|
111 |
+
raw_att_components = np.array(matched_outcomes_treated) - np.array(matched_outcomes_control_means)
|
112 |
+
|
113 |
+
if perform_bias_adjustment:
|
114 |
+
# Ensure lengths match *after* potential skips due to NaNs
|
115 |
+
if not propensity_diffs or len(raw_att_components) != len(propensity_diffs):
|
116 |
+
logger.warning("Bias adjustment skipped due to inconsistent data lengths after NaN checks.")
|
117 |
+
return np.mean(raw_att_components)
|
118 |
+
|
119 |
+
try:
|
120 |
+
X_bias_adj = sm.add_constant(np.array(propensity_diffs))
|
121 |
+
y_bias_adj = raw_att_components
|
122 |
+
# Add check for NaNs/Infs in inputs to OLS
|
123 |
+
if np.isnan(X_bias_adj).any() or np.isnan(y_bias_adj).any() or \
|
124 |
+
np.isinf(X_bias_adj).any() or np.isinf(y_bias_adj).any():
|
125 |
+
logger.warning("NaN/Inf values detected in OLS inputs for bias adjustment. Falling back.")
|
126 |
+
return np.mean(raw_att_components)
|
127 |
+
|
128 |
+
bias_model = sm.OLS(y_bias_adj, X_bias_adj).fit()
|
129 |
+
bias_adjusted_att = bias_model.params[0]
|
130 |
+
return bias_adjusted_att
|
131 |
+
except Exception as e:
|
132 |
+
logger.warning(f"OLS for bias adjustment failed: {e}. Falling back to raw ATT.")
|
133 |
+
return np.mean(raw_att_components)
|
134 |
+
else:
|
135 |
+
return np.mean(raw_att_components)
|
136 |
+
|
137 |
+
def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str,
|
138 |
+
covariates: List[str], **kwargs) -> Dict[str, Any]:
|
139 |
+
'''Estimate ATT using Propensity Score Matching.
|
140 |
+
Tries DoWhy's PSM first, falls back to custom implementation if DoWhy fails.
|
141 |
+
Uses bootstrap SE based on the custom implementation regardless.
|
142 |
+
'''
|
143 |
+
query = kwargs.get('query')
|
144 |
+
n_bootstraps = kwargs.get('n_bootstraps', 100)
|
145 |
+
|
146 |
+
# --- Parameter Setup (as before) ---
|
147 |
+
llm_params = get_llm_parameters(df, query, "PS.Matching")
|
148 |
+
llm_suggested_params = llm_params.get("parameters", {})
|
149 |
+
|
150 |
+
caliper = kwargs.get('caliper', llm_suggested_params.get('caliper'))
|
151 |
+
temp_propensity_scores_for_caliper = None
|
152 |
+
try:
|
153 |
+
temp_propensity_scores_for_caliper = estimate_propensity_scores(
|
154 |
+
df, treatment, covariates,
|
155 |
+
model_type=llm_suggested_params.get('propensity_model_type', 'logistic'),
|
156 |
+
**kwargs
|
157 |
+
)
|
158 |
+
if caliper is None and temp_propensity_scores_for_caliper is not None:
|
159 |
+
logit_ps = _calculate_logit(temp_propensity_scores_for_caliper)
|
160 |
+
if not np.isnan(logit_ps).all(): # Check if logit calculation was successful
|
161 |
+
caliper = 0.2 * np.nanstd(logit_ps) # Use nanstd for robustness
|
162 |
+
else:
|
163 |
+
logger.warning("Logit of propensity scores resulted in NaNs, cannot calculate heuristic caliper.")
|
164 |
+
caliper = None
|
165 |
+
elif caliper is None:
|
166 |
+
logger.warning("Could not estimate propensity scores for caliper heuristic.")
|
167 |
+
caliper = None
|
168 |
+
|
169 |
+
except Exception as e:
|
170 |
+
logger.warning(f"Failed to estimate initial propensity scores for caliper heuristic: {e}. Caliper set to None.")
|
171 |
+
caliper = None # Proceed without caliper if heuristic fails
|
172 |
+
|
173 |
+
n_neighbors = kwargs.get('n_neighbors', llm_suggested_params.get('n_neighbors', 1))
|
174 |
+
propensity_model_type = kwargs.get('propensity_model_type',
|
175 |
+
llm_suggested_params.get('propensity_model_type',
|
176 |
+
select_propensity_model(df, treatment, covariates, query)))
|
177 |
+
|
178 |
+
# --- Attempt DoWhy PSM for Point Estimate ---
|
179 |
+
att_estimate = np.nan
|
180 |
+
method_used_for_att = "Fallback Custom PSM"
|
181 |
+
dowhy_model = None
|
182 |
+
identified_estimand = None
|
183 |
+
|
184 |
+
try:
|
185 |
+
logger.info("Attempting estimation using DoWhy Propensity Score Matching...")
|
186 |
+
dowhy_model = CausalModel(
|
187 |
+
data=df,
|
188 |
+
treatment=treatment,
|
189 |
+
outcome=outcome,
|
190 |
+
common_causes=covariates,
|
191 |
+
estimand_type='nonparametric-ate' # Provide list of names directly
|
192 |
+
)
|
193 |
+
# Identify estimand (optional step, but good practice)
|
194 |
+
identified_estimand = dowhy_model.identify_effect(proceed_when_unidentifiable=True)
|
195 |
+
logger.info(f"DoWhy identified estimand: {identified_estimand}")
|
196 |
+
|
197 |
+
# Estimate effect using DoWhy's PSM
|
198 |
+
estimate = dowhy_model.estimate_effect(
|
199 |
+
identified_estimand,
|
200 |
+
method_name="backdoor.propensity_score_matching",
|
201 |
+
target_units="att",
|
202 |
+
method_params={}
|
203 |
+
)
|
204 |
+
att_estimate = estimate.value
|
205 |
+
method_used_for_att = "DoWhy PSM"
|
206 |
+
logger.info(f"DoWhy PSM successful. ATT Estimate: {att_estimate}")
|
207 |
+
|
208 |
+
except Exception as e:
|
209 |
+
logger.warning(f"DoWhy PSM failed: {e}. Falling back to custom PSM implementation.")
|
210 |
+
# Fallback is triggered implicitly if att_estimate remains NaN
|
211 |
+
|
212 |
+
# --- Fallback or if DoWhy failed ---
|
213 |
+
if np.isnan(att_estimate):
|
214 |
+
logger.info("Calculating ATT estimate using fallback custom PSM...")
|
215 |
+
att_estimate = _perform_matching_and_get_att(
|
216 |
+
df, treatment, outcome, covariates,
|
217 |
+
propensity_model_type, n_neighbors, caliper,
|
218 |
+
perform_bias_adjustment=True, **kwargs # Bias adjust the fallback
|
219 |
+
)
|
220 |
+
method_used_for_att = "Fallback Custom PSM" # Confirm it's fallback
|
221 |
+
if np.isnan(att_estimate):
|
222 |
+
raise ValueError("Fallback custom PSM estimation also failed. Cannot proceed.")
|
223 |
+
logger.info(f"Fallback Custom PSM successful. ATT Estimate: {att_estimate}")
|
224 |
+
|
225 |
+
# --- Bootstrap SE (using custom helper for consistency) ---
|
226 |
+
logger.info(f"Calculating Bootstrap SE using custom helper ({n_bootstraps} iterations)...")
|
227 |
+
bootstrap_atts = []
|
228 |
+
for i in range(n_bootstraps):
|
229 |
+
try:
|
230 |
+
# Ensure bootstrap samples are drawn correctly
|
231 |
+
df_boot = df.sample(n=len(df), replace=True, random_state=np.random.randint(1000000) + i)
|
232 |
+
# Bias adjustment in bootstrap can be slow, optionally disable it
|
233 |
+
boot_att = _perform_matching_and_get_att(
|
234 |
+
df_boot, treatment, outcome, covariates,
|
235 |
+
propensity_model_type, n_neighbors, caliper,
|
236 |
+
perform_bias_adjustment=False, **kwargs # Set bias adjustment to False for speed in bootstrap
|
237 |
+
)
|
238 |
+
if not np.isnan(boot_att):
|
239 |
+
bootstrap_atts.append(boot_att)
|
240 |
+
except Exception as boot_e:
|
241 |
+
logger.warning(f"Bootstrap iteration {i+1} failed: {boot_e}")
|
242 |
+
continue # Skip failed bootstrap iteration
|
243 |
+
|
244 |
+
att_se = np.nanstd(bootstrap_atts) if bootstrap_atts else np.nan # Use nanstd
|
245 |
+
actual_bootstrap_iterations = len(bootstrap_atts)
|
246 |
+
logger.info(f"Bootstrap SE calculated: {att_se} from {actual_bootstrap_iterations} successful iterations.")
|
247 |
+
|
248 |
+
# --- Diagnostics (using custom matching logic for consistency) ---
|
249 |
+
logger.info("Performing diagnostic checks using custom matching logic...")
|
250 |
+
diagnostics = {"error": "Diagnostics failed to run."}
|
251 |
+
propensity_scores_orig = temp_propensity_scores_for_caliper # Reuse if available and not None
|
252 |
+
|
253 |
+
if propensity_scores_orig is None:
|
254 |
+
try:
|
255 |
+
propensity_scores_orig = estimate_propensity_scores(
|
256 |
+
df, treatment, covariates, model_type=propensity_model_type, **kwargs
|
257 |
+
)
|
258 |
+
except Exception as e:
|
259 |
+
logger.error(f"Failed to estimate propensity scores for diagnostics: {e}")
|
260 |
+
propensity_scores_orig = None
|
261 |
+
|
262 |
+
if propensity_scores_orig is not None and not np.isnan(propensity_scores_orig).all():
|
263 |
+
df_ps_orig = df.copy()
|
264 |
+
df_ps_orig['propensity_score'] = propensity_scores_orig
|
265 |
+
treated_orig = df_ps_orig[df_ps_orig[treatment] == 1]
|
266 |
+
control_orig = df_ps_orig[df_ps_orig[treatment] == 0]
|
267 |
+
unmatched_treated_count = 0
|
268 |
+
|
269 |
+
# Drop rows with NaN propensity scores before diagnostics
|
270 |
+
treated_orig = treated_orig.dropna(subset=['propensity_score'])
|
271 |
+
control_orig = control_orig.dropna(subset=['propensity_score'])
|
272 |
+
|
273 |
+
if not treated_orig.empty and not control_orig.empty:
|
274 |
+
try:
|
275 |
+
nn_diag = NearestNeighbors(n_neighbors=n_neighbors, radius=caliper if caliper is not None else np.inf, metric='minkowski', p=2)
|
276 |
+
nn_diag.fit(control_orig[['propensity_score']].values)
|
277 |
+
distances_diag, indices_diag = nn_diag.kneighbors(treated_orig[['propensity_score']].values)
|
278 |
+
|
279 |
+
matched_treated_indices_diag = []
|
280 |
+
matched_control_indices_diag = []
|
281 |
+
|
282 |
+
for i in range(len(treated_orig)):
|
283 |
+
valid_neighbors_mask_diag = distances_diag[i] <= (caliper if caliper is not None else np.inf)
|
284 |
+
valid_neighbors_idx_diag = indices_diag[i][valid_neighbors_mask_diag]
|
285 |
+
if len(valid_neighbors_idx_diag) > 0:
|
286 |
+
# Get original DataFrame indices from control_orig based on iloc indices
|
287 |
+
selected_control_original_indices = control_orig.index[valid_neighbors_idx_diag]
|
288 |
+
matched_treated_indices_diag.extend([treated_orig.index[i]] * len(selected_control_original_indices))
|
289 |
+
matched_control_indices_diag.extend(selected_control_original_indices)
|
290 |
+
else:
|
291 |
+
unmatched_treated_count += 1
|
292 |
+
|
293 |
+
if matched_control_indices_diag:
|
294 |
+
# Use unique indices for creating the diagnostic dataframe
|
295 |
+
unique_matched_control_indices = list(set(matched_control_indices_diag))
|
296 |
+
unique_matched_treated_indices = list(set(matched_treated_indices_diag))
|
297 |
+
|
298 |
+
matched_control_df_diag = df.loc[unique_matched_control_indices]
|
299 |
+
matched_treated_df_for_diag = df.loc[unique_matched_treated_indices]
|
300 |
+
matched_df_diag = pd.concat([matched_treated_df_for_diag, matched_control_df_diag]).drop_duplicates()
|
301 |
+
|
302 |
+
# Retrieve propensity scores for the specific units in matched_df_diag
|
303 |
+
ps_matched_for_diag = propensity_scores_orig.loc[matched_df_diag.index]
|
304 |
+
|
305 |
+
diagnostics = assess_balance(df, matched_df_diag, treatment, covariates,
|
306 |
+
method="PSM",
|
307 |
+
propensity_scores_original=propensity_scores_orig,
|
308 |
+
propensity_scores_matched=ps_matched_for_diag)
|
309 |
+
else:
|
310 |
+
diagnostics = {"message": "No units could be matched for diagnostic assessment."}
|
311 |
+
# If no controls were matched, all treated were unmatched
|
312 |
+
unmatched_treated_count = len(treated_orig) if not treated_orig.empty else 0
|
313 |
+
except Exception as diag_e:
|
314 |
+
logger.error(f"Error during diagnostic matching/balance assessment: {diag_e}")
|
315 |
+
diagnostics = {"error": f"Diagnostics failed: {diag_e}"}
|
316 |
+
else:
|
317 |
+
diagnostics = {"message": "Treatment or control group empty after dropping NaN PS, diagnostics skipped."}
|
318 |
+
unmatched_treated_count = len(treated_orig) if not treated_orig.empty else 0
|
319 |
+
|
320 |
+
# Ensure unmatched count calculation is safe
|
321 |
+
if 'unmatched_treated_count' not in locals():
|
322 |
+
unmatched_treated_count = 0 # Initialize if loop didn't run
|
323 |
+
diagnostics["unmatched_treated_count"] = unmatched_treated_count
|
324 |
+
diagnostics["percent_treated_matched"] = (len(treated_orig) - unmatched_treated_count) / len(treated_orig) * 100 if len(treated_orig) > 0 else 0
|
325 |
+
else:
|
326 |
+
diagnostics = {"error": "Propensity scores could not be estimated for diagnostics."}
|
327 |
+
|
328 |
+
# Add final details to diagnostics
|
329 |
+
diagnostics["att_estimation_method"] = method_used_for_att
|
330 |
+
diagnostics["propensity_score_model"] = propensity_model_type
|
331 |
+
diagnostics["bootstrap_iterations_for_se"] = actual_bootstrap_iterations
|
332 |
+
diagnostics["final_caliper_used"] = caliper
|
333 |
+
|
334 |
+
# --- Format and return results ---
|
335 |
+
logger.info(f"Formatting results. ATT Estimate: {att_estimate}, SE: {att_se}, Method: {method_used_for_att}")
|
336 |
+
return format_ps_results(att_estimate, att_se, diagnostics,
|
337 |
+
method_details=f"PSM ({method_used_for_att})",
|
338 |
+
parameters={"caliper": caliper,
|
339 |
+
"n_neighbors": n_neighbors, # n_neighbors used in fallback/bootstrap/diag
|
340 |
+
"propensity_model": propensity_model_type,
|
341 |
+
"n_bootstraps_config": n_bootstraps})
|