FireShadow commited on
Commit
1721aea
·
0 Parent(s):

Initial clean commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env.example +0 -0
  2. .gitattributes +37 -0
  3. .gitignore +217 -0
  4. README copy.md +198 -0
  5. README.md +13 -0
  6. app.py +339 -0
  7. auto_causal/__init__.py +50 -0
  8. auto_causal/agent.py +394 -0
  9. auto_causal/components/__init__.py +28 -0
  10. auto_causal/components/dataset_analyzer.py +853 -0
  11. auto_causal/components/decision_tree.py +366 -0
  12. auto_causal/components/decision_tree_llm.py +218 -0
  13. auto_causal/components/explanation_generator.py +404 -0
  14. auto_causal/components/input_parser.py +456 -0
  15. auto_causal/components/method_validator.py +327 -0
  16. auto_causal/components/output_formatter.py +138 -0
  17. auto_causal/components/query_interpreter.py +580 -0
  18. auto_causal/components/state_manager.py +40 -0
  19. auto_causal/config.py +97 -0
  20. auto_causal/methods/__init__.py +44 -0
  21. auto_causal/methods/backdoor_adjustment/__init__.py +0 -0
  22. auto_causal/methods/backdoor_adjustment/diagnostics.py +92 -0
  23. auto_causal/methods/backdoor_adjustment/estimator.py +105 -0
  24. auto_causal/methods/backdoor_adjustment/llm_assist.py +176 -0
  25. auto_causal/methods/causal_method.py +88 -0
  26. auto_causal/methods/diff_in_means/__init__.py +0 -0
  27. auto_causal/methods/diff_in_means/diagnostics.py +60 -0
  28. auto_causal/methods/diff_in_means/estimator.py +107 -0
  29. auto_causal/methods/diff_in_means/llm_assist.py +95 -0
  30. auto_causal/methods/difference_in_differences/diagnostics.py +345 -0
  31. auto_causal/methods/difference_in_differences/estimator.py +463 -0
  32. auto_causal/methods/difference_in_differences/llm_assist.py +362 -0
  33. auto_causal/methods/difference_in_differences/utils.py +65 -0
  34. auto_causal/methods/generalized_propensity_score/__init__.py +3 -0
  35. auto_causal/methods/generalized_propensity_score/diagnostics.py +196 -0
  36. auto_causal/methods/generalized_propensity_score/estimator.py +386 -0
  37. auto_causal/methods/generalized_propensity_score/llm_assist.py +208 -0
  38. auto_causal/methods/instrumental_variable/__init__.py +1 -0
  39. auto_causal/methods/instrumental_variable/diagnostics.py +218 -0
  40. auto_causal/methods/instrumental_variable/estimator.py +370 -0
  41. auto_causal/methods/instrumental_variable/llm_assist.py +240 -0
  42. auto_causal/methods/linear_regression/__init__.py +0 -0
  43. auto_causal/methods/linear_regression/diagnostics.py +76 -0
  44. auto_causal/methods/linear_regression/estimator.py +355 -0
  45. auto_causal/methods/linear_regression/llm_assist.py +146 -0
  46. auto_causal/methods/propensity_score/__init__.py +13 -0
  47. auto_causal/methods/propensity_score/base.py +80 -0
  48. auto_causal/methods/propensity_score/diagnostics.py +74 -0
  49. auto_causal/methods/propensity_score/llm_assist.py +45 -0
  50. auto_causal/methods/propensity_score/matching.py +341 -0
.env.example ADDED
File without changes
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.pdf filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be added to the global gitignore or merged into this project gitignore. For a PyCharm
158
+ # project, it is recommended to include directory-based project settings:
159
+ .idea/
160
+
161
+ # VS Code
162
+ .vscode/
163
+
164
+ # Data files
165
+ *.csv
166
+ *.xlsx
167
+ *.xls
168
+ *.json
169
+ *.parquet
170
+ *.pickle
171
+ *.pkl
172
+ *.h5
173
+ *.hdf5
174
+
175
+ # Model files
176
+ *.model
177
+ *.joblib
178
+ *.sav
179
+
180
+ # Output directories
181
+ outputs/
182
+ results/
183
+ logs/
184
+ checkpoints/
185
+ wandb/
186
+
187
+ # Temporary files
188
+ *.tmp
189
+ *.temp
190
+ *~
191
+
192
+ # OS generated files
193
+ .DS_Store
194
+ .DS_Store?
195
+ ._*
196
+ .Spotlight-V100
197
+ .Trashes
198
+ ehthumbs.db
199
+ Thumbs.db
200
+
201
+ # LLM API keys and secrets
202
+ .env.local
203
+ .env.production
204
+ secrets.json
205
+ api_keys.txt
206
+
207
+ # Experiment tracking
208
+ mlruns/
209
+ .mlflow/
210
+
211
+ # Large files (adjust sizes as needed)
212
+ *.zip
213
+ *.tar.gz
214
+ *.rar
215
+
216
+ # Project specific
217
+ tests/output/
README copy.md ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">
2
+ <img src="blob/main/asset/cais.png" width="400" alt="CAIS" />
3
+ <br>
4
+ Causal AI Scientist: Facilitating Causal Data Science with
5
+ Large Language Models
6
+ </h1>
7
+ <!-- <p align="center">
8
+ <a href="https://causalcopilot.com/"><b>[Demo]</b></a> •
9
+ <a href="https://github.com/Lancelot39/Causal-Copilot"><b>[Code]</b></a> •
10
+ <a href="">"Coming Soon"<b>[Arxiv(coming soon)]</b></a>
11
+ </p> -->
12
+
13
+ **Causal AI Scientist (CAIS)** is an LLM-powered tool for generating data-driven answers to natural language causal queries. It takes a natural language query (for example, "Does participating in a job training program lead to higher income?"), an accompanying dataset, and the corresponding description as inputs. CAIS then frames a suitable causal estimation problem by selecting appropriate treatment and outcome variables. It finds the suitable method for causal effect estimation, implements it, runs diagnostic tests, and finally interprets the numerical results in the context of the original query.
14
+
15
+ This repo includes instructions on both using the tool to perform causal analysis on a dataset of interest and reproducing results from our paper.
16
+
17
+ **Note** : This repository is a work in progress and will be updated with additional instructions and files.
18
+
19
+ <!-- ## 1. Introduction
20
+
21
+ Causal effect estimation is central to evidence-based decision-making across domains like social sciences, healthcare, and economics. However, it requires specialized expertise to select the right inference method, identify valid variables, and validate results.
22
+
23
+ **CAIS (Causal AI Scientist)** automates this process using Large Language Models (LLMs) to:
24
+ - Parse a natural language causal query.
25
+ - Analyze the dataset characteristics.
26
+ - Select the appropriate causal inference method via a decision tree and prompting strategies.
27
+ - Execute the method using pre-defined code templates.
28
+ - Validate and interpret the results.
29
+
30
+ <div style="text-align: center;">
31
+ <img src="blob/main/asset/CAIS-arch.png" width="990" alt="CAIS" />
32
+ </div>
33
+ </h1>
34
+
35
+ **Key Features:**
36
+ - End-to-end causal estimation with minimal user input.
37
+ - Supports a wide range of methods:
38
+ - **Econometric:** Difference-in-Differences (DiD), Instrumental Variables (IV), Ordinary Least Squares (OLS), Regression Discontinuity Design (RDD).
39
+ - **Causal Graph-based:** Backdoor adjustment, Frontdoor adjustment.
40
+ - Combines structured reasoning (decision tree) with LLM-powered interpretation.
41
+ - Works on clean textbook datasets, messy real-world datasets, and synthetic scenarios.
42
+
43
+
44
+ CAIS consists of three main stages, powered by a **decision-tree-driven reasoning pipeline**:
45
+
46
+ ### **Stage 1: Variable and Method Selection**
47
+ 1. **Dataset & Query Analysis**
48
+ - The LLM inspects the dataset description, variable names, and statistical summaries.
49
+ - Identifies treatment, outcome, and covariates.
50
+ 2. **Property Detection**
51
+ - Uses targeted prompts to detect dataset properties:
52
+ - Randomized vs observational
53
+ - Presence of temporal/running variables
54
+ - Availability of valid instruments
55
+ 3. **Decision Tree Traversal**
56
+ - Traverses a predefined causal inference decision tree (Fig. B in paper).
57
+ - Maps detected properties to the most appropriate estimation method.
58
+
59
+ ---
60
+
61
+ ### **Stage 2: Causal Inference Execution**
62
+ 1. **Template-based Code Generation**
63
+ - Predefined Python templates for each method (e.g., DiD, IV, OLS).
64
+ - Variables from Stage 1 are substituted into templates.
65
+ 2. **Diagnostics & Validation**
66
+ - Runs statistical tests and checks assumptions where applicable.
67
+ - Handles basic data preprocessing (e.g., type conversion for DoWhy).
68
+
69
+ ---
70
+
71
+ ### **Stage 3: Result Interpretation**
72
+ - LLM interprets numerical results and diagnostics in the context of the user’s causal query.
73
+ - Outputs:
74
+ - Estimated causal effect (ATE, ATT, or LATE).
75
+ - Standard errors, confidence intervals.
76
+ - Plain-language explanation.
77
+
78
+ ---
79
+ ## 3. Evaluation
80
+
81
+ We evaluate **CAIS** across three diverse dataset collections:
82
+ 1. **QRData (Textbook Examples)** – curated, clean datasets with known causal effects.
83
+ 2. **Real-World Studies** – empirical datasets from research papers (economics, health, political science).
84
+ 3. **Synthetic Data** – generated with controlled causal structures to ensure balanced method coverage.
85
+
86
+ ### **Metrics**
87
+ We assess CAIS on:
88
+ - **Method Selection Accuracy (MSA)** – % of cases where CAIS selects the correct inference method as per the reference.
89
+ - **Mean Relative Error (MRE)** – Average relative error between CAIS’s estimated causal effect and the reference value.
90
+
91
+
92
+ <p align="center">
93
+ <table>
94
+ <tr>
95
+ <td align="center">
96
+ <img src="blob/main/asset/CAIS-MRE.png" width="450" alt="CAIS MRE"/>
97
+ </td>
98
+ <td align="center">
99
+ <img src="blob/main/asset/CAIS-msa.png" width="450" alt="CAIS MSA"/>
100
+ </td>
101
+ </tr>
102
+ </table>
103
+ </p>
104
+ -->
105
+
106
+ ## Getting Started
107
+
108
+ #### 🔧 Environment Installation
109
+
110
+
111
+ **Prerequisites:**
112
+ - **Python 3.10** (create a new conda environment first)
113
+ - Required Python libraries (specified in `requirements.txt`)
114
+
115
+
116
+ **Step 1: Copy the example configuration**
117
+ ```bash
118
+ cp .env.example .env
119
+ ```
120
+
121
+ **Step 2: Create Python 3.10 environment**
122
+ ```bash
123
+ # Create a new conda environment with Python 3.10
124
+ conda create -n auto_causal python=3.10
125
+ conda activate auto_causal
126
+ pip install -r requirement.txt
127
+ ```
128
+
129
+ **Step3: Setup auto_causal library**
130
+ ```bash
131
+ pip install -e .
132
+ ```
133
+
134
+ ## Dataset Information
135
+
136
+ All datasets used to evaluate CAIs and the baseline models are available in the data/ directory. Specifically:
137
+
138
+ * `all_data`: Folder containing all CSV files from the QRData and real-world study collections.
139
+ * `synthetic_data`: Folder containing all CSV files corresponding to synthetic datasets.
140
+ * `qr_info.csv`: Metadata for QRData files. For each file, this includes the filename, description, causal query, reference causal effect, intended inference method, and additional remarks.
141
+ * `real_info.csv`: Metadata for the real-world datasets.
142
+ * `synthetic_info.csv`: Metadata for the synthetic datasets.
143
+
144
+ ## Run
145
+ To execute CAIS, run
146
+ ```python
147
+ python main/run_cais.py \
148
+ --metadata_path {path_to_metadata} \
149
+ --data_dir {path_to_data_folder} \
150
+ --output_dir {output_folder} \
151
+ --output_name {output_filename} \
152
+ --llm_name {llm_name}
153
+ ```
154
+ Args:
155
+
156
+ * metadata_path (str): Path to the CSV file containing the queries, dataset descriptions, and data file names
157
+ * data_dir (str): Path to the folder containing the data in CSV format
158
+ * output_dir (str): Path to the folder where the output JSON results will be saved
159
+ * output_name (str): Name of the JSON file where the outputs will be saved
160
+ * llm_name (str): Name of the LLM to be used (e.g., 'gpt-4', 'claude-3', etc.)
161
+
162
+ A specific example,
163
+ ```python
164
+ python main/run_cais.py \
165
+ --metadata_path "data/qr_info.csv" \
166
+ --data_dir "data/all_data" \
167
+ --output_dir "output" \
168
+ --output_name "results_qr_4o" \
169
+ --llm_name "gpt-4o-mini"
170
+ ```
171
+
172
+
173
+ ## Reproducing paper results
174
+ **Will be updated soon**
175
+
176
+ **⚠️ Important Notes:**
177
+ - Keep your `.env` file secure and never commit it to version control
178
+
179
+ ## License
180
+
181
+ Distributed under the MIT License. See `LICENSE` for more information.
182
+
183
+
184
+
185
+ <!--## Contributors
186
+
187
+
188
+
189
+ **Core Contributors**: Vishal Verma, Sawal Acharya, Devansh Bhardwaj
190
+
191
+ **Other Contributors**: Zhijing Jin, Ana Hagihat, Samuel Simko
192
+
193
+ ---
194
+
195
+ ## Contact
196
+
197
+ For additional information, questions, or feedback, please contact ours **[Vishal Verma]([email protected])**, **[Sawal Acharya]([email protected])**, **[Devansh Bhardwaj]([email protected])**. We welcome contributions! Come and join us now!
198
+ -->
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Causal AI Scientist
3
+ emoji: 🌍
4
+ colorFrom: green
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.41.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ from pathlib import Path
5
+ import gradio as gr
6
+ import time
7
+
8
+ # Make your repo importable (expecting a folder named causal-agent at repo root)
9
+ sys.path.append(str(Path(__file__).parent / "causal-agent"))
10
+
11
+ from auto_causal.agent import run_causal_analysis # uses env for provider/model
12
+
13
+ # -------- LLM config (OpenAI only; key via HF Secrets) --------
14
+ os.environ.setdefault("LLM_PROVIDER", "openai")
15
+ os.environ.setdefault("LLM_MODEL", "gpt-4o")
16
+
17
+ # Lazy import to avoid import-time errors if key missing
18
+ def _get_openai_client():
19
+ if os.getenv("LLM_PROVIDER", "openai") != "openai":
20
+ raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.")
21
+ if not os.getenv("OPENAI_API_KEY"):
22
+ raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).")
23
+ try:
24
+ # OpenAI SDK v1+
25
+ from openai import OpenAI
26
+ return OpenAI()
27
+ except Exception as e:
28
+ raise RuntimeError(f"OpenAI SDK not available: {e}")
29
+
30
+ # -------- System prompt you asked for (verbatim) --------
31
+ SYSTEM_PROMPT = """You are an expert in statistics and causal inference.
32
+ You will be given:
33
+ 1) The original research question.
34
+ 2) The analysis method used.
35
+ 3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group.
36
+ 4) A brief dataset description.
37
+
38
+ Your task is to produce a clear, concise, and non-technical summary that:
39
+ - Directly answers the research question.
40
+ - States whether the effect is statistically significant.
41
+ - Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change).
42
+ - Mentions the method used in one sentence.
43
+ - Optionally ranks the treatment effects from largest to smallest if multiple treatments exist.
44
+
45
+ Formatting rules:
46
+ - Use bullet points or short paragraphs.
47
+ - Report effect sizes to two decimal places.
48
+ - Clearly state the interpretation in plain English without technical jargon.
49
+
50
+ Example Output Structure:
51
+ - **Method:** [Name of method + 1-line rationale]
52
+ - **Key Finding:** [Main answer to the research question]
53
+ - **Details:**
54
+ - [Treatment name]: +X.XX percentage points (95% CI: [L, U]), p < 0.001 — [Significance comment]
55
+ - …
56
+ - **Rank Order of Effects:** [Largest → Smallest]
57
+ """
58
+
59
+ def _extract_minimal_payload(agent_result: dict) -> dict:
60
+ """
61
+ Extract the minimal, LLM-friendly payload from run_causal_analysis output.
62
+ Falls back gracefully if any fields are missing.
63
+ """
64
+ # Try both top-level and nested (your JSON showed both patterns)
65
+ res = agent_result or {}
66
+ results = res.get("results", {}) if isinstance(res.get("results"), dict) else {}
67
+ inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {}
68
+ vars_ = results.get("variables", {}) if isinstance(results.get("variables"), dict) else {}
69
+ dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {}
70
+
71
+ # Pull best-available fields
72
+ question = (
73
+ results.get("original_query")
74
+ or dataset_analysis.get("original_query")
75
+ or res.get("query")
76
+ or "N/A"
77
+ )
78
+ method = (
79
+ inner.get("method_used")
80
+ or res.get("method_used")
81
+ or results.get("method_used")
82
+ or "N/A"
83
+ )
84
+
85
+ effect_estimate = (
86
+ inner.get("effect_estimate")
87
+ or res.get("effect_estimate")
88
+ or {}
89
+ )
90
+ confidence_interval = (
91
+ inner.get("confidence_interval")
92
+ or res.get("confidence_interval")
93
+ or {}
94
+ )
95
+ standard_error = (
96
+ inner.get("standard_error")
97
+ or res.get("standard_error")
98
+ or {}
99
+ )
100
+ p_value = (
101
+ inner.get("p_value")
102
+ or res.get("p_value")
103
+ or {}
104
+ )
105
+
106
+ dataset_desc = (
107
+ results.get("dataset_description")
108
+ or res.get("dataset_description")
109
+ or "N/A"
110
+ )
111
+
112
+ return {
113
+ "original_question": question,
114
+ "method_used": method,
115
+ "estimates": {
116
+ "effect_estimate": effect_estimate,
117
+ "confidence_interval": confidence_interval,
118
+ "standard_error": standard_error,
119
+ "p_value": p_value,
120
+ },
121
+ "dataset_description": dataset_desc,
122
+ }
123
+
124
+ def _format_effects_md(effect_estimate: dict) -> str:
125
+ """
126
+ Minimal human-readable view of effect estimates for display.
127
+ """
128
+ if not effect_estimate or not isinstance(effect_estimate, dict):
129
+ return "_No effect estimates found._"
130
+ # Render as bullet list
131
+ lines = []
132
+ for k, v in effect_estimate.items():
133
+ try:
134
+ lines.append(f"- **{k}**: {float(v):+.4f}")
135
+ except Exception:
136
+ lines.append(f"- **{k}**: {v}")
137
+ return "\n".join(lines)
138
+
139
+ def _summarize_with_llm(payload: dict) -> str:
140
+ """
141
+ Calls OpenAI with the provided SYSTEM_PROMPT and the JSON payload as the user message.
142
+ Returns the model's text, or raises on error.
143
+ """
144
+ client = _get_openai_client()
145
+ model_name = os.getenv("LLM_MODEL", "gpt-4o-mini")
146
+
147
+ user_content = (
148
+ "Summarize the following causal analysis results:\n\n"
149
+ + json.dumps(payload, indent=2, ensure_ascii=False)
150
+ )
151
+
152
+ # Use Chat Completions for broad compatibility
153
+ resp = client.chat.completions.create(
154
+ model=model_name,
155
+ messages=[
156
+ {"role": "system", "content": SYSTEM_PROMPT},
157
+ {"role": "user", "content": user_content},
158
+ ],
159
+ temperature=0
160
+ )
161
+ text = resp.choices[0].message.content.strip()
162
+ return text
163
+
164
+ def run_agent(query: str, csv_path: str, dataset_description: str):
165
+ """
166
+ Modified to use yield for progressive updates and immediate feedback
167
+ """
168
+ # Immediate feedback - show processing has started
169
+ processing_html = """
170
+ <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
171
+ <div style='font-size: 16px; margin-bottom: 5px;'>🔄 Analysis in Progress...</div>
172
+ <div style='font-size: 14px; color: #666;'>This may take 1-2 minutes depending on dataset size</div>
173
+ </div>
174
+ """
175
+
176
+ yield (
177
+ processing_html, # method_out
178
+ processing_html, # effects_out
179
+ processing_html, # explanation_out
180
+ {"status": "Processing started..."} # raw_results
181
+ )
182
+
183
+ # Input validation
184
+ if not os.getenv("OPENAI_API_KEY"):
185
+ error_html = "<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>⚠️ Set a Space Secret named OPENAI_API_KEY</div>"
186
+ yield (error_html, "", "", {})
187
+ return
188
+
189
+ if not csv_path:
190
+ error_html = "<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>Please upload a CSV dataset.</div>"
191
+ yield (error_html, "", "", {})
192
+ return
193
+
194
+ try:
195
+ # Update status to show causal analysis is running
196
+ analysis_html = """
197
+ <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
198
+ <div style='font-size: 16px; margin-bottom: 5px;'>📊 Running Causal Analysis...</div>
199
+ <div style='font-size: 14px; color: #666;'>Analyzing dataset and selecting optimal method</div>
200
+ </div>
201
+ """
202
+
203
+ yield (
204
+ analysis_html,
205
+ analysis_html,
206
+ analysis_html,
207
+ {"status": "Running causal analysis..."}
208
+ )
209
+
210
+ result = run_causal_analysis(
211
+ query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
212
+ dataset_path=csv_path,
213
+ dataset_description=(dataset_description or "").strip(),
214
+ )
215
+
216
+ # Update to show LLM summarization step
217
+ llm_html = """
218
+ <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
219
+ <div style='font-size: 16px; margin-bottom: 5px;'>🤖 Generating Summary...</div>
220
+ <div style='font-size: 14px; color: #666;'>Creating human-readable interpretation</div>
221
+ </div>
222
+ """
223
+
224
+ yield (
225
+ llm_html,
226
+ llm_html,
227
+ llm_html,
228
+ {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}}
229
+ )
230
+
231
+ except Exception as e:
232
+ error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>❌ Error: {e}</div>"
233
+ yield (error_html, "", "", {})
234
+ return
235
+
236
+ try:
237
+ payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
238
+ method = payload.get("method_used", "N/A")
239
+
240
+ # Format method output with simple styling
241
+ method_html = f"""
242
+ <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
243
+ <h3 style='margin: 0 0 10px 0; font-size: 18px;'>Selected Method</h3>
244
+ <p style='margin: 0; font-size: 16px;'>{method}</p>
245
+ </div>
246
+ """
247
+
248
+ # Format effects with simple styling
249
+ effect_estimate = payload.get("estimates", {}).get("effect_estimate", {})
250
+ if effect_estimate:
251
+ effects_html = "<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>"
252
+ effects_html += "<h3 style='margin: 0 0 10px 0; font-size: 18px;'>Effect Estimates</h3>"
253
+ # for k, v in effect_estimate.items():
254
+ # try:
255
+ # value = f"{float(v):+.4f}"
256
+ # effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #ffffff;'><strong>{k}:</strong> <span style='font-size: 16px;'>{value}</span></div>"
257
+ # except:
258
+ effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #333333;'>{effect_estimate}</div>"
259
+ effects_html += "</div>"
260
+ else:
261
+ effects_html = "<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; color: #666; font-style: italic; background-color: #333333;'>No effect estimates found</div>"
262
+
263
+ # Generate explanation and format it
264
+ try:
265
+ explanation = _summarize_with_llm(payload)
266
+ explanation_html = f"""
267
+ <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
268
+ <h3 style='margin: 0 0 15px 0; font-size: 18px;'>Detailed Explanation</h3>
269
+ <div style='line-height: 1.6; white-space: pre-wrap;'>{explanation}</div>
270
+ </div>
271
+ """
272
+ except Exception as e:
273
+ explanation_html = f"<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>⚠️ LLM summary failed: {e}</div>"
274
+
275
+ except Exception as e:
276
+ error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>❌ Failed to parse results: {e}</div>"
277
+ yield (error_html, "", "", {})
278
+ return
279
+
280
+ # Final result
281
+ yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {})
282
+
283
+ with gr.Blocks() as demo:
284
+ gr.Markdown("# Causal Agent")
285
+ gr.Markdown("Upload your dataset and ask causal questions in natural language. The system will automatically select the appropriate causal inference method and provide clear explanations.")
286
+
287
+ with gr.Row():
288
+ query = gr.Textbox(
289
+ label="Your causal question (natural language)",
290
+ placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?",
291
+ lines=2,
292
+ )
293
+
294
+ with gr.Row():
295
+ csv_file = gr.File(
296
+ label="Dataset (CSV)",
297
+ file_types=[".csv"],
298
+ type="filepath"
299
+ )
300
+
301
+ dataset_description = gr.Textbox(
302
+ label="Dataset description (optional)",
303
+ placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.",
304
+ lines=4,
305
+ )
306
+
307
+ run_btn = gr.Button("Run analysis", variant="primary")
308
+
309
+ with gr.Row():
310
+ with gr.Column(scale=1):
311
+ method_out = gr.HTML(label="Selected Method")
312
+ with gr.Column(scale=1):
313
+ effects_out = gr.HTML(label="Effect Estimates")
314
+
315
+ with gr.Row():
316
+ explanation_out = gr.HTML(label="Detailed Explanation")
317
+
318
+ # Add the collapsible raw results section
319
+ with gr.Accordion("Raw Results (Advanced)", open=False):
320
+ raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
321
+
322
+ run_btn.click(
323
+ fn=run_agent,
324
+ inputs=[query, csv_file, dataset_description],
325
+ outputs=[method_out, effects_out, explanation_out, raw_results],
326
+ show_progress=True
327
+ )
328
+
329
+ gr.Markdown(
330
+ """
331
+ **Tips:**
332
+ - Be specific about your treatment, outcome, and control variables
333
+ - Include relevant context in the dataset description
334
+ - The analysis may take 1-2 minutes for complex datasets
335
+ """
336
+ )
337
+
338
+ if __name__ == "__main__":
339
+ demo.queue().launch()
auto_causal/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Auto Causal module for causal inference.
3
+
4
+ This module provides automated causal inference capabilities
5
+ through a pipeline that selects and applies appropriate causal methods.
6
+ """
7
+
8
+ __version__ = "0.1.0"
9
+
10
+ # Import components
11
+ from auto_causal.components import (
12
+ parse_input,
13
+ analyze_dataset,
14
+ interpret_query,
15
+ validate_method,
16
+ generate_explanation,
17
+ format_output,
18
+ create_workflow_state_update
19
+ )
20
+
21
+ # Import tools
22
+ from auto_causal.tools import (
23
+ input_parser_tool,
24
+ dataset_analyzer_tool,
25
+ query_interpreter_tool,
26
+ method_selector_tool,
27
+ method_validator_tool,
28
+ method_executor_tool,
29
+ explanation_generator_tool,
30
+ output_formatter_tool
31
+ )
32
+
33
+ # Import the main agent function
34
+ from .agent import run_causal_analysis
35
+
36
+ # Remove backward compatibility for old pipeline
37
+ # try:
38
+ # from .pipeline import CausalInferencePipeline
39
+ # except ImportError:
40
+ # # Define a placeholder class if the old pipeline doesn't exist
41
+ # class CausalInferencePipeline:
42
+ # """Placeholder for CausalInferencePipeline."""
43
+ #
44
+ # def __init__(self, *args, **kwargs):
45
+ # pass
46
+
47
+ # Update __all__ to export the main function
48
+ __all__ = [
49
+ 'run_causal_analysis'
50
+ ]
auto_causal/agent.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangChain agent for the auto_causal module.
3
+
4
+ This module configures a LangChain agent with specialized tools for causal inference,
5
+ allowing for an interactive approach to analyzing datasets and applying appropriate
6
+ causal inference methods.
7
+ """
8
+
9
+ import logging
10
+ from typing import Dict, List, Any, Optional
11
+ from langchain.agents.react.agent import create_react_agent
12
+ from langchain.agents import AgentExecutor, create_structured_chat_agent, create_tool_calling_agent
13
+ from langchain.chains.conversation.memory import ConversationBufferMemory
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
16
+ from langchain.tools import tool
17
+ # Import the callback handler
18
+ from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
19
+ # Import tool rendering utility
20
+ from langchain.tools.render import render_text_description
21
+ # Import LCEL components
22
+ from langchain.agents.format_scratchpad.tools import format_to_tool_messages
23
+ from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
24
+ from langchain_core.runnables import RunnablePassthrough
25
+ from langchain_core.language_models import BaseChatModel
26
+ from langchain_anthropic.chat_models import convert_to_anthropic_tool
27
+ import os
28
+ # Import actual tools from the tools directory
29
+ from auto_causal.tools.input_parser_tool import input_parser_tool
30
+ from auto_causal.tools.dataset_analyzer_tool import dataset_analyzer_tool
31
+ from auto_causal.tools.query_interpreter_tool import query_interpreter_tool
32
+ from auto_causal.tools.method_selector_tool import method_selector_tool
33
+ from auto_causal.tools.method_validator_tool import method_validator_tool
34
+ from auto_causal.tools.method_executor_tool import method_executor_tool
35
+ from auto_causal.tools.explanation_generator_tool import explanation_generator_tool
36
+ from auto_causal.tools.output_formatter_tool import output_formatter_tool
37
+ #from auto_causal.prompts import SYSTEM_PROMPT # Assuming SYSTEM_PROMPT is defined here or imported
38
+ from langchain_core.output_parsers import StrOutputParser
39
+ # Import the centralized factory function
40
+ from .config import get_llm_client
41
+ #from .prompts import SYSTEM_PROMPT
42
+ from langchain_core.messages import AIMessage, AIMessageChunk
43
+ import re
44
+ import json
45
+ from typing import Union
46
+ from langchain_core.output_parsers import BaseOutputParser
47
+ from langchain.schema import AgentAction, AgentFinish
48
+ from langchain_anthropic.output_parsers import ToolsOutputParser
49
+ from langchain.agents.react.output_parser import ReActOutputParser
50
+ from langchain.agents import AgentOutputParser
51
+ from langchain.agents.agent import AgentAction, AgentFinish, OutputParserException
52
+ import re
53
+ from typing import Union, List
54
+ from auto_causal.models import *
55
+
56
+ from langchain_core.agents import AgentAction, AgentFinish
57
+ from langchain_core.exceptions import OutputParserException
58
+
59
+ from langchain.agents.agent import AgentOutputParser
60
+ from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
61
+
62
+ FINAL_ANSWER_ACTION = "Final Answer:"
63
+ MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
64
+ "Invalid Format: Missing 'Action:' after 'Thought:'"
65
+ )
66
+ MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE = (
67
+ "Invalid Format: Missing 'Action Input:' after 'Action:'"
68
+ )
69
+ FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
70
+ "Parsing LLM output produced both a final answer and parse-able actions"
71
+ )
72
+
73
+
74
+ class ReActMultiInputOutputParser(AgentOutputParser):
75
+ """Parses ReAct-style output that may contain multiple tool calls."""
76
+
77
+ def get_format_instructions(self) -> str:
78
+ # You can reuse the original FORMAT_INSTRUCTIONS,
79
+ # but let the model know it may emit multiple actions.
80
+ return FORMAT_INSTRUCTIONS + (
81
+ "\n\nIf you need to call more than one tool, simply repeat:\n"
82
+ "Action: <tool_name>\n"
83
+ "Action Input: <json or text>\n"
84
+ "…for each tool in sequence."
85
+ )
86
+
87
+ @property
88
+ def _type(self) -> str:
89
+ return "react-multi-input"
90
+
91
+ def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
92
+ includes_answer = FINAL_ANSWER_ACTION in text
93
+ print('-------------------')
94
+ print(text)
95
+ print('-------------------')
96
+ # Grab every Action / Action Input block
97
+ pattern = (
98
+ r"Action\s*\d*\s*:[\s]*(.*?)\s*"
99
+ r"Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*?)(?=(?:Action\s*\d*\s*:|$))"
100
+ )
101
+ matches = list(re.finditer(pattern, text, re.DOTALL))
102
+
103
+ # If we found tool calls…
104
+ if matches:
105
+ if includes_answer:
106
+ # both a final answer *and* tool calls is ambiguous
107
+ raise OutputParserException(
108
+ f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
109
+ )
110
+
111
+ actions: List[AgentAction] = []
112
+ for m in matches:
113
+ tool_name = m.group(1).strip()
114
+ tool_input = m.group(2).strip().strip('"')
115
+ print('\n--------------------------')
116
+ print(tool_input)
117
+ print('--------------------------')
118
+ actions.append(AgentAction(tool_name, json.loads(tool_input), text))
119
+
120
+ return actions
121
+
122
+ # Otherwise, if there's a final answer, finish
123
+ if includes_answer:
124
+ answer = text.split(FINAL_ANSWER_ACTION, 1)[1].strip()
125
+ return AgentFinish({"output": answer}, text)
126
+
127
+ # No calls and no final answer → figure out which error to throw
128
+ if not re.search(r"Action\s*\d*\s*Input\s*\d*:", text):
129
+ raise OutputParserException(
130
+ f"Could not parse LLM output: `{text}`",
131
+ observation=MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
132
+ llm_output=text,
133
+ send_to_llm=True,
134
+ )
135
+
136
+ # Fallback
137
+ raise OutputParserException(f"Could not parse LLM output: `{text}`")
138
+
139
+ # Set up basic logging
140
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
141
+ logger = logging.getLogger(__name__)
142
+
143
+
144
+ def create_agent_prompt(tools: List[tool]) -> ChatPromptTemplate:
145
+ """Create the prompt template for the causal inference agent, emphasizing workflow and data handoff.
146
+ (This is the version required by the LCEL agent structure below)
147
+ """
148
+ # Get the tool descriptions
149
+ tool_description = render_text_description(tools)
150
+ tool_names = ", ".join([t.name for t in tools])
151
+
152
+ # Define the system prompt template string
153
+ system_template = """
154
+ You are a causal inference expert helping users answer causal questions by following a strict workflow using specialized tools.
155
+
156
+ Remember you always have to always generate the Thought, Action and Action Input block.
157
+ TOOLS:
158
+ ------
159
+ You have access to the following tools:
160
+
161
+ {tools}
162
+
163
+ To use a tool, please use the following format:
164
+
165
+ Thought: Do I need to use a tool? Yes
166
+ Action: the action to take, should be one of [{tool_names}]
167
+ Action Input: the input to the action, as a single, valid JSON object string. Check the tool definition for required arguments and structure.
168
+ Observation: the result of the action, often containing structured data like 'variables', 'dataset_analysis', 'method_info', etc.
169
+
170
+ When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
171
+
172
+ Thought: Do I need to use a tool? No
173
+ Final Answer: [your response here]
174
+
175
+ DO NOT UNDER ANY CIRCUMSTANCE CALL MORE THAN ONE TOOL IN A STEP
176
+
177
+ **IMPORTANT TOOL USAGE:**
178
+ 1. **Action Input Format:** The value for 'Action Input' MUST be a single, valid JSON object string. Do NOT include any other text or formatting around the JSON string.
179
+ 2. **Argument Gathering:** You MUST gather ALL required arguments for the Action Input JSON from the initial Human input AND the 'Observation' outputs of PREVIOUS steps. Look carefully at the required arguments for the tool you are calling.
180
+ 3. **Data Handoff:** The 'Observation' from a previous step often contains structured data needed by the next tool. For example, the 'variables' output from `query_interpreter_tool` contains fields like `treatment_variable`, `outcome_variable`, `covariates`, `time_variable`, `instrument_variable`, `running_variable`, `cutoff_value`, and `is_rct`. When calling `method_selector_tool`, you MUST construct its required `variables` input argument by including **ALL** these relevant fields identified by the `query_interpreter_tool` in the previous Observation. Similarly, pass the full `dataset_analysis`, `dataset_description`, and `original_query` when required by the next tool.
181
+
182
+ IMPORTANT WORKFLOW:
183
+ -------------------
184
+ You must follow this exact workflow, selecting the appropriate tool for each step:
185
+
186
+ 1. ALWAYS start with `input_parser_tool` to understand the query
187
+ 2. THEN use `dataset_analyzer_tool` to analyze the dataset
188
+ 3. THEN use `query_interpreter_tool` to identify variables (output includes `variables` and `dataset_analysis`)
189
+ 4. THEN use `method_selector_tool` (input requires `variables` and `dataset_analysis` from previous step)
190
+ 5. THEN use `method_validator_tool` (input requires `method_info` and `variables` from previous step)
191
+ 6. THEN use `method_executor_tool` (input requires `method`, `variables`, `dataset_path`)
192
+ 7. THEN use `explanation_generator_tool` (input requires results, method_info, variables, etc.)
193
+ 8. FINALLY use `output_formatter_tool` to return the results
194
+
195
+ REASONING PROCESS:
196
+ ------------------
197
+ EXPLICITLY REASON about:
198
+ 1. What step you're currently on (based on previous tool's Observation)
199
+ 2. Why you're selecting a particular tool (should follow the workflow)
200
+ 3. How the output of the previous tool (especially structured data like `variables`, `dataset_analysis`, `method_info`) informs the inputs required for the current tool.
201
+
202
+ IMPORTANT RULES:
203
+ 1. Do not make more than one tool call in a single step.
204
+ 2. Do not include ``` in your output at all.
205
+ 3. Don't use action names like default_api.dataset_analyzer_tool, instead use tool names like dataset_analyzer_tool.
206
+ 4. Always start, action, and observation with a new line.
207
+ 5. Don't use '\\' before double quotes
208
+ 6. Don't include ```json for Action Input. Also ensure that Action Input is a valid json. DO no add any text after Action Iput.
209
+ 7. You have to always choose one of the tools unless it's the final answer.
210
+ Begin!
211
+ """
212
+
213
+ # Create the prompt template
214
+ prompt = ChatPromptTemplate.from_messages([
215
+ ("system", system_template),
216
+ MessagesPlaceholder("chat_history", optional=True), # Use MessagesPlaceholder
217
+ # MessagesPlaceholder("agent_scratchpad"),
218
+
219
+ ("human", "{input}\n Thought:{agent_scratchpad}"),
220
+ # ("ai", "{agent_scratchpad}"),
221
+ # MessagesPlaceholder("agent_scratchpad" ), # Use MessagesPlaceholder
222
+ # "agent_scratchpad"
223
+ ])
224
+ return prompt
225
+
226
+ def create_causal_agent(llm: BaseChatModel) -> AgentExecutor:
227
+ """
228
+ Create and configure the LangChain agent with causal inference tools.
229
+ (Using explicit LCEL construction, compatible with shared LLM client)
230
+ """
231
+ # Define tools available to the agent
232
+ agent_tools = [
233
+ input_parser_tool,
234
+ dataset_analyzer_tool,
235
+ query_interpreter_tool,
236
+ method_selector_tool,
237
+ method_validator_tool,
238
+ method_executor_tool,
239
+ explanation_generator_tool,
240
+ output_formatter_tool
241
+ ]
242
+ # anthropic_agent_tools = [ convert_to_anthropic_tool(anthropic_tool) for anthropic_tool in agent_tools]
243
+ # Create the prompt using the helper
244
+ prompt = create_agent_prompt(agent_tools)
245
+ # Bind tools to the LLM (using the passed shared instance)
246
+
247
+
248
+ # Create memory
249
+ # Consider if memory needs to be passed in or created here
250
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
251
+
252
+ # Manually construct the agent runnable using LCEL
253
+ from langchain_anthropic.output_parsers import ToolsOutputParser
254
+ from langchain.agents.output_parsers.json import JSONAgentOutputParser
255
+ # from langchain.agents.react.output_parser import MultiActionAgentOutputParsers ReActMultiInputOutputParser
256
+ provider = os.getenv("LLM_PROVIDER", "openai")
257
+ if provider == "gemini":
258
+ base_parser=ReActMultiInputOutputParser()
259
+ llm_with_tools = llm.bind_tools(agent_tools)
260
+ else:
261
+ base_parser=ToolsAgentOutputParser()
262
+ llm_with_tools = llm.bind_tools(agent_tools, tool_choice="any")
263
+ agent = create_react_agent(llm_with_tools, agent_tools, prompt, output_parser=base_parser)
264
+
265
+
266
+ # Create executor (should now work with the manually constructed agent)
267
+ executor = AgentExecutor(
268
+ agent=agent,
269
+ tools=agent_tools,
270
+ memory=memory, # Pass the memory object
271
+ verbose=True,
272
+ callbacks=[ConsoleCallbackHandler()], # Optional: for console debugging
273
+ handle_parsing_errors=True, # Let AE handle parsing errors
274
+ max_retries = 100
275
+ )
276
+
277
+ return executor
278
+
279
+ def run_causal_analysis(query: str, dataset_path: str,
280
+ dataset_description: Optional[str] = None,
281
+ api_key: Optional[str] = None) -> Dict[str, Any]:
282
+ """
283
+ Run causal analysis on a dataset based on a user query.
284
+
285
+ Args:
286
+ query: User's causal question
287
+ dataset_path: Path to the dataset
288
+ dataset_description: Optional textual description of the dataset
289
+ api_key: Optional OpenAI API key (DEPRECATED - will be ignored)
290
+
291
+ Returns:
292
+ Dictionary containing the final formatted analysis results from the agent's last step.
293
+ """
294
+ # Log the start of the analysis
295
+ logger.info("Starting causal analysis run...")
296
+
297
+ try:
298
+ # --- Instantiate the shared LLM client ---
299
+ model_name = os.getenv("LLM_MODEL", "gpt-4")
300
+ if model_name in ['o3', 'o4-mini', 'o3-mini']:
301
+ print('-------------------------')
302
+ shared_llm = get_llm_client()
303
+ else:
304
+ shared_llm = get_llm_client(temperature=0) # Or read provider/model from env
305
+
306
+ # --- Dependency Injection Note (REMAINS RELEVANT) ---
307
+ # If tools need the LLM, they must be adapted. Example using partial:
308
+ # from functools import partial
309
+ # from .components import input_parser
310
+ # # Assume input_parser.parse_input needs llm
311
+ # input_parser_tool_with_llm = tool(partial(input_parser.parse_input, llm=shared_llm))
312
+ # Use input_parser_tool_with_llm in the tools list passed to the agent below.
313
+ # Similar adjustments needed for decision_tree._recommend_ps_method if used.
314
+ # --- End Note ---
315
+
316
+ # --- Create agent using the shared LLM ---
317
+ # agent_executor = create_causal_agent(shared_llm)
318
+
319
+ # Construct input, including description if available
320
+ # IMPORTANT: Agent now expects 'input' and potentially 'chat_history'
321
+ # The input needs to contain all initial info the first tool might need.
322
+ input_text = f"My question is: {query}\n"
323
+ input_text += f"The dataset is located at: {dataset_path}\n"
324
+ if dataset_description:
325
+ input_text += f"Dataset Description: {dataset_description}\n"
326
+ input_text += "Please perform the causal analysis following the workflow."
327
+
328
+ # Log the constructed input text
329
+ logger.info(f"Constructed input for agent: \n{input_text}")
330
+
331
+ input_parsing_result = input_parser_tool(input_text)
332
+ dataset_analysis_result = dataset_analyzer_tool.func(dataset_path=input_parsing_result["dataset_path"], dataset_description=input_parsing_result["dataset_description"], original_query=input_parsing_result["original_query"]).analysis_results
333
+ query_info = QueryInfo(
334
+ query_text=input_parsing_result["original_query"],
335
+ potential_treatments=input_parsing_result["extracted_variables"].get("treatment"),
336
+ potential_outcomes=input_parsing_result["extracted_variables"].get("outcome"),
337
+ covariates_hints=input_parsing_result["extracted_variables"].get("covariates_mentioned"),
338
+ instrument_hints=input_parsing_result["extracted_variables"].get("instruments_mentioned")
339
+ )
340
+
341
+ query_interpreter_output = query_interpreter_tool.func(query_info=query_info, dataset_analysis=dataset_analysis_result, dataset_description=input_parsing_result["dataset_description"], original_query = input_parsing_result["original_query"]).variables
342
+ method_selector_output = method_selector_tool.func(variables=query_interpreter_output,
343
+ dataset_analysis=dataset_analysis_result,
344
+ dataset_description=input_parsing_result["dataset_description"],
345
+ original_query = input_parsing_result["original_query"],
346
+ excluded_methods=None)
347
+ method_info = MethodInfo(
348
+ **method_selector_output['method_info']
349
+ )
350
+ method_validator_input = MethodValidatorInput(
351
+ method_info=method_info,
352
+ variables=query_interpreter_output,
353
+ dataset_analysis=dataset_analysis_result,
354
+ dataset_description=input_parsing_result["dataset_description"],
355
+ original_query = input_parsing_result["original_query"]
356
+ )
357
+ method_validator_output = method_validator_tool.func(method_validator_input)
358
+ method_executor_input = MethodExecutorInput(
359
+ **method_validator_output
360
+ )
361
+ method_executor_output = method_executor_tool.func(method_executor_input, original_query = input_parsing_result["original_query"])
362
+
363
+ explainer_output = explanation_generator_tool.func( method_info=method_info,
364
+ validation_info=method_validator_output,
365
+ variables=query_interpreter_output,
366
+ results=method_executor_output,
367
+ dataset_analysis=dataset_analysis_result,
368
+ dataset_description=input_parsing_result["dataset_description"],
369
+ original_query = input_parsing_result["original_query"])
370
+ result = explainer_output
371
+ result['results']['results']["method_used"] = method_validator_output['method']
372
+ logger.info(result)
373
+ logger.info("Causal analysis run finished.")
374
+
375
+ # Ensure result is a dict and extract the 'output' part
376
+ if isinstance(result, dict):
377
+ final_output = result
378
+ if isinstance(final_output, dict):
379
+ return final_output # Return only the dictionary from the final tool
380
+ else:
381
+ logger.error(f"Agent result['output'] was not a dictionary: {type(final_output)}. Returning error dict.")
382
+ return {"error": "Agent did not produce the expected dictionary output in the 'output' key.", "raw_agent_result": result}
383
+ else:
384
+ logger.error(f"Agent returned non-dict type: {type(result)}. Returning error dict.")
385
+ return {"error": "Agent did not return expected dictionary output.", "raw_output": str(result)}
386
+
387
+ except ValueError as e:
388
+ logger.error(f"Configuration Error: {e}")
389
+ # Return an error dictionary in case of exception too
390
+ return {"error": f"Error: Configuration issue - {e}"} # Ensure consistent error return type
391
+ except Exception as e:
392
+ logger.error(f"An unexpected error occurred during causal analysis: {e}", exc_info=True)
393
+ # Return an error dictionary in case of exception too
394
+ return {"error": f"An unexpected error occurred: {e}"}
auto_causal/components/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Auto Causal components package.
3
+
4
+ This package contains the core components for the auto_causal module,
5
+ each handling a specific part of the causal inference workflow.
6
+ """
7
+
8
+ from auto_causal.components.input_parser import parse_input
9
+ from auto_causal.components.dataset_analyzer import analyze_dataset
10
+ from auto_causal.components.query_interpreter import interpret_query
11
+ from auto_causal.components.decision_tree import select_method
12
+ from auto_causal.components.method_validator import validate_method
13
+ from auto_causal.components.explanation_generator import generate_explanation
14
+ from auto_causal.components.output_formatter import format_output
15
+ from auto_causal.components.state_manager import create_workflow_state_update
16
+
17
+ __all__ = [
18
+ "parse_input",
19
+ "analyze_dataset",
20
+ "interpret_query",
21
+ "select_method",
22
+ "validate_method",
23
+ "generate_explanation",
24
+ "format_output",
25
+ "create_workflow_state_update"
26
+ ]
27
+
28
+ # This file makes Python treat the directory as a package.
auto_causal/components/dataset_analyzer.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset analyzer component for causal inference.
3
+
4
+ This module provides functionality to analyze datasets to detect characteristics
5
+ relevant for causal inference methods, including temporal structure, potential
6
+ instrumental variables, discontinuities, and variable relationships.
7
+ """
8
+
9
+ import os
10
+ import pandas as pd
11
+ import numpy as np
12
+ from typing import Dict, List, Any, Optional, Tuple
13
+ from scipy import stats
14
+ import logging
15
+ import json
16
+ from langchain_core.language_models import BaseChatModel
17
+ from auto_causal.utils.llm_helpers import llm_identify_temporal_and_unit_vars
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ def _calculate_per_group_stats(df: pd.DataFrame, potential_treatments: List[str]) -> Dict[str, Dict]:
22
+ """Calculates summary stats for numeric covariates grouped by potential binary treatments."""
23
+ stats_dict = {}
24
+ numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
25
+
26
+ for treat_var in potential_treatments:
27
+ if treat_var not in df.columns:
28
+ logger.warning(f"Potential treatment '{treat_var}' not found in DataFrame columns.")
29
+ continue
30
+
31
+ # Ensure treatment is binary (0/1 or similar)
32
+ unique_vals = df[treat_var].dropna().unique()
33
+ if len(unique_vals) != 2:
34
+ logger.info(f"Skipping stats for potential treatment '{treat_var}' as it is not binary ({len(unique_vals)} unique values).")
35
+ continue
36
+
37
+ # Attempt to map values to 0 and 1 if possible
38
+ try:
39
+ # Ensure boolean is converted to int
40
+ if df[treat_var].dtype == 'bool':
41
+ df[treat_var] = df[treat_var].astype(int)
42
+ unique_vals = df[treat_var].dropna().unique()
43
+
44
+ # Basic check if values are interpretable as 0/1
45
+ if not set(unique_vals).issubset({0, 1}):
46
+ # Attempt conversion if possible (e.g., True/False strings?)
47
+ logger.warning(f"Potential treatment '{treat_var}' has values {unique_vals}, not {0, 1}. Cannot calculate group stats reliably.")
48
+ continue
49
+ except Exception as e:
50
+ logger.warning(f"Could not process potential treatment '{treat_var}' values ({unique_vals}): {e}")
51
+ continue
52
+
53
+ logger.info(f"Calculating group stats for treatment: '{treat_var}'")
54
+ treat_stats = {'group_sizes': {}, 'covariate_stats': {}}
55
+
56
+ try:
57
+ grouped = df.groupby(treat_var)
58
+ sizes = grouped.size()
59
+ treat_stats['group_sizes']['treated'] = int(sizes.get(1, 0))
60
+ treat_stats['group_sizes']['control'] = int(sizes.get(0, 0))
61
+
62
+ if treat_stats['group_sizes']['treated'] == 0 or treat_stats['group_sizes']['control'] == 0:
63
+ logger.warning(f"Treatment '{treat_var}' has zero samples in one group. Skipping covariate stats.")
64
+ stats_dict[treat_var] = treat_stats
65
+ continue
66
+
67
+ # Calculate mean and std for numeric covariates
68
+ cov_stats = grouped[numeric_cols].agg(['mean', 'std']).unstack()
69
+
70
+ for cov in numeric_cols:
71
+ if cov == treat_var: continue # Skip treatment variable itself
72
+
73
+ mean_control = cov_stats.get(('mean', 0, cov), np.nan)
74
+ std_control = cov_stats.get(('std', 0, cov), np.nan)
75
+ mean_treated = cov_stats.get(('mean', 1, cov), np.nan)
76
+ std_treated = cov_stats.get(('std', 1, cov), np.nan)
77
+
78
+ treat_stats['covariate_stats'][cov] = {
79
+ 'mean_control': float(mean_control) if pd.notna(mean_control) else None,
80
+ 'std_control': float(std_control) if pd.notna(std_control) else None,
81
+ 'mean_treat': float(mean_treated) if pd.notna(mean_treated) else None,
82
+ 'std_treat': float(std_treated) if pd.notna(std_treated) else None,
83
+ }
84
+ stats_dict[treat_var] = treat_stats
85
+ except Exception as e:
86
+ logger.error(f"Error calculating stats for treatment '{treat_var}': {e}", exc_info=True)
87
+ # Store partial info if possible
88
+ if treat_var not in stats_dict:
89
+ stats_dict[treat_var] = {'error': str(e)}
90
+ elif 'error' not in stats_dict[treat_var]:
91
+ stats_dict[treat_var]['error'] = str(e)
92
+
93
+ return stats_dict
94
+
95
+ def analyze_dataset(
96
+ dataset_path: str,
97
+ llm_client: Optional[BaseChatModel] = None,
98
+ dataset_description: Optional[str] = None,
99
+ original_query: Optional[str] = None
100
+ ) -> Dict[str, Any]:
101
+ """
102
+ Analyze a dataset to identify important characteristics for causal inference.
103
+
104
+ Args:
105
+ dataset_path: Path to the dataset file
106
+ llm_client: Optional LLM client for enhanced analysis
107
+ dataset_description: Optional description of the dataset for context
108
+
109
+ Returns:
110
+ Dict containing dataset analysis results:
111
+ - dataset_info: Basic information about the dataset
112
+ - columns: List of column names
113
+ - potential_treatments: List of potential treatment variables (possibly LLM augmented)
114
+ - potential_outcomes: List of potential outcome variables (possibly LLM augmented)
115
+ - temporal_structure_detected: Whether temporal structure was detected
116
+ - panel_data_detected: Whether panel data structure was detected
117
+ - potential_instruments_detected: Whether potential instruments were detected
118
+ - discontinuities_detected: Whether discontinuities were detected
119
+ - llm_augmentation: Status of LLM augmentation if used
120
+ """
121
+ llm_augmentation = "Not used" if not llm_client else "Initialized"
122
+
123
+ # Check if file exists
124
+ if not os.path.exists(dataset_path):
125
+ logger.error(f"Dataset file not found at {dataset_path}")
126
+ return {"error": f"Dataset file not found at {dataset_path}"}
127
+
128
+ try:
129
+ # Load the dataset
130
+ df = pd.read_csv(dataset_path)
131
+
132
+ # Basic dataset information
133
+ sample_size = len(df)
134
+ columns_list = df.columns.tolist()
135
+ num_covariates = len(columns_list) - 2 # Rough estimate (total - T - Y)
136
+ dataset_info = {
137
+ "num_rows": sample_size,
138
+ "num_columns": len(columns_list),
139
+ "file_path": dataset_path,
140
+ "file_name": os.path.basename(dataset_path)
141
+ }
142
+
143
+ # --- Detailed Analysis (Keep internal) ---
144
+ column_types_detailed = {col: str(df[col].dtype) for col in df.columns}
145
+ missing_values_detailed = df.isnull().sum().to_dict()
146
+ column_categories_detailed = _categorize_columns(df)
147
+ column_nunique_counts_detailed = {col: df[col].nunique() for col in df.columns} # Calculate nunique
148
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
149
+ correlations_detailed = df[numeric_cols].corr() if numeric_cols else pd.DataFrame()
150
+ temporal_structure_detailed = detect_temporal_structure(df, llm_client, dataset_description, original_query)
151
+
152
+ # First, identify potential treatment and outcome variables
153
+ potential_variables = _identify_potential_variables(
154
+ df,
155
+ column_categories_detailed,
156
+ llm_client=llm_client,
157
+ dataset_description=dataset_description
158
+ )
159
+
160
+ if llm_client:
161
+ llm_augmentation = "Used for variable identification"
162
+
163
+ # Then use that info to help find potential instrumental variables
164
+ potential_instruments_detailed = find_potential_instruments(
165
+ df,
166
+ llm_client=llm_client,
167
+ potential_treatments=potential_variables.get("potential_treatments", []),
168
+ potential_outcomes=potential_variables.get("potential_outcomes", []),
169
+ dataset_description=dataset_description
170
+ )
171
+
172
+ # Other analyses
173
+ discontinuities_detailed = detect_discontinuities(df)
174
+ variable_relationships_detailed = assess_variable_relationships(df, correlations_detailed)
175
+
176
+ # Calculate per-group stats for potential binary treatments
177
+ potential_binary_treatments = [
178
+ t for t in potential_variables["potential_treatments"]
179
+ if column_categories_detailed.get(t) == 'binary'
180
+ or column_categories_detailed.get(t) == 'binary_categorical'
181
+ ]
182
+ per_group_stats = _calculate_per_group_stats(df.copy(), potential_binary_treatments)
183
+
184
+ # --- Summarized Analysis (For Output) ---
185
+
186
+ # Get boolean flags and essential lists
187
+ has_temporal = temporal_structure_detailed.get("has_temporal_structure", False)
188
+ is_panel = temporal_structure_detailed.get("is_panel_data", False)
189
+ logger.info(f"iv is {potential_instruments_detailed}")
190
+ has_instruments = len(potential_instruments_detailed) > 0
191
+ has_discontinuities = discontinuities_detailed.get("has_discontinuities", False)
192
+
193
+ # --- Extract only instrument names for the final output ---
194
+ potential_instrument_names = [
195
+ inst_dict.get('variable')
196
+ for inst_dict in potential_instruments_detailed
197
+ if isinstance(inst_dict, dict) and 'variable' in inst_dict
198
+ ]
199
+ logger.info(f"iv is {potential_instrument_names}")
200
+ # --- Final Output Dictionary (Highly Summarized) ---
201
+ return {
202
+ "dataset_info": dataset_info, # Keep basic info
203
+ "columns": columns_list,
204
+ "potential_treatments": potential_variables["potential_treatments"],
205
+ "potential_outcomes": potential_variables["potential_outcomes"],
206
+ # Return concise flags instead of detailed dicts/lists
207
+ "temporal_structure_detected": has_temporal,
208
+ "panel_data_detected": is_panel,
209
+ "potential_instruments_detected": has_instruments,
210
+ "discontinuities_detected": has_discontinuities,
211
+ # Use the extracted list of names here
212
+ "potential_instruments": potential_instrument_names,
213
+ "discontinuities": discontinuities_detailed,
214
+ "temporal_structure": temporal_structure_detailed,
215
+ "column_categories": column_categories_detailed,
216
+ "column_nunique_counts": column_nunique_counts_detailed, # Add nunique counts to output
217
+ "sample_size": sample_size,
218
+ "num_covariates_estimate": num_covariates,
219
+ "llm_augmentation": llm_augmentation
220
+ }
221
+
222
+ except Exception as e:
223
+ logger.error(f"Error analyzing dataset '{dataset_path}': {e}", exc_info=True)
224
+ return {
225
+ "error": f"Error analyzing dataset: {str(e)}",
226
+ "llm_augmentation": llm_augmentation
227
+ }
228
+
229
+
230
+ def _categorize_columns(df: pd.DataFrame) -> Dict[str, str]:
231
+ """
232
+ Categorize columns into types relevant for causal inference.
233
+
234
+ Args:
235
+ df: DataFrame to analyze
236
+
237
+ Returns:
238
+ Dict mapping column names to their types
239
+ """
240
+ result = {}
241
+
242
+ for col in df.columns:
243
+ # Check if column is numeric
244
+ if pd.api.types.is_numeric_dtype(df[col]):
245
+ # Count number of unique values
246
+ n_unique = df[col].nunique()
247
+
248
+ # Binary numeric variable
249
+ if n_unique == 2:
250
+ result[col] = "binary"
251
+ # Likely categorical represented as numeric
252
+ elif n_unique < 10:
253
+ result[col] = "categorical_numeric"
254
+ # Discrete numeric (integers)
255
+ elif pd.api.types.is_integer_dtype(df[col]):
256
+ result[col] = "discrete_numeric"
257
+ # Continuous numeric
258
+ else:
259
+ result[col] = "continuous_numeric"
260
+
261
+ # Check for datetime
262
+ elif pd.api.types.is_datetime64_any_dtype(df[col]) or _is_date_string(df, col):
263
+ result[col] = "datetime"
264
+
265
+ # Check for categorical
266
+ elif pd.api.types.is_categorical_dtype(df[col]) or df[col].nunique() < 20:
267
+ if df[col].nunique() == 2:
268
+ result[col] = "binary_categorical"
269
+ else:
270
+ result[col] = "categorical"
271
+
272
+ # Must be text or other
273
+ else:
274
+ result[col] = "text_or_other"
275
+
276
+ return result
277
+
278
+
279
+ def _is_date_string(df: pd.DataFrame, col: str) -> bool:
280
+ """
281
+ Check if a column contains date strings.
282
+
283
+ Args:
284
+ df: DataFrame to check
285
+ col: Column name to check
286
+
287
+ Returns:
288
+ True if the column appears to contain date strings
289
+ """
290
+ # Try to convert to datetime
291
+ if not pd.api.types.is_string_dtype(df[col]):
292
+ return False
293
+
294
+ # Check sample of values
295
+ sample = df[col].dropna().sample(min(10, len(df[col].dropna()))).tolist()
296
+
297
+ try:
298
+ for val in sample:
299
+ pd.to_datetime(val)
300
+ return True
301
+ except:
302
+ return False
303
+
304
+
305
+ def _identify_potential_variables(
306
+ df: pd.DataFrame,
307
+ column_categories: Dict[str, str],
308
+ llm_client: Optional[BaseChatModel] = None,
309
+ dataset_description: Optional[str] = None
310
+ ) -> Dict[str, List[str]]:
311
+ """
312
+ Identify potential treatment and outcome variables in the dataset, using LLM if available.
313
+ Falls back to heuristic method if LLM fails or is not available.
314
+
315
+ Args:
316
+ df: DataFrame to analyze
317
+ column_categories: Dictionary mapping column names to their types
318
+ llm_client: Optional LLM client for enhanced identification
319
+ dataset_description: Optional description of the dataset for context
320
+
321
+ Returns:
322
+ Dict with potential treatment and outcome variables
323
+ """
324
+ # Try LLM approach if client is provided
325
+ if llm_client:
326
+ try:
327
+ logger.info("Using LLM to identify potential treatment and outcome variables")
328
+
329
+ # Create a concise prompt with just column information
330
+ columns_list = df.columns.tolist()
331
+ column_types = {col: str(df[col].dtype) for col in columns_list}
332
+
333
+ # Get binary columns for extra context
334
+ binary_cols = [col for col in columns_list
335
+ if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() == 2]
336
+
337
+ # Add dataset description if available
338
+ description_text = f"\nDataset Description: {dataset_description}" if dataset_description else ""
339
+
340
+ prompt = f"""
341
+ You are an expert causal inference data scientist. Identify potential treatment and outcome variables from this dataset.{description_text}
342
+
343
+ Dataset columns:
344
+ {columns_list}
345
+
346
+ Column types:
347
+ {column_types}
348
+
349
+ Binary columns (good treatment candidates):
350
+ {binary_cols}
351
+
352
+ Instructions:
353
+ 1. Identify TREATMENT variables: interventions, treatments, programs, policies, or binary state changes.
354
+ Look for binary variables or names with 'treatment', 'intervention', 'program', 'policy', etc.
355
+
356
+ 2. Identify OUTCOME variables: results, effects, or responses to treatments.
357
+ Look for numeric variables (especially non-binary) or names with 'outcome', 'result', 'effect', 'score', etc.
358
+
359
+ Return ONLY a valid JSON object with two lists: "potential_treatments" and "potential_outcomes".
360
+ Example: {{"potential_treatments": ["treatment_a", "program_b"], "potential_outcomes": ["result_score", "outcome_measure"]}}
361
+ """
362
+
363
+ # Call the LLM and parse the response
364
+ response = llm_client.invoke(prompt)
365
+ response_text = response.content if hasattr(response, 'content') else str(response)
366
+
367
+ # Extract JSON from the response text
368
+ import re
369
+ json_match = re.search(r'{.*}', response_text, re.DOTALL)
370
+
371
+ if json_match:
372
+ result = json.loads(json_match.group(0))
373
+
374
+ # Validate the response
375
+ if (isinstance(result, dict) and
376
+ "potential_treatments" in result and
377
+ "potential_outcomes" in result and
378
+ isinstance(result["potential_treatments"], list) and
379
+ isinstance(result["potential_outcomes"], list)):
380
+
381
+ # Ensure all suggestions are valid columns
382
+ valid_treatments = [col for col in result["potential_treatments"] if col in df.columns]
383
+ valid_outcomes = [col for col in result["potential_outcomes"] if col in df.columns]
384
+
385
+ if valid_treatments and valid_outcomes:
386
+ logger.info(f"LLM identified {len(valid_treatments)} treatments and {len(valid_outcomes)} outcomes")
387
+ return {
388
+ "potential_treatments": valid_treatments,
389
+ "potential_outcomes": valid_outcomes
390
+ }
391
+ else:
392
+ logger.warning("LLM suggested invalid columns, falling back to heuristic method")
393
+ else:
394
+ logger.warning("Invalid LLM response format, falling back to heuristic method")
395
+ else:
396
+ logger.warning("Could not extract JSON from LLM response, falling back to heuristic method")
397
+
398
+ except Exception as e:
399
+ logger.error(f"Error in LLM identification: {e}", exc_info=True)
400
+ logger.info("Falling back to heuristic method")
401
+
402
+ # Fallback to heuristic method
403
+ logger.info("Using heuristic method to identify potential treatment and outcome variables")
404
+
405
+ # Identify potential treatment variables
406
+ potential_treatments = []
407
+
408
+ # Look for binary variables (good treatment candidates)
409
+ binary_cols = [col for col in df.columns
410
+ if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() == 2]
411
+
412
+ # Look for variables with names suggesting treatment
413
+ treatment_keywords = ['treatment', 'treat', 'intervention', 'program', 'policy',
414
+ 'exposed', 'assigned', 'received', 'participated']
415
+
416
+ for col in df.columns:
417
+ col_lower = col.lower()
418
+ if any(keyword in col_lower for keyword in treatment_keywords):
419
+ potential_treatments.append(col)
420
+
421
+ # Add binary variables if we don't have enough candidates
422
+ if len(potential_treatments) < 3:
423
+ for col in binary_cols:
424
+ if col not in potential_treatments:
425
+ potential_treatments.append(col)
426
+ if len(potential_treatments) >= 3:
427
+ break
428
+
429
+ # Identify potential outcome variables
430
+ potential_outcomes = []
431
+
432
+ # Look for numeric variables that aren't binary
433
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
434
+ non_binary_numeric = [col for col in numeric_cols if col not in binary_cols]
435
+
436
+ # Look for variables with names suggesting outcomes
437
+ outcome_keywords = ['outcome', 'result', 'effect', 'response', 'score', 'performance',
438
+ 'achievement', 'success', 'failure', 'improvement']
439
+
440
+ for col in df.columns:
441
+ col_lower = col.lower()
442
+ if any(keyword in col_lower for keyword in outcome_keywords):
443
+ potential_outcomes.append(col)
444
+
445
+ # Add numeric non-binary variables if we don't have enough candidates
446
+ if len(potential_outcomes) < 3:
447
+ for col in non_binary_numeric:
448
+ if col not in potential_outcomes and col not in potential_treatments:
449
+ potential_outcomes.append(col)
450
+ if len(potential_outcomes) >= 3:
451
+ break
452
+
453
+ return {
454
+ "potential_treatments": potential_treatments,
455
+ "potential_outcomes": potential_outcomes
456
+ }
457
+
458
+
459
+ def detect_temporal_structure(
460
+ df: pd.DataFrame,
461
+ llm_client: Optional[BaseChatModel] = None,
462
+ dataset_description: Optional[str] = None,
463
+ original_query: Optional[str] = None
464
+ ) -> Dict[str, Any]:
465
+ """
466
+ Detect temporal structure in the dataset, using LLM for enhanced identification.
467
+
468
+ Args:
469
+ df: DataFrame to analyze
470
+ llm_client: Optional LLM client for enhanced identification
471
+ dataset_description: Optional description of the dataset for context
472
+
473
+ Returns:
474
+ Dict with information about temporal structure:
475
+ - has_temporal_structure: Whether temporal structure exists
476
+ - temporal_columns: Primary time column identified (or list if multiple from heuristic)
477
+ - is_panel_data: Whether data is in panel format
478
+ - time_column: Primary time column identified for panel data
479
+ - id_column: Primary unit ID column identified for panel data
480
+ - time_periods: Number of time periods (if panel data)
481
+ - units: Number of unique units (if panel data)
482
+ - identification_method: How time/unit vars were identified ('LLM', 'Heuristic', 'None')
483
+ """
484
+ result = {
485
+ "has_temporal_structure": False,
486
+ "temporal_columns": [], # Will store primary time column or heuristic list
487
+ "is_panel_data": False,
488
+ "time_column": None,
489
+ "id_column": None,
490
+ "time_periods": None,
491
+ "units": None,
492
+ "identification_method": "None"
493
+ }
494
+
495
+ # --- Step 1: Heuristic identification (as before) ---
496
+ #heuristic_datetime_cols = []
497
+ #for col in df.columns:
498
+ # if pd.api.types.is_datetime64_any_dtype(df[col]):
499
+ # heuristic_datetime_cols.append(col)
500
+ # elif pd.api.types.is_string_dtype(df[col]):
501
+ # try:
502
+ # if pd.to_datetime(df[col], errors='coerce').notna().any():
503
+ # heuristic_datetime_cols.append(col)
504
+ # except:
505
+ # pass # Ignore conversion errors
506
+
507
+ #time_keywords = ['year', 'month', 'day', 'date', 'time', 'period', 'quarter', 'week']
508
+ #for col in df.columns:
509
+ # col_lower = col.lower()
510
+ # if any(keyword in col_lower for keyword in time_keywords) and col not in heuristic_datetime_cols:
511
+ # heuristic_datetime_cols.append(col)
512
+
513
+ #id_keywords = ['id', 'individual', 'person', 'unit', 'entity', 'firm', 'company', 'state', 'country']
514
+ #heuristic_potential_id_cols = []
515
+ #for col in df.columns:
516
+ # col_lower = col.lower()
517
+ # # Exclude columns already identified as time-related by heuristics
518
+ # if any(keyword in col_lower for keyword in id_keywords) and col not in heuristic_datetime_cols:
519
+ # heuristic_potential_id_cols.append(col)
520
+
521
+ # --- Step 2: LLM-assisted identification ---
522
+ llm_identified_time_var = None
523
+ llm_identified_unit_var = None
524
+ heuristic_datetime_cols = []
525
+ heuristic_potential_id_cols = []
526
+ dataset_summary = df.describe(include='all')
527
+
528
+ if llm_client:
529
+ logger.info("Attempting LLM-assisted identification of temporal/unit variables.")
530
+ column_names = df.columns.tolist()
531
+ column_dtypes_dict = {col: str(df[col].dtype) for col in column_names}
532
+
533
+ try:
534
+ llm_suggestions = llm_identify_temporal_and_unit_vars(
535
+ column_names=column_names,
536
+ column_dtypes=column_dtypes_dict,
537
+ dataset_description=dataset_description if dataset_description else "No dataset description provided.",
538
+ dataset_summary=dataset_summary,
539
+ heuristic_time_candidates=heuristic_datetime_cols,
540
+ heuristic_id_candidates=heuristic_potential_id_cols,
541
+ query=original_query if original_query else "No query provided.",
542
+ llm=llm_client
543
+ )
544
+ llm_identified_time_var = llm_suggestions.get("time_variable")
545
+ llm_identified_unit_var = llm_suggestions.get("unit_variable")
546
+ result["identification_method"] = "LLM"
547
+
548
+ if not llm_identified_time_var and not llm_identified_unit_var:
549
+ result["identification_method"] = "LLM_NoIdentification"
550
+ except Exception as e:
551
+ logger.warning(f"LLM call for temporal/unit vars failed: {e}. Falling back to heuristics.")
552
+ result["identification_method"] = "Heuristic_LLM_Error"
553
+ else:
554
+ result["identification_method"] = "Heuristic_NoLLM"
555
+
556
+ # --- Step 3: Combine LLM and Heuristic Results ---
557
+ final_time_var = None
558
+ final_unit_var = None
559
+
560
+ if llm_identified_time_var:
561
+ final_time_var = llm_identified_time_var
562
+ logger.info(f"Prioritizing LLM identified time variable: {final_time_var}")
563
+ elif heuristic_datetime_cols:
564
+ final_time_var = heuristic_datetime_cols[0] # Fallback to first heuristic time col
565
+ logger.info(f"Using heuristic time variable: {final_time_var}")
566
+
567
+ if llm_identified_unit_var:
568
+ final_unit_var = llm_identified_unit_var
569
+ logger.info(f"Prioritizing LLM identified unit variable: {final_unit_var}")
570
+ elif heuristic_potential_id_cols:
571
+ final_unit_var = heuristic_potential_id_cols[0] # Fallback to first heuristic ID col
572
+ logger.info(f"Using heuristic unit variable: {final_unit_var}")
573
+
574
+ # Update results based on final selections
575
+ if final_time_var:
576
+ result["has_temporal_structure"] = True
577
+ result["temporal_columns"] = [final_time_var] # Store as a list with the primary time var
578
+ result["time_column"] = final_time_var
579
+ else: # If no time var found by LLM or heuristic, use original heuristic list for temporal_columns
580
+ if heuristic_datetime_cols:
581
+ result["has_temporal_structure"] = True
582
+ result["temporal_columns"] = heuristic_datetime_cols
583
+ # time_column remains None
584
+
585
+ if final_unit_var:
586
+ result["id_column"] = final_unit_var
587
+
588
+ # --- Step 4: Update Panel Data Logic (based on final_time_var and final_unit_var) ---
589
+ if final_time_var and final_unit_var:
590
+ # Check if there are multiple time periods per unit using the identified variables
591
+ try:
592
+ # Ensure columns exist before groupby
593
+ if final_time_var in df.columns and final_unit_var in df.columns:
594
+ if df.groupby(final_unit_var)[final_time_var].nunique().mean() > 1.0:
595
+ result["is_panel_data"] = True
596
+ result["time_periods"] = df[final_time_var].nunique()
597
+ result["units"] = df[final_unit_var].nunique()
598
+ logger.info(f"Panel data detected: Time='{final_time_var}', Unit='{final_unit_var}', Periods={result['time_periods']}, Units={result['units']}")
599
+ else:
600
+ logger.info("Not panel data: Each unit does not have multiple time periods.")
601
+ else:
602
+ logger.warning(f"Final time ('{final_time_var}') or unit ('{final_unit_var}') var not in DataFrame. Cannot confirm panel structure.")
603
+ except Exception as e:
604
+ logger.error(f"Error checking panel data structure with time='{final_time_var}', unit='{final_unit_var}': {e}")
605
+ result["is_panel_data"] = False # Default to false on error
606
+ else:
607
+ logger.info("Not panel data: Missing either time or unit variable for panel structure.")
608
+
609
+ logger.debug(f"Final temporal structure detection result: {result}")
610
+ return result
611
+
612
+
613
+ def find_potential_instruments(
614
+ df: pd.DataFrame,
615
+ llm_client: Optional[BaseChatModel] = None,
616
+ potential_treatments: List[str] = None,
617
+ potential_outcomes: List[str] = None,
618
+ dataset_description: Optional[str] = None
619
+ ) -> List[Dict[str, Any]]:
620
+ """
621
+ Find potential instrumental variables in the dataset, using LLM if available.
622
+ Falls back to heuristic method if LLM fails or is not available.
623
+
624
+ Args:
625
+ df: DataFrame to analyze
626
+ llm_client: Optional LLM client for enhanced identification
627
+ potential_treatments: Optional list of potential treatment variables
628
+ potential_outcomes: Optional list of potential outcome variables
629
+ dataset_description: Optional description of the dataset for context
630
+
631
+ Returns:
632
+ List of potential instrumental variables with their properties
633
+ """
634
+ # Try LLM approach if client is provided
635
+ if llm_client:
636
+ try:
637
+ logger.info("Using LLM to identify potential instrumental variables")
638
+
639
+ # Create a concise prompt with just column information
640
+ columns_list = df.columns.tolist()
641
+
642
+ # Exclude known treatment and outcome variables from consideration
643
+ excluded_columns = []
644
+ if potential_treatments:
645
+ excluded_columns.extend(potential_treatments)
646
+ if potential_outcomes:
647
+ excluded_columns.extend(potential_outcomes)
648
+
649
+ # Filter columns to exclude treatments and outcomes
650
+ candidate_columns = [col for col in columns_list if col not in excluded_columns]
651
+
652
+ if not candidate_columns:
653
+ logger.warning("No eligible columns for instrumental variables after filtering treatments and outcomes")
654
+ return []
655
+
656
+ # Get column types for context
657
+ column_types = {col: str(df[col].dtype) for col in candidate_columns}
658
+
659
+ # Add dataset description if available
660
+ description_text = f"\nDataset Description: {dataset_description}" if dataset_description else ""
661
+
662
+ prompt = f"""
663
+ You are an expert causal inference data scientist. Identify potential instrumental variables from this dataset.{description_text}
664
+
665
+ DEFINITION: Instrumental variables must:
666
+ 1. Be correlated with the treatment variable (relevance)
667
+ 2. Only affect the outcome through the treatment (exclusion restriction)
668
+ 3. Not be correlated with unmeasured confounders (exogeneity)
669
+
670
+ Treatment variables: {potential_treatments if potential_treatments else "Unknown"}
671
+ Outcome variables: {potential_outcomes if potential_outcomes else "Unknown"}
672
+
673
+ Available columns (excluding treatments and outcomes):
674
+ {candidate_columns}
675
+
676
+ Column types:
677
+ {column_types}
678
+
679
+ Look for variables likely to be:
680
+ - Random assignments
681
+ - Policy changes
682
+ - Geographic or temporal variations
683
+ - Variables with names containing: 'instrument', 'iv', 'assigned', 'random', 'lottery', 'exogenous'
684
+
685
+ Return ONLY a JSON array of objects, each with "variable", "reason", and "data_type" fields.
686
+ Example:
687
+ [
688
+ {{"variable": "random_assignment", "reason": "Random assignment variable", "data_type": "int64"}},
689
+ {{"variable": "distance_to_facility", "reason": "Geographic variation", "data_type": "float64"}}
690
+ ]
691
+ """
692
+
693
+ # Call the LLM and parse the response
694
+ response = llm_client.invoke(prompt)
695
+ response_text = response.content if hasattr(response, 'content') else str(response)
696
+
697
+ # Extract JSON from the response text
698
+ import re
699
+ json_match = re.search(r'\[\s*{.*}\s*\]', response_text, re.DOTALL)
700
+
701
+ if json_match:
702
+ result = json.loads(json_match.group(0))
703
+
704
+ # Validate the response
705
+ if isinstance(result, list) and len(result) > 0:
706
+ # Filter for valid entries
707
+ valid_instruments = []
708
+ for item in result:
709
+ if not isinstance(item, dict) or "variable" not in item:
710
+ continue
711
+
712
+ if item["variable"] not in df.columns:
713
+ continue
714
+
715
+ # Ensure all required fields are present
716
+ if "reason" not in item:
717
+ item["reason"] = "Identified by LLM"
718
+ if "data_type" not in item:
719
+ item["data_type"] = str(df[item["variable"]].dtype)
720
+
721
+ valid_instruments.append(item)
722
+
723
+ if valid_instruments:
724
+ logger.info(f"LLM identified {len(valid_instruments)} potential instrumental variables {valid_instruments}")
725
+ return valid_instruments
726
+ else:
727
+ logger.warning("No valid instruments found by LLM, falling back to heuristic method")
728
+ else:
729
+ logger.warning("Invalid LLM response format, falling back to heuristic method")
730
+ else:
731
+ logger.warning("Could not extract JSON from LLM response, falling back to heuristic method")
732
+
733
+ except Exception as e:
734
+ logger.error(f"Error in LLM identification of instruments: {e}", exc_info=True)
735
+ logger.info("Falling back to heuristic method")
736
+
737
+ # Fallback to heuristic method
738
+ logger.info("Using heuristic method to identify potential instrumental variables")
739
+ potential_instruments = []
740
+
741
+ # Look for variables with instrumental-related names
742
+ instrument_keywords = ['instrument', 'iv', 'assigned', 'random', 'lottery', 'exogenous']
743
+
744
+ for col in df.columns:
745
+ # Skip treatment and outcome variables
746
+ if potential_treatments and col in potential_treatments:
747
+ continue
748
+ if potential_outcomes and col in potential_outcomes:
749
+ continue
750
+
751
+ col_lower = col.lower()
752
+ if any(keyword in col_lower for keyword in instrument_keywords):
753
+ instrument_info = {
754
+ "variable": col,
755
+ "reason": f"Name contains instrument-related keyword",
756
+ "data_type": str(df[col].dtype)
757
+ }
758
+ potential_instruments.append(instrument_info)
759
+
760
+ return potential_instruments
761
+
762
+
763
+ def detect_discontinuities(df: pd.DataFrame) -> Dict[str, Any]:
764
+ """
765
+ Identify discontinuities in continuous variables (for RDD).
766
+
767
+ Args:
768
+ df: DataFrame to analyze
769
+
770
+ Returns:
771
+ Dict with information about detected discontinuities
772
+ """
773
+ discontinuities = []
774
+
775
+ # For each numeric column, check for potential discontinuities
776
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
777
+
778
+ for col in numeric_cols:
779
+ # Skip columns with too many unique values
780
+ if df[col].nunique() > 100:
781
+ continue
782
+
783
+ values = df[col].dropna().sort_values().values
784
+
785
+ # Calculate gaps between consecutive values
786
+ if len(values) > 10:
787
+ gaps = np.diff(values)
788
+ mean_gap = np.mean(gaps)
789
+ std_gap = np.std(gaps)
790
+
791
+ # Look for unusually large gaps (potential discontinuities)
792
+ large_gaps = np.where(gaps > mean_gap + 2*std_gap)[0]
793
+
794
+ if len(large_gaps) > 0:
795
+ for idx in large_gaps:
796
+ cutpoint = (values[idx] + values[idx+1]) / 2
797
+ discontinuities.append({
798
+ "variable": col,
799
+ "cutpoint": float(cutpoint),
800
+ "gap_size": float(gaps[idx]),
801
+ "mean_gap": float(mean_gap)
802
+ })
803
+
804
+ return {
805
+ "has_discontinuities": len(discontinuities) > 0,
806
+ "discontinuities": discontinuities
807
+ }
808
+
809
+
810
+ def assess_variable_relationships(df: pd.DataFrame, corr_matrix: pd.DataFrame) -> Dict[str, Any]:
811
+ """
812
+ Assess relationships between variables in the dataset.
813
+
814
+ Args:
815
+ df: DataFrame to analyze
816
+ corr_matrix: Precomputed correlation matrix for numeric columns
817
+
818
+ Returns:
819
+ Dict with information about variable relationships:
820
+ - strongly_correlated_pairs: Pairs of strongly correlated variables
821
+ - potential_confounders: Variables that might be confounders
822
+ """
823
+ result = {"strongly_correlated_pairs": [], "potential_confounders": []}
824
+
825
+ numeric_cols = corr_matrix.columns.tolist()
826
+ if len(numeric_cols) < 2:
827
+ return result
828
+
829
+ # Use the precomputed correlation matrix
830
+ corr_matrix_abs = corr_matrix.abs()
831
+
832
+ # Find strongly correlated variable pairs
833
+ for i in range(len(numeric_cols)):
834
+ for j in range(i+1, len(numeric_cols)):
835
+ if abs(corr_matrix_abs.iloc[i, j]) > 0.7: # Correlation threshold
836
+ result["strongly_correlated_pairs"].append({
837
+ "variables": [numeric_cols[i], numeric_cols[j]],
838
+ "correlation": float(corr_matrix.iloc[i, j])
839
+ })
840
+
841
+ # Identify potential confounders (variables correlated with multiple others)
842
+ confounder_counts = {col: 0 for col in numeric_cols}
843
+
844
+ for pair in result["strongly_correlated_pairs"]:
845
+ confounder_counts[pair["variables"][0]] += 1
846
+ confounder_counts[pair["variables"][1]] += 1
847
+
848
+ # Variables correlated with multiple others are potential confounders
849
+ for col, count in confounder_counts.items():
850
+ if count >= 2:
851
+ result["potential_confounders"].append({"variable": col, "num_correlations": count})
852
+
853
+ return result
auto_causal/components/decision_tree.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ decision tree component for selecting causal inference methods
3
+
4
+ this module implements the decision tree logic to select the most appropriate
5
+ causal inference method based on dataset characteristics and available variables
6
+ """
7
+
8
+ import logging
9
+ from typing import Dict, List, Any, Optional
10
+ import pandas as pd
11
+
12
+ # define method names
13
+ BACKDOOR_ADJUSTMENT = "backdoor_adjustment"
14
+ LINEAR_REGRESSION = "linear_regression"
15
+ DIFF_IN_MEANS = "diff_in_means"
16
+ DIFF_IN_DIFF = "difference_in_differences"
17
+ REGRESSION_DISCONTINUITY = "regression_discontinuity_design"
18
+ PROPENSITY_SCORE_MATCHING = "propensity_score_matching"
19
+ INSTRUMENTAL_VARIABLE = "instrumental_variable"
20
+ CORRELATION_ANALYSIS = "correlation_analysis"
21
+ PROPENSITY_SCORE_WEIGHTING = "propensity_score_weighting"
22
+ GENERALIZED_PROPENSITY_SCORE = "generalized_propensity_score"
23
+ FRONTDOOR_ADJUSTMENT = "frontdoor_adjustment"
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # method assumptions mapping
29
+ METHOD_ASSUMPTIONS = {
30
+ BACKDOOR_ADJUSTMENT: [
31
+ "no unmeasured confounders (conditional ignorability given covariates)",
32
+ "correct model specification for outcome conditional on treatment and covariates",
33
+ "positivity/overlap (for all covariate values, units could potentially receive either treatment level)"
34
+ ],
35
+ LINEAR_REGRESSION: [
36
+ "linear relationship between treatment, covariates, and outcome",
37
+ "no unmeasured confounders (if observational)",
38
+ "correct model specification",
39
+ "homoscedasticity of errors",
40
+ "normally distributed errors (for inference)"
41
+ ],
42
+ DIFF_IN_MEANS: [
43
+ "treatment is randomly assigned (or as-if random)",
44
+ "no spillover effects",
45
+ "stable unit treatment value assumption (SUTVA)"
46
+ ],
47
+ DIFF_IN_DIFF: [
48
+ "parallel trends between treatment and control groups before treatment",
49
+ "no spillover effects between groups",
50
+ "no anticipation effects before treatment",
51
+ "stable composition of treatment and control groups",
52
+ "treatment timing is exogenous"
53
+ ],
54
+ REGRESSION_DISCONTINUITY: [
55
+ "units cannot precisely manipulate the running variable around the cutoff",
56
+ "continuity of conditional expectation functions of potential outcomes at the cutoff",
57
+ "no other changes occurring precisely at the cutoff"
58
+ ],
59
+ PROPENSITY_SCORE_MATCHING: [
60
+ "no unmeasured confounders (conditional ignorability)",
61
+ "sufficient overlap (common support) between treatment and control groups",
62
+ "correct propensity score model specification"
63
+ ],
64
+ INSTRUMENTAL_VARIABLE: [
65
+ "instrument is correlated with treatment (relevance)",
66
+ "instrument affects outcome only through treatment (exclusion restriction)",
67
+ "instrument is independent of unmeasured confounders (exogeneity/independence)"
68
+ ],
69
+ CORRELATION_ANALYSIS: [
70
+ "data represents a sample from the population of interest",
71
+ "variables are measured appropriately"
72
+ ],
73
+ PROPENSITY_SCORE_WEIGHTING: [
74
+ "no unmeasured confounders (conditional ignorability)",
75
+ "sufficient overlap (common support) between treatment and control groups",
76
+ "correct propensity score model specification",
77
+ "weights correctly specified (e.g., ATE, ATT)"
78
+ ],
79
+ GENERALIZED_PROPENSITY_SCORE: [
80
+ "conditional mean independence",
81
+ "positivity/common support for GPS",
82
+ "correct specification of the GPS model",
83
+ "correct specification of the outcome model",
84
+ "no unmeasured confounders affecting both treatment and outcome, given X",
85
+ "treatment variable is continuous"
86
+ ],
87
+ FRONTDOOR_ADJUSTMENT: [
88
+ "mediator is affected by treatment and affects outcome",
89
+ "mediator is not affected by any confounders of the treatment-outcome relationship"
90
+ ]
91
+ }
92
+
93
+
94
+ def select_method(dataset_properties: Dict[str, Any], excluded_methods: Optional[List[str]] = None) -> Dict[str, Any]:
95
+ excluded_methods = set(excluded_methods or [])
96
+ logger.info(f"Excluded methods: {sorted(excluded_methods)}")
97
+
98
+ treatment = dataset_properties.get("treatment_variable")
99
+ outcome = dataset_properties.get("outcome_variable")
100
+ if not treatment or not outcome:
101
+ raise ValueError("Both treatment and outcome variables must be specified")
102
+
103
+ instrument_var = dataset_properties.get("instrument_variable")
104
+ running_var = dataset_properties.get("running_variable")
105
+ cutoff_val = dataset_properties.get("cutoff_value")
106
+ time_var = dataset_properties.get("time_variable")
107
+ is_rct = dataset_properties.get("is_rct", False)
108
+ has_temporal = dataset_properties.get("has_temporal_structure", False)
109
+ frontdoor = dataset_properties.get("frontdoor_criterion", False)
110
+ covariate_overlap_result = dataset_properties.get("covariate_overlap_score")
111
+ covariates = dataset_properties.get("covariates", [])
112
+ treatment_variable_type = dataset_properties.get("treatment_variable_type", "binary")
113
+
114
+ # Helpers to collect candidates
115
+ candidates = [] # list of (method, priority_index)
116
+ justifications: Dict[str, str] = {}
117
+ assumptions: Dict[str, List[str]] = {}
118
+
119
+ def add(method: str, justification: str, prio_order: List[str]):
120
+ if method in justifications: # already added
121
+ return
122
+ justifications[method] = justification
123
+ assumptions[method] = METHOD_ASSUMPTIONS[method]
124
+ # priority index from provided order (fallback large if not present)
125
+ try:
126
+ idx = prio_order.index(method)
127
+ except ValueError:
128
+ idx = 10**6
129
+ candidates.append((method, idx))
130
+
131
+ # ----- Build candidate set (no returns here) -----
132
+
133
+ # RCT branch
134
+ if is_rct:
135
+ logger.info("Dataset is from a randomized controlled trial (RCT)")
136
+ rct_priority = [INSTRUMENTAL_VARIABLE, LINEAR_REGRESSION, DIFF_IN_MEANS]
137
+
138
+ if instrument_var and instrument_var != treatment:
139
+ add(INSTRUMENTAL_VARIABLE,
140
+ f"RCT encouragement: instrument '{instrument_var}' differs from treatment '{treatment}'.",
141
+ rct_priority)
142
+
143
+ if covariates:
144
+ add(LINEAR_REGRESSION,
145
+ "RCT with covariates—use OLS for precision.",
146
+ rct_priority)
147
+ else:
148
+ add(DIFF_IN_MEANS,
149
+ "Pure RCT without covariates—difference-in-means.",
150
+ rct_priority)
151
+
152
+ # Observational branch
153
+ obs_priority_binary = [
154
+ INSTRUMENTAL_VARIABLE,
155
+ PROPENSITY_SCORE_MATCHING,
156
+ PROPENSITY_SCORE_WEIGHTING,
157
+ FRONTDOOR_ADJUSTMENT,
158
+ LINEAR_REGRESSION,
159
+ ]
160
+ obs_priority_nonbinary = [
161
+ INSTRUMENTAL_VARIABLE,
162
+ FRONTDOOR_ADJUSTMENT,
163
+ LINEAR_REGRESSION,
164
+ ]
165
+
166
+ # Common early structural signals first (still only add as candidates)
167
+ if has_temporal and time_var:
168
+ add(DIFF_IN_DIFF,
169
+ f"Temporal structure via '{time_var}'—consider Difference-in-Differences (assumes parallel trends).",
170
+ [DIFF_IN_DIFF]) # highest among itself
171
+
172
+ if running_var and cutoff_val is not None:
173
+ add(REGRESSION_DISCONTINUITY,
174
+ f"Running variable '{running_var}' with cutoff {cutoff_val}—consider RDD.",
175
+ [REGRESSION_DISCONTINUITY])
176
+
177
+ # Binary vs non-binary pathways
178
+ if treatment_variable_type == "binary":
179
+ if instrument_var:
180
+ add(INSTRUMENTAL_VARIABLE,
181
+ f"Instrumental variable '{instrument_var}' available.",
182
+ obs_priority_binary)
183
+
184
+ # Propensity score methods only if covariates exist
185
+ if covariates:
186
+ if covariate_overlap_result is not None:
187
+ ps_method = (PROPENSITY_SCORE_WEIGHTING
188
+ if covariate_overlap_result < 0.1
189
+ else PROPENSITY_SCORE_MATCHING)
190
+ else:
191
+ ps_method = PROPENSITY_SCORE_MATCHING
192
+ add(ps_method,
193
+ "Covariates observed; PS method chosen based on overlap.",
194
+ obs_priority_binary)
195
+
196
+ if frontdoor:
197
+ add(FRONTDOOR_ADJUSTMENT,
198
+ "Front-door criterion satisfied.",
199
+ obs_priority_binary)
200
+
201
+ add(LINEAR_REGRESSION,
202
+ "OLS as a fallback specification.",
203
+ obs_priority_binary)
204
+
205
+ else:
206
+ logger.info(f"Non-binary treatment variable detected: {treatment_variable_type}")
207
+ if instrument_var:
208
+ add(INSTRUMENTAL_VARIABLE,
209
+ f"Instrument '{instrument_var}' candidate for non-binary treatment.",
210
+ obs_priority_nonbinary)
211
+ if frontdoor:
212
+ add(FRONTDOOR_ADJUSTMENT,
213
+ "Front-door criterion satisfied.",
214
+ obs_priority_nonbinary)
215
+ add(LINEAR_REGRESSION,
216
+ "Fallback for non-binary treatment without stronger identification.",
217
+ obs_priority_nonbinary)
218
+
219
+ # ----- Centralized exclusion handling -----
220
+ # Remove excluded
221
+ filtered = [(m, p) for (m, p) in candidates if m not in excluded_methods]
222
+
223
+ # If nothing survives, attempt a safe fallback not excluded
224
+ if not filtered:
225
+ logger.warning(f"All candidates excluded. Candidates were: {[m for m,_ in candidates]}. Excluded: {sorted(excluded_methods)}")
226
+ fallback_order = [
227
+ LINEAR_REGRESSION,
228
+ DIFF_IN_MEANS,
229
+ PROPENSITY_SCORE_MATCHING,
230
+ PROPENSITY_SCORE_WEIGHTING,
231
+ DIFF_IN_DIFF,
232
+ REGRESSION_DISCONTINUITY,
233
+ INSTRUMENTAL_VARIABLE,
234
+ FRONTDOOR_ADJUSTMENT,
235
+ ]
236
+ fallback = next((m for m in fallback_order if m in justifications and m not in excluded_methods), None)
237
+ if not fallback:
238
+ # truly nothing left; raise with context
239
+ raise RuntimeError("No viable method remains after exclusions.")
240
+ selected_method = fallback
241
+ alternatives = []
242
+ justifications[selected_method] = justifications.get(selected_method, "Fallback after exclusions.")
243
+ else:
244
+ # Pick by smallest priority index, then stable by insertion
245
+ filtered.sort(key=lambda x: x[1])
246
+ selected_method = filtered[0][0]
247
+ alternatives = [m for (m, _) in filtered[1:] if m != selected_method]
248
+
249
+ logger.info(f"Selected method: {selected_method}; alternatives: {alternatives}")
250
+
251
+ return {
252
+ "selected_method": selected_method,
253
+ "method_justification": justifications[selected_method],
254
+ "method_assumptions": assumptions[selected_method],
255
+ "alternatives": alternatives,
256
+ "excluded_methods": sorted(excluded_methods),
257
+ }
258
+
259
+
260
+
261
+ def rule_based_select_method(dataset_analysis, variables, is_rct, llm, dataset_description, original_query, excluded_methods=None):
262
+ """
263
+ Wrapped function to select causal method based on dataset properties and query
264
+
265
+ Args:
266
+ dataset_analysis (Dict): results of dataset analysis
267
+ variables (Dict): dictionary of variable names and types
268
+ is_rct (bool): whether the dataset is from a randomized controlled trial
269
+ llm (BaseChatModel): language model instance for generating prompts
270
+ dataset_description (str): description of the dataset
271
+ original_query (str): the original user query
272
+ excluded_methods (List[str], optional): list of methods to exclude from selection
273
+ """
274
+
275
+ logger.info("Running rule-based method selection")
276
+
277
+
278
+ properties = {"treatment_variable": variables.get("treatment_variable"), "instrument_variable":variables.get("instrument_variable"),
279
+ "covariates": variables.get("covariates", []), "outcome_variable": variables.get("outcome_variable"),
280
+ "time_variable": variables.get("time_variable"), "running_variable": variables.get("running_variable"),
281
+ "treatment_variable_type": variables.get("treatment_variable_type", "binary"),
282
+ "has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
283
+ "frontdoor_criterion": variables.get("frontdoor_criterion", False),
284
+ "cutoff_value": variables.get("cutoff_value"),
285
+ "covariate_overlap_score": variables.get("covariate_overlap_result", 0)}
286
+
287
+ properties["is_rct"] = is_rct
288
+ logger.info(f"Dataset properties for method selection: {properties}")
289
+
290
+ return select_method(properties, excluded_methods)
291
+
292
+
293
+
294
+ class DecisionTreeEngine:
295
+ """
296
+ Engine for applying decision trees to select appropriate causal methods.
297
+
298
+ This class wraps the functional decision tree implementation to provide
299
+ an object-oriented interface for method selection.
300
+ """
301
+
302
+ def __init__(self, verbose=False):
303
+ self.verbose = verbose
304
+
305
+ def select_method(self, df: pd.DataFrame, treatment: str, outcome: str, covariates: List[str],
306
+ dataset_analysis: Dict[str, Any], query_details: Dict[str, Any]) -> Dict[str, Any]:
307
+ """
308
+ Apply decision tree to select appropriate causal method.
309
+ """
310
+
311
+ if self.verbose:
312
+ print(f"Applying decision tree for treatment: {treatment}, outcome: {outcome}")
313
+ print(f"Available covariates: {covariates}")
314
+
315
+ treatment_variable_type = query_details.get("treatment_variable_type")
316
+ covariate_overlap_result = query_details.get("covariate_overlap_result")
317
+ info = {"treatment_variable": treatment, "outcome_variable": outcome,
318
+ "covariates": covariates, "time_variable": query_details.get("time_variable"),
319
+ "group_variable": query_details.get("group_variable"),
320
+ "instrument_variable": query_details.get("instrument_variable"),
321
+ "running_variable": query_details.get("running_variable"),
322
+ "cutoff_value": query_details.get("cutoff_value"),
323
+ "is_rct": query_details.get("is_rct", False),
324
+ "has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
325
+ "frontdoor_criterion": query_details.get("frontdoor_criterion", False),
326
+ "covariate_overlap_score": covariate_overlap_result,
327
+ "treatment_variable_type": treatment_variable_type}
328
+
329
+ result = select_method(info)
330
+
331
+ if self.verbose:
332
+ print(f"Selected method: {result['selected_method']}")
333
+ print(f"Justification: {result['method_justification']}")
334
+
335
+ result["decision_path"] = self._get_decision_path(result["selected_method"])
336
+ return result
337
+
338
+
339
+ def _get_decision_path(self, method):
340
+ if method == "linear_regression":
341
+ return ["Check if randomized experiment", "Data appears to be from a randomized experiment with covariates"]
342
+ elif method == "propensity_score_matching":
343
+ return ["Check if randomized experiment", "Data is observational",
344
+ "Check for sufficient covariate overlap", "Sufficient overlap exists"]
345
+ elif method == "propensity_score_weighting":
346
+ return ["Check if randomized experiment", "Data is observational",
347
+ "Check for sufficient covariate overlap", "Low overlap—weighting preferred"]
348
+ elif method == "backdoor_adjustment":
349
+ return ["Check if randomized experiment", "Data is observational",
350
+ "Check for sufficient covariate overlap", "Adjusting for covariates"]
351
+ elif method == "instrumental_variable":
352
+ return ["Check if randomized experiment", "Data is observational",
353
+ "Check for instrumental variables", "Instrument is available"]
354
+ elif method == "regression_discontinuity_design":
355
+ return ["Check if randomized experiment", "Data is observational",
356
+ "Check for discontinuity", "Discontinuity exists"]
357
+ elif method == "difference_in_differences":
358
+ return ["Check if randomized experiment", "Data is observational",
359
+ "Check for temporal structure", "Panel data structure exists"]
360
+ elif method == "frontdoor_adjustment":
361
+ return ["Check if randomized experiment", "Data is observational",
362
+ "Check front-door criterion", "Front-door path identified"]
363
+ elif method == "diff_in_means":
364
+ return ["Check if randomized experiment", "Pure RCT without covariates"]
365
+ else:
366
+ return ["Default method selection"]
auto_causal/components/decision_tree_llm.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM-based Decision tree component for selecting causal inference methods.
3
+
4
+ This module implements the decision tree logic via an LLM prompt
5
+ to select the most appropriate causal inference method based on
6
+ dataset characteristics and available variables.
7
+ """
8
+
9
+ import logging
10
+ import json
11
+ from typing import Dict, Any, Optional, List
12
+
13
+ from langchain_core.messages import HumanMessage
14
+ from langchain_core.language_models import BaseChatModel
15
+
16
+ # Import constants and assumptions from the original decision_tree module
17
+ from .decision_tree import (
18
+ METHOD_ASSUMPTIONS,
19
+ BACKDOOR_ADJUSTMENT,
20
+ LINEAR_REGRESSION,
21
+ DIFF_IN_MEANS,
22
+ DIFF_IN_DIFF,
23
+ REGRESSION_DISCONTINUITY,
24
+ PROPENSITY_SCORE_MATCHING,
25
+ INSTRUMENTAL_VARIABLE,
26
+ CORRELATION_ANALYSIS,
27
+ PROPENSITY_SCORE_WEIGHTING,
28
+ GENERALIZED_PROPENSITY_SCORE
29
+ )
30
+
31
+ # Configure logging
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Define a list of all known methods for the LLM prompt
35
+ ALL_METHODS = [
36
+ DIFF_IN_MEANS,
37
+ LINEAR_REGRESSION,
38
+ DIFF_IN_DIFF,
39
+ REGRESSION_DISCONTINUITY,
40
+ INSTRUMENTAL_VARIABLE,
41
+ PROPENSITY_SCORE_MATCHING,
42
+ PROPENSITY_SCORE_WEIGHTING,
43
+ GENERALIZED_PROPENSITY_SCORE,
44
+ BACKDOOR_ADJUSTMENT, # Often a general approach rather than a specific model.
45
+ CORRELATION_ANALYSIS,
46
+ ]
47
+
48
+ METHOD_DESCRIPTIONS_FOR_LLM = {
49
+ DIFF_IN_MEANS: "Appropriate for Randomized Controlled Trials (RCTs) with no covariates. Compares the average outcome between treated and control groups.",
50
+ LINEAR_REGRESSION: "Can be used for RCTs with covariates to increase precision, or for observational data assuming linear relationships and no unmeasured confounders. Models the outcome as a linear function of treatment and covariates.",
51
+ DIFF_IN_DIFF: "Suitable for observational data with a temporal structure (e.g., panel data with pre/post treatment periods). Requires the 'parallel trends' assumption: treatment and control groups would have followed similar trends in the outcome in the absence of treatment.",
52
+ REGRESSION_DISCONTINUITY: "Applicable when treatment assignment is determined by whether an observed 'running variable' crosses a specific cutoff point. Assumes individuals cannot precisely manipulate the running variable.",
53
+ INSTRUMENTAL_VARIABLE: "Used when there's an 'instrument' variable that is correlated with the treatment, affects the outcome only through the treatment, and is not confounded with the outcome. Useful for handling unobserved confounding.",
54
+ PROPENSITY_SCORE_MATCHING: "For observational data with covariates. Estimates the probability of receiving treatment (propensity score) for each unit and then matches treated and control units with similar scores. Aims to create balanced groups.",
55
+ PROPENSITY_SCORE_WEIGHTING: "Similar to PSM, for observational data with covariates. Uses propensity scores to weight units to create a pseudo-population where confounders are balanced. Can estimate ATE, ATT, or ATC.",
56
+ GENERALIZED_PROPENSITY_SCORE: "An extension of propensity scores for continuous treatment variables. Aims to estimate the dose-response function, assuming unconfoundedness given covariates.",
57
+ BACKDOOR_ADJUSTMENT: "A general strategy for causal inference in observational studies that involves statistically controlling for all common causes (confounders) of the treatment and outcome. Specific methods like regression or matching implement this.",
58
+ CORRELATION_ANALYSIS: "A fallback method when causal inference is not feasible due to data limitations (e.g., no clear design, no covariates for adjustment). Measures the statistical association between variables, but does not imply causation."
59
+ }
60
+
61
+
62
+ class DecisionTreeLLMEngine:
63
+ """
64
+ Engine for applying an LLM-based decision tree to select appropriate causal methods.
65
+ """
66
+
67
+ def __init__(self, verbose: bool = False):
68
+ """
69
+ Initialize the LLM decision tree engine.
70
+
71
+ Args:
72
+ verbose: Whether to print verbose information.
73
+ """
74
+ self.verbose = verbose
75
+
76
+ def _construct_prompt(self, dataset_analysis: Dict[str, Any], variables: Dict[str, Any], is_rct: bool, excluded_methods: Optional[List[str]] = None) -> str:
77
+ """
78
+ Constructs the detailed prompt for the LLM.
79
+ """
80
+ # Filter out excluded methods
81
+ excluded_methods = excluded_methods or []
82
+ available_methods = [method for method in ALL_METHODS if method not in excluded_methods]
83
+ methods_list_str = "\n".join([f"- {method}: {METHOD_DESCRIPTIONS_FOR_LLM[method]}" for method in available_methods if method in METHOD_DESCRIPTIONS_FOR_LLM])
84
+
85
+ excluded_info = ""
86
+ if excluded_methods:
87
+ excluded_info = f"\nEXCLUDED METHODS (do not select these): {', '.join(excluded_methods)}\nReason: These methods failed validation in previous attempts.\n"
88
+
89
+ prompt = f"""You are an expert in causal inference. Your task is to select the most appropriate causal inference method based on the provided dataset analysis and variable information.
90
+
91
+ Dataset Analysis:
92
+ {json.dumps(dataset_analysis, indent=2)}
93
+
94
+ Identified Variables:
95
+ {json.dumps(variables, indent=2)}
96
+
97
+ Is the data from a Randomized Controlled Trial (RCT)? {'Yes' if is_rct else 'No'}{excluded_info}
98
+
99
+ Available Causal Inference Methods and their descriptions:
100
+ {methods_list_str}
101
+
102
+ Instructions:
103
+ 1. Carefully review all the provided information: dataset analysis, variables, and RCT status.
104
+ 2. Reason step-by-step to determine the most suitable method. Consider the hierarchy of methods (e.g., specific designs like DiD, RDD, IV before general adjustment methods).
105
+ 3. Explain your reasoning for selecting a particular method.
106
+ 4. Identify any potential alternative methods if applicable.
107
+ 5. State the key assumptions for your *selected* method by referring to the general list of assumptions for all methods that will be provided to you separately (you don't need to list them here, just be aware that you need to select a method for which assumptions are known).
108
+
109
+ Output your final decision as a JSON object with the following exact keys:
110
+ - "selected_method": string (must be one of {', '.join(available_methods)})
111
+ - "method_justification": string (your detailed reasoning)
112
+ - "alternative_methods": list of strings (alternative method names, can be empty)
113
+
114
+ Example JSON output format:
115
+ {{
116
+ "selected_method": "difference_in_differences",
117
+ "method_justification": "The dataset has a clear time variable and group variable, indicating a panel structure suitable for DiD. The parallel trends assumption will need to be checked.",
118
+ "alternative_methods": ["instrumental_variable"]
119
+ }}
120
+
121
+ Please provide only the JSON object in your response.
122
+ """
123
+ return prompt
124
+
125
+ def select_method_llm(self, dataset_analysis: Dict[str, Any], variables: Dict[str, Any], is_rct: bool = False, llm: Optional[BaseChatModel] = None, excluded_methods: Optional[List[str]] = None) -> Dict[str, Any]:
126
+ """
127
+ Apply LLM-based decision tree to select appropriate causal method.
128
+
129
+ Args:
130
+ dataset_analysis: Dataset analysis results.
131
+ variables: Identified variables from query_interpreter.
132
+ is_rct: Boolean indicating if the data comes from an RCT.
133
+ llm: Langchain BaseChatModel instance for making the call.
134
+ excluded_methods: Optional list of method names to exclude from selection.
135
+
136
+ Returns:
137
+ Dict with selected method, justification, and assumptions.
138
+ Example:
139
+ {{
140
+ "selected_method": "difference_in_differences",
141
+ "method_justification": "Reasoning...",
142
+ "method_assumptions": ["Assumption 1", ...],
143
+ "alternative_methods": ["instrumental_variable"]
144
+ }}
145
+ """
146
+ if not llm:
147
+ logger.error("LLM client not provided to DecisionTreeLLMEngine. Cannot select method.")
148
+ return {
149
+ "selected_method": CORRELATION_ANALYSIS,
150
+ "method_justification": "LLM client not provided. Defaulting to Correlation Analysis as causal inference method selection is not possible. This indicates association, not causation.",
151
+ "method_assumptions": METHOD_ASSUMPTIONS.get(CORRELATION_ANALYSIS, []),
152
+ "alternative_methods": []
153
+ }
154
+
155
+ prompt = self._construct_prompt(dataset_analysis, variables, is_rct, excluded_methods)
156
+ if self.verbose:
157
+ logger.info("LLM Prompt for method selection:")
158
+ logger.info(prompt)
159
+
160
+ messages = [HumanMessage(content=prompt)]
161
+
162
+ llm_output_str = "" # Initialize llm_output_str here
163
+ try:
164
+ response = llm.invoke(messages)
165
+ llm_output_str = response.content.strip()
166
+
167
+ if self.verbose:
168
+ logger.info(f"LLM Raw Output: {llm_output_str}")
169
+
170
+ # Attempt to parse the JSON output
171
+ # The LLM might sometimes include explanations outside the JSON block.
172
+ # Try to extract JSON from within ```json ... ``` if present.
173
+ if "```json" in llm_output_str:
174
+ json_str = llm_output_str.split("```json")[1].split("```")[0].strip()
175
+ elif "```" in llm_output_str and llm_output_str.startswith("{") == False : # if it doesn't start with { then likely ```{}```
176
+ json_str = llm_output_str.split("```")[1].strip()
177
+ else: # Assume the entire string is the JSON if no triple backticks
178
+ json_str = llm_output_str
179
+
180
+ parsed_response = json.loads(json_str)
181
+
182
+ selected_method = parsed_response.get("selected_method")
183
+ justification = parsed_response.get("method_justification", "No justification provided by LLM.")
184
+ alternatives = parsed_response.get("alternative_methods", [])
185
+
186
+ if selected_method and selected_method in METHOD_ASSUMPTIONS:
187
+ logger.info(f"LLM selected method: {selected_method}")
188
+ return {
189
+ "selected_method": selected_method,
190
+ "method_justification": justification,
191
+ "method_assumptions": METHOD_ASSUMPTIONS[selected_method],
192
+ "alternative_methods": alternatives
193
+ }
194
+ else:
195
+ logger.warning(f"LLM selected an invalid or unknown method: '{selected_method}'. Or method not in METHOD_ASSUMPTIONS. Raw response: {llm_output_str}")
196
+ fallback_justification = f"LLM output was problematic (selected: {selected_method}). Defaulting to Correlation Analysis. LLM Raw Response: {llm_output_str}"
197
+ selected_method = CORRELATION_ANALYSIS
198
+ justification = fallback_justification
199
+
200
+ except json.JSONDecodeError as e:
201
+ logger.error(f"Failed to parse JSON response from LLM: {e}. Raw response: {llm_output_str}", exc_info=True)
202
+ fallback_justification = f"LLM response was not valid JSON. Defaulting to Correlation Analysis. Error: {e}. LLM Raw Response: {llm_output_str}"
203
+ selected_method = CORRELATION_ANALYSIS
204
+ justification = fallback_justification
205
+ alternatives = []
206
+ except Exception as e:
207
+ logger.error(f"Error during LLM call for method selection: {e}. Raw response: {llm_output_str}", exc_info=True)
208
+ fallback_justification = f"An unexpected error occurred during LLM method selection. Defaulting to Correlation Analysis. Error: {e}. LLM Raw Response: {llm_output_str}"
209
+ selected_method = CORRELATION_ANALYSIS
210
+ justification = fallback_justification
211
+ alternatives = []
212
+
213
+ return {
214
+ "selected_method": selected_method,
215
+ "method_justification": justification,
216
+ "method_assumptions": METHOD_ASSUMPTIONS.get(selected_method, []),
217
+ "alternative_methods": alternatives
218
+ }
auto_causal/components/explanation_generator.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Explanation generator component for causal inference methods.
3
+
4
+ This module generates explanations for causal inference methods, including
5
+ what the method does, its assumptions, and how it will be applied to the dataset.
6
+ """
7
+
8
+ from typing import Dict, Any, List, Optional
9
+ from langchain_core.language_models import BaseChatModel # For LLM type hint
10
+
11
+
12
+ def generate_explanation(
13
+ method_info: Dict[str, Any],
14
+ validation_result: Dict[str, Any],
15
+ variables: Dict[str, Any],
16
+ results: Dict[str, Any],
17
+ dataset_analysis: Optional[Dict[str, Any]] = None,
18
+ dataset_description: Optional[str] = None,
19
+ llm: Optional[BaseChatModel] = None
20
+ ) -> Dict[str, str]:
21
+ """
22
+ Generates a comprehensive explanation text for the causal analysis.
23
+
24
+ Args:
25
+ method_info: Dictionary containing selected method details.
26
+ validation_result: Dictionary containing method validation results.
27
+ variables: Dictionary containing identified variables.
28
+ results: Dictionary containing numerical results from the method execution.
29
+ dataset_analysis: Optional dictionary with dataset analysis details.
30
+ dataset_description: Optional string describing the dataset.
31
+ llm: Optional language model instance (for potential future use in generation).
32
+
33
+ Returns:
34
+ Dictionary containing the final explanation text.
35
+ """
36
+ method = method_info.get("method_name")
37
+
38
+ # Handle potential None for validation_result
39
+ if validation_result and validation_result.get("valid") is False:
40
+ method = validation_result.get("recommended_method", method)
41
+
42
+ # Get components
43
+ method_explanation = get_method_explanation(method)
44
+ assumption_explanations = explain_assumptions(method_info.get("assumptions", []))
45
+ application_explanation = explain_application(method, variables.get("treatment_variable"),
46
+ variables.get("outcome_variable"),
47
+ variables.get("covariates", []), variables)
48
+ limitations_explanation = explain_limitations(method, validation_result.get("concerns", []) if validation_result else [])
49
+ interpretation_guide = generate_interpretation_guide(method, variables.get("treatment_variable"),
50
+ variables.get("outcome_variable"))
51
+
52
+ # --- Extract Numerical Results ---
53
+ effect_estimate = results.get("effect_estimate")
54
+ effect_se = results.get("effect_se")
55
+ ci = results.get("confidence_interval")
56
+ p_value = results.get("p_value") # Assuming method executor returns p_value
57
+
58
+ # --- Assemble Final Text ---
59
+ final_text = f"**Method Used:** {method_info.get('method_name', method)}\n\n"
60
+ final_text += f"**Method Explanation:**\n{method_explanation}\n\n"
61
+
62
+ # Add Results Section
63
+ final_text += "**Results:**\n"
64
+ if effect_estimate is not None:
65
+ final_text += f"- Estimated Causal Effect: {effect_estimate:.4f}\n"
66
+ if effect_se is not None:
67
+ final_text += f"- Standard Error: {effect_se:.4f}\n"
68
+ if ci and ci[0] is not None and ci[1] is not None:
69
+ final_text += f"- 95% Confidence Interval: [{ci[0]:.4f}, {ci[1]:.4f}]\n"
70
+ if p_value is not None:
71
+ final_text += f"- P-value: {p_value:.4f}\n"
72
+ final_text += "\n"
73
+
74
+ final_text += f"**Interpretation Guide:**\n{interpretation_guide}\n\n"
75
+ final_text += f"**Assumptions:**\n"
76
+ for item in assumption_explanations:
77
+ final_text += f"- {item['assumption']}: {item['explanation']}\n"
78
+ final_text += "\n"
79
+ final_text += f"**Limitations:**\n{limitations_explanation}\n\n"
80
+
81
+ return {
82
+ "final_explanation_text": final_text
83
+ # Return only the final text, the tool wrapper adds workflow state
84
+ }
85
+
86
+
87
+ def get_method_explanation(method: str) -> str:
88
+ """
89
+ Get explanation for what the method does.
90
+
91
+ Args:
92
+ method: Causal inference method name
93
+
94
+ Returns:
95
+ String explaining what the method does
96
+ """
97
+ explanations = {
98
+ "propensity_score_matching": (
99
+ "Propensity Score Matching is a statistical technique that attempts to estimate the effect "
100
+ "of a treatment by accounting for covariates that predict receiving the treatment. "
101
+ "It creates matched sets of treated and untreated subjects who share similar characteristics, "
102
+ "allowing for a more fair comparison between groups."
103
+ ),
104
+ "regression_adjustment": (
105
+ "Regression Adjustment is a method that uses regression models to estimate causal effects "
106
+ "by controlling for covariates. It models the outcome as a function of the treatment and "
107
+ "other potential confounding variables, allowing the isolation of the treatment effect."
108
+ ),
109
+ "instrumental_variable": (
110
+ "The Instrumental Variable method addresses issues of endogeneity or unmeasured confounding "
111
+ "by using an 'instrument' - a variable that affects the treatment but not the outcome directly. "
112
+ "It effectively finds the natural experiment hidden in your data to estimate causal effects."
113
+ ),
114
+ "difference_in_differences": (
115
+ "Difference-in-Differences compares the changes in outcomes over time between a group that "
116
+ "receives a treatment and a group that does not. It controls for time-invariant unobserved "
117
+ "confounders by looking at differences in trends rather than absolute values."
118
+ ),
119
+ "regression_discontinuity": (
120
+ "Regression Discontinuity Design exploits a threshold or cutoff rule that determines treatment "
121
+ "assignment. By comparing observations just above and below this threshold, where treatment "
122
+ "status changes but other characteristics remain similar, it estimates the local causal effect."
123
+ ),
124
+ "backdoor_adjustment": (
125
+ "Backdoor Adjustment controls for confounding variables that create 'backdoor paths' between "
126
+ "treatment and outcome variables in a causal graph. By conditioning on these variables, "
127
+ "it blocks the non-causal associations, allowing for identification of the causal effect."
128
+ ),
129
+ }
130
+
131
+ return explanations.get(method,
132
+ f"The {method} method is a causal inference technique used to estimate "
133
+ f"causal effects from observational data.")
134
+
135
+
136
+ def explain_assumptions(assumptions: List[str]) -> List[Dict[str, str]]:
137
+ """
138
+ Explain each assumption of the method.
139
+
140
+ Args:
141
+ assumptions: List of assumption names
142
+
143
+ Returns:
144
+ List of dictionaries with assumption name and explanation
145
+ """
146
+ assumption_details = {
147
+ "Treatment is randomly assigned": (
148
+ "This assumes that treatment assignment is not influenced by any factors "
149
+ "related to the outcome, similar to a randomized controlled trial. "
150
+ "In observational data, this assumption rarely holds without conditioning on confounders."
151
+ ),
152
+ "No systematic differences between treatment and control groups": (
153
+ "Treatment and control groups should be balanced on all relevant characteristics "
154
+ "except for the treatment itself. Any systematic differences could bias the estimate."
155
+ ),
156
+ "No unmeasured confounders (conditional ignorability)": (
157
+ "All variables that simultaneously affect the treatment and outcome are measured and "
158
+ "included in the analysis. If important confounders are missing, the estimated causal "
159
+ "effect will be biased."
160
+ ),
161
+ "Sufficient overlap between treatment and control groups": (
162
+ "For each combination of covariate values, there should be both treated and untreated "
163
+ "units. Without overlap, the model must extrapolate, which can lead to biased estimates."
164
+ ),
165
+ "Treatment assignment is not deterministic given covariates": (
166
+ "No combination of covariates should perfectly predict treatment assignment. "
167
+ "If treatment is deterministic for some units, causal comparisons become impossible."
168
+ ),
169
+ "Instrument is correlated with treatment (relevance)": (
170
+ "The instrumental variable must have a clear and preferably strong effect on the "
171
+ "treatment variable. Weak instruments lead to imprecise and potentially biased estimates."
172
+ ),
173
+ "Instrument affects outcome only through treatment (exclusion restriction)": (
174
+ "The instrumental variable must not directly affect the outcome except through its "
175
+ "effect on the treatment. If this assumption fails, the causal estimate will be biased."
176
+ ),
177
+ "Instrument is as good as randomly assigned (exogeneity)": (
178
+ "The instrumental variable must not be correlated with any confounders of the "
179
+ "treatment-outcome relationship. It should be as good as randomly assigned."
180
+ ),
181
+ "Parallel trends between treatment and control groups": (
182
+ "In the absence of treatment, the difference between treatment and control groups "
183
+ "would have remained constant over time. This is the key identifying assumption for "
184
+ "difference-in-differences and cannot be directly tested for the post-treatment period."
185
+ ),
186
+ "No spillover effects between groups": (
187
+ "The treatment of one unit should not affect the outcomes of other units. "
188
+ "If spillovers exist, they can bias the estimated treatment effect."
189
+ ),
190
+ "No anticipation effects before treatment": (
191
+ "Units should not change their behavior in anticipation of future treatment. "
192
+ "If anticipation effects exist, the pre-treatment trends may already reflect treatment effects."
193
+ ),
194
+ "Stable composition of treatment and control groups": (
195
+ "The composition of treatment and control groups should remain stable over time. "
196
+ "If units move between groups based on outcomes, this can bias the estimates."
197
+ ),
198
+ "Units cannot precisely manipulate their position around the cutoff": (
199
+ "In regression discontinuity, units must not be able to precisely control their position "
200
+ "relative to the cutoff. If they can, the randomization-like property of the design fails."
201
+ ),
202
+ "No other variables change discontinuously at the cutoff": (
203
+ "Any discontinuity in outcomes at the cutoff should be attributable only to the change "
204
+ "in treatment status. If other relevant variables also change at the cutoff, the causal "
205
+ "interpretation is compromised."
206
+ ),
207
+ "The relationship between running variable and outcome is continuous at the cutoff": (
208
+ "In the absence of treatment, the relationship between the running variable and the "
209
+ "outcome would be continuous at the cutoff. This allows attributing any observed "
210
+ "discontinuity to the treatment effect."
211
+ ),
212
+ "The model correctly specifies the relationship between variables": (
213
+ "The functional form of the relationship between variables in the model should correctly "
214
+ "capture the true relationship in the data. Misspecification can lead to biased estimates."
215
+ ),
216
+ "No reverse causality": (
217
+ "The treatment must cause the outcome, not the other way around. If the outcome affects "
218
+ "the treatment, the estimated relationship will not have a causal interpretation."
219
+ ),
220
+ }
221
+
222
+ return [
223
+ {"assumption": assumption, "explanation": assumption_details.get(assumption,
224
+ "This is a key assumption for the selected causal inference method.")}
225
+ for assumption in assumptions
226
+ ]
227
+
228
+
229
+ def explain_application(method: str, treatment: str, outcome: str,
230
+ covariates: List[str], variables: Dict[str, Any]) -> str:
231
+ """
232
+ Explain how the method will be applied to the dataset.
233
+
234
+ Args:
235
+ method: Causal inference method name
236
+ treatment: Treatment variable name
237
+ outcome: Outcome variable name
238
+ covariates: List of covariate names
239
+ variables: Dictionary of identified variables
240
+
241
+ Returns:
242
+ String explaining the application
243
+ """
244
+ covariate_str = ", ".join(covariates[:3])
245
+ if len(covariates) > 3:
246
+ covariate_str += f", and {len(covariates) - 3} other variables"
247
+
248
+ applications = {
249
+ "propensity_score_matching": (
250
+ f"I will estimate the propensity scores (probability of receiving treatment) for each "
251
+ f"observation based on the covariates ({covariate_str}). Then, I'll match treated and "
252
+ f"untreated units with similar propensity scores to create balanced comparison groups. "
253
+ f"Finally, I'll calculate the difference in {outcome} between these matched groups to "
254
+ f"estimate the causal effect of {treatment}."
255
+ ),
256
+ "regression_adjustment": (
257
+ f"I will build a regression model with {outcome} as the dependent variable and "
258
+ f"{treatment} as the independent variable of interest, while controlling for "
259
+ f"potential confounders ({covariate_str}). The coefficient of {treatment} will "
260
+ f"represent the estimated causal effect after adjusting for these covariates."
261
+ ),
262
+ "instrumental_variable": (
263
+ f"I will use {variables.get('instrument_variable')} as an instrumental variable for "
264
+ f"{treatment}. First, I'll estimate how the instrument affects {treatment} (first stage). "
265
+ f"Then, I'll use these predictions to estimate how changes in {treatment} that are induced "
266
+ f"by the instrument affect {outcome} (second stage). This two-stage approach helps "
267
+ f"address potential unmeasured confounding."
268
+ ),
269
+ "difference_in_differences": (
270
+ f"I will compare the change in {outcome} before and after the intervention for the "
271
+ f"group receiving {treatment}, relative to the change in a control group that didn't "
272
+ f"receive the treatment. This approach controls for time-invariant confounders and "
273
+ f"common time trends that affect both groups."
274
+ ),
275
+ "regression_discontinuity": (
276
+ f"I will focus on observations close to the cutoff value "
277
+ f"({variables.get('cutoff_value')}) of the running variable "
278
+ f"({variables.get('running_variable')}), where treatment assignment changes. "
279
+ f"By comparing outcomes just above and below this threshold, I can estimate "
280
+ f"the local causal effect of {treatment} on {outcome}."
281
+ ),
282
+ "backdoor_adjustment": (
283
+ f"I will control for the identified confounding variables ({covariate_str}) to "
284
+ f"block all backdoor paths between {treatment} and {outcome}. This may involve "
285
+ f"stratification, regression adjustment, or inverse probability weighting, depending "
286
+ f"on the data characteristics."
287
+ ),
288
+ }
289
+
290
+ return applications.get(method,
291
+ f"I will apply the {method} method to estimate the causal effect of "
292
+ f"{treatment} on {outcome}, controlling for relevant confounding factors "
293
+ f"where appropriate.")
294
+
295
+
296
+ def explain_limitations(method: str, concerns: List[str]) -> str:
297
+ """
298
+ Explain the limitations of the method based on validation concerns.
299
+
300
+ Args:
301
+ method: Causal inference method name
302
+ concerns: List of concerns from validation
303
+
304
+ Returns:
305
+ String explaining the limitations
306
+ """
307
+ method_limitations = {
308
+ "propensity_score_matching": (
309
+ "Propensity Score Matching can only account for observed confounders, and its "
310
+ "effectiveness depends on having good overlap between treatment and control groups. "
311
+ "It may also be sensitive to model specification for the propensity score estimation."
312
+ ),
313
+ "regression_adjustment": (
314
+ "Regression Adjustment relies heavily on correct model specification and can only "
315
+ "control for observed confounders. Extrapolation to regions with limited data can lead "
316
+ "to unreliable estimates, and the method may be sensitive to outliers."
317
+ ),
318
+ "instrumental_variable": (
319
+ "Instrumental Variable estimation can be imprecise with weak instruments and is "
320
+ "sensitive to violations of the exclusion restriction. The estimated effect is a local "
321
+ "average treatment effect for 'compliers', which may not generalize to the entire population."
322
+ ),
323
+ "difference_in_differences": (
324
+ "Difference-in-Differences relies on the parallel trends assumption, which cannot be fully "
325
+ "tested for the post-treatment period. It may be sensitive to the choice of comparison group "
326
+ "and can be biased if there are time-varying confounders or anticipation effects."
327
+ ),
328
+ "regression_discontinuity": (
329
+ "Regression Discontinuity provides estimates that are local to the cutoff point and may not "
330
+ "generalize to units far from this threshold. It also requires sufficient data around the "
331
+ "cutoff and is sensitive to the choice of bandwidth and functional form."
332
+ ),
333
+ "backdoor_adjustment": (
334
+ "Backdoor Adjustment requires correctly identifying all confounding variables and their "
335
+ "relationships. It depends on the assumption of no unmeasured confounders and may be "
336
+ "sensitive to model misspecification in complex settings."
337
+ ),
338
+ }
339
+
340
+ base_limitation = method_limitations.get(method,
341
+ f"The {method} method has general limitations in terms of its assumptions and applicability.")
342
+
343
+ # Add specific concerns if any
344
+ if concerns:
345
+ concern_text = " Additionally, specific concerns for this analysis include: " + \
346
+ "; ".join(concerns) + "."
347
+ return base_limitation + concern_text
348
+
349
+ return base_limitation
350
+
351
+
352
+ def generate_interpretation_guide(method: str, treatment: str, outcome: str) -> str:
353
+ """
354
+ Generate guide for interpreting the results.
355
+
356
+ Args:
357
+ method: Causal inference method name
358
+ treatment: Treatment variable name
359
+ outcome: Outcome variable name
360
+
361
+ Returns:
362
+ String with interpretation guide
363
+ """
364
+ interpretation_guides = {
365
+ "propensity_score_matching": (
366
+ f"The estimated effect represents the Average Treatment Effect (ATE) or the Average "
367
+ f"Treatment Effect on the Treated (ATT), depending on the specific matching approach. "
368
+ f"It can be interpreted as the expected change in {outcome} if a unit were to receive "
369
+ f"{treatment}, compared to not receiving it, for units with similar covariate values."
370
+ ),
371
+ "regression_adjustment": (
372
+ f"The coefficient of {treatment} in the regression model represents the estimated "
373
+ f"average causal effect on {outcome}, holding all included covariates constant. "
374
+ f"For binary treatments, it's the expected difference in outcomes between treated "
375
+ f"and untreated units with the same covariate values."
376
+ ),
377
+ "instrumental_variable": (
378
+ f"The estimated effect represents the Local Average Treatment Effect (LATE) for 'compliers' "
379
+ f"- units whose treatment status is influenced by the instrument. It can be interpreted as "
380
+ f"the average effect of {treatment} on {outcome} for this specific subpopulation."
381
+ ),
382
+ "difference_in_differences": (
383
+ f"The estimated effect represents the average causal impact of {treatment} on {outcome}, "
384
+ f"under the assumption that treatment and control groups would have followed parallel "
385
+ f"trends in the absence of treatment. It accounts for both time-invariant differences "
386
+ f"between groups and common time trends."
387
+ ),
388
+ "regression_discontinuity": (
389
+ f"The estimated effect represents the local causal impact of {treatment} on {outcome} "
390
+ f"at the cutoff point. It can be interpreted as the expected difference in outcomes "
391
+ f"for units just above versus just below the threshold, where treatment status changes."
392
+ ),
393
+ "backdoor_adjustment": (
394
+ f"The estimated effect represents the average causal effect of {treatment} on {outcome} "
395
+ f"after controlling for all identified confounding variables. It can be interpreted as "
396
+ f"the expected difference in outcomes if a unit were to receive versus not receive the "
397
+ f"treatment, holding all confounding factors constant."
398
+ ),
399
+ }
400
+
401
+ return interpretation_guides.get(method,
402
+ f"The estimated effect represents the causal impact of {treatment} on {outcome}, "
403
+ f"given the assumptions of the method are met. Careful consideration of these "
404
+ f"assumptions is needed for valid causal interpretation.")
auto_causal/components/input_parser.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Input parser component for extracting information from causal queries.
3
+
4
+ This module provides functionality to parse user queries and extract key
5
+ elements such as the causal question, relevant variables, and constraints.
6
+ """
7
+
8
+ import re
9
+ import os
10
+ import json
11
+ import logging # Added for better logging
12
+ from typing import Dict, List, Any, Optional, Union
13
+ import pandas as pd
14
+ from pydantic import BaseModel, Field, ValidationError
15
+ from functools import partial # Import partial
16
+
17
+ # Add dotenv import
18
+ from dotenv import load_dotenv
19
+
20
+ # LangChain Imports
21
+ from langchain_openai import ChatOpenAI # Example, replace if using another provider
22
+ from langchain_core.messages import HumanMessage, SystemMessage
23
+ from langchain_core.exceptions import OutputParserException # Correct path
24
+ from langchain_core.language_models import BaseChatModel # Import BaseChatModel
25
+
26
+ # --- Load .env file ---
27
+ load_dotenv() # Load environment variables from .env file
28
+
29
+ # --- Configure Logging ---
30
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # --- Instantiate LLM Client ---
34
+ # Ensure OPENAI_API_KEY environment variable is set
35
+ # Consider making model name configurable
36
+ try:
37
+ # Using with_structured_output later, so instantiate base model here
38
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
39
+ # Add a check or allow configuration for different providers if needed
40
+ except ImportError:
41
+ logger.error("langchain_openai not installed. Please install it to use OpenAI models.")
42
+ llm = None
43
+ except Exception as e:
44
+ logger.error(f"Error initializing LLM: {e}. Input parsing will rely on fallbacks.")
45
+ llm = None
46
+
47
+ # --- Pydantic Models for Structured Output ---
48
+ class ParsedVariables(BaseModel):
49
+ treatment: List[str] = Field(default_factory=list, description="Variable(s) representing the treatment/intervention.")
50
+ outcome: List[str] = Field(default_factory=list, description="Variable(s) representing the outcome/result.")
51
+ covariates_mentioned: Optional[List[str]] = Field(default_factory=list, description="Covariate/control variable(s) explicitly mentioned in the query.")
52
+ grouping_vars: Optional[List[str]] = Field(default_factory=list, description="Variable(s) identifying groups or units for analysis.")
53
+ instruments_mentioned: Optional[List[str]] = Field(default_factory=list, description="Potential instrumental variable(s) mentioned.")
54
+
55
+ class ParsedQueryInfo(BaseModel):
56
+ query_type: str = Field(..., description="Type of query (e.g., EFFECT_ESTIMATION, COUNTERFACTUAL, CORRELATION, DESCRIPTIVE, OTHER). Required.")
57
+ variables: ParsedVariables = Field(..., description="Variables identified in the query.")
58
+ constraints: Optional[List[str]] = Field(default_factory=list, description="Constraints or conditions mentioned (e.g., 'X > 10', 'country = USA').")
59
+ dataset_path_mentioned: Optional[str] = Field(None, description="Dataset path explicitly mentioned in the query, if any.")
60
+
61
+ # Add Pydantic model for path extraction
62
+ class ExtractedPath(BaseModel):
63
+ dataset_path: Optional[str] = Field(None, description="File path or URL for the dataset mentioned in the query.")
64
+
65
+ # --- End Pydantic Models ---
66
+
67
+ def _build_llm_prompt(query: str, dataset_info: Optional[Dict] = None) -> str:
68
+ """Builds the prompt for the LLM to extract query information."""
69
+ dataset_context = "No dataset context provided."
70
+ if dataset_info:
71
+ columns = dataset_info.get('columns', [])
72
+ column_details = "\n".join([f"- {col} (Type: {dataset_info.get('column_types', {}).get(col, 'Unknown')})" for col in columns])
73
+ sample_rows = dataset_info.get('sample_rows', 'Not available')
74
+ # Ensure sample rows are formatted reasonably
75
+ if isinstance(sample_rows, list):
76
+ sample_rows_str = json.dumps(sample_rows[:3], indent=2) # Show first 3 sample rows
77
+ elif isinstance(sample_rows, str):
78
+ sample_rows_str = sample_rows
79
+ else:
80
+ sample_rows_str = 'Not available'
81
+
82
+ dataset_context = f"""
83
+ Dataset Context:
84
+ Columns:
85
+ {column_details}
86
+ Sample Rows (first few):
87
+ {sample_rows_str}
88
+ """
89
+
90
+ prompt = f"""
91
+ Analyze the following causal query **strictly in the context of the provided dataset information (if available)**. Identify the query type, key variables (mapping query terms to actual column names when possible), constraints, and any explicitly mentioned dataset path.
92
+
93
+ User Query: "{query}"
94
+
95
+ {dataset_context}
96
+
97
+ # Add specific guidance for query types
98
+ Guidance for Identifying Query Type:
99
+ - EFFECT_ESTIMATION: Look for keywords like 'effect', 'impact', 'influence', 'cause', 'affect', 'consequence'. Also consider questions asking "how does X affect Y?" or comparing outcomes between groups based on an intervention.
100
+ - COUNTERFACTUAL: Look for hypothetical scenarios, often using phrases like 'what if', 'if X had been', 'would Y have changed', 'imagine if', 'counterfactual'.
101
+ - CORRELATION: Look for keywords like 'correlation', 'association', 'relationship', 'linked to', 'related to'. These queries ask about statistical relationships without necessarily implying causality.
102
+ - DESCRIPTIVE: These queries ask for summaries, descriptions, trends, or statistics about the data without investigating causal links or relationships (e.g., "Show sales over time", "What is the average age?").
103
+ - OTHER: Use this if the query does not fit any of the above categories.
104
+
105
+ Choose the most appropriate type from: EFFECT_ESTIMATION, COUNTERFACTUAL, CORRELATION, DESCRIPTIVE, OTHER.
106
+
107
+ Variable Roles to Identify:
108
+ - treatment: The intervention or variable whose effect is being studied.
109
+ - outcome: The result or variable being measured.
110
+ - covariates_mentioned: Variables explicitly mentioned to control for or adjust for.
111
+ - grouping_vars: Variables identifying specific subgroups for analysis (e.g., 'for men', 'in the sales department').
112
+ - instruments_mentioned: Variables explicitly mentioned as potential instruments.
113
+
114
+ Constraints: Conditions applied to the analysis (e.g., filters on columns, specific time periods).
115
+
116
+ Dataset Path Mentioned: Extract the file path or URL if explicitly stated in the query.
117
+
118
+ **Output ONLY a valid JSON object** matching this exact schema (no explanations, notes, or surrounding text):
119
+ ```json
120
+ {{
121
+ "query_type": "<Identified Query Type>",
122
+ "variables": {{
123
+ "treatment": ["<Treatment Variable(s) Mentioned>"],
124
+ "outcome": ["<Outcome Variable(s) Mentioned>"],
125
+ "covariates_mentioned": ["<Covariate(s) Mentioned>"],
126
+ "grouping_vars": ["<Grouping Variable(s) Mentioned>"],
127
+ "instruments_mentioned": ["<Instrument(s) Mentioned>"]
128
+ }},
129
+ "constraints": ["<Constraint 1>", "<Constraint 2>"],
130
+ "dataset_path_mentioned": "<Path Mentioned or null>"
131
+ }}
132
+ ```
133
+ If Dataset Context is provided, ensure variable names in the output JSON correspond to actual column names where possible. If no context is provided, or if a mentioned variable doesn't map directly, use the phrasing from the query.
134
+ Respond with only the JSON object.
135
+ """
136
+ return prompt
137
+
138
+ def _validate_llm_output(parsed_info: ParsedQueryInfo, dataset_info: Optional[Dict] = None) -> bool:
139
+ """Perform basic assertions on the parsed LLM output."""
140
+ # 1. Check required fields exist (Pydantic handles this on parsing)
141
+ # 2. Check query type is one of the allowed types (can add enum to Pydantic later)
142
+ allowed_types = {"EFFECT_ESTIMATION", "COUNTERFACTUAL", "CORRELATION", "DESCRIPTIVE", "OTHER"}
143
+ print(parsed_info)
144
+ assert parsed_info.query_type in allowed_types, f"Invalid query_type: {parsed_info.query_type}"
145
+
146
+ # 3. Check that if it's an effect query, treatment and outcome are likely present
147
+ if parsed_info.query_type == "EFFECT_ESTIMATION":
148
+ # Check that the lists are not empty
149
+ assert parsed_info.variables.treatment, "Treatment variable list is empty for effect query."
150
+ assert parsed_info.variables.outcome, "Outcome variable list is empty for effect query."
151
+
152
+ # 4. If dataset_info provided, check if extracted variables exist in columns
153
+ if dataset_info and (columns := dataset_info.get('columns')):
154
+ all_extracted_vars = set()
155
+ for var_list in parsed_info.variables.model_dump().values(): # Iterate through variable lists
156
+ if var_list: # Ensure var_list is not None or empty
157
+ all_extracted_vars.update(var_list)
158
+
159
+ unknown_vars = all_extracted_vars - set(columns)
160
+ # Allow for non-column variables if context is missing? Maybe relax this.
161
+ # For now, strict check if columns are provided.
162
+ if unknown_vars:
163
+ logger.warning(f"LLM mentioned variables potentially not in dataset columns: {unknown_vars}")
164
+ # Decide if this should be a hard failure (AssertionError) or just a warning.
165
+ # Let's make it a hard failure for now to enforce mapping.
166
+ raise AssertionError(f"LLM hallucinated variables not in dataset columns: {unknown_vars}")
167
+
168
+ logger.info("LLM output validation passed.")
169
+ return True
170
+
171
+ def _extract_query_information_with_llm(query: str, dataset_info: Optional[Dict] = None, llm: Optional[BaseChatModel] = None, max_retries: int = 3) -> Optional[ParsedQueryInfo]:
172
+ """Extracts query type, variables, and constraints using LLM with retries and validation."""
173
+ if not llm:
174
+ logger.error("LLM client not provided. Cannot perform LLM extraction.")
175
+ return None
176
+
177
+ last_error = None
178
+ # Bind the Pydantic model to the LLM for structured output
179
+ structured_llm = llm.with_structured_output(ParsedQueryInfo)
180
+
181
+ # Initial prompt construction
182
+ system_prompt_content = _build_llm_prompt(query, dataset_info)
183
+ messages = [HumanMessage(content=system_prompt_content)] # Start with just the detailed prompt as Human message
184
+
185
+ for attempt in range(max_retries):
186
+ logger.info(f"LLM Extraction Attempt {attempt + 1}/{max_retries}...")
187
+ try:
188
+ # --- Invoke LangChain LLM with structured output (using passed llm) ---
189
+ parsed_info = structured_llm.invoke(messages)
190
+ # ---------------------------------------------------
191
+ print(messages)
192
+ print('---------------------------------------------------')
193
+ print(parsed_info)
194
+ # Perform custom assertions/validation
195
+ if _validate_llm_output(parsed_info, dataset_info):
196
+ return parsed_info # Success!
197
+
198
+ # Catch errors specific to structured output parsing or Pydantic validation
199
+ except (OutputParserException, ValidationError, AssertionError) as e:
200
+ logger.warning(f"Validation/Parsing Error (Attempt {attempt + 1}): {e}")
201
+ last_error = e
202
+ # Add feedback message for retry
203
+ messages.append(SystemMessage(content=f"Your previous response failed validation: {str(e)}. Please revise your response to be valid JSON conforming strictly to the schema and ensure variable names exist in the dataset context."))
204
+ continue # Go to next retry
205
+ except Exception as e: # Catch other potential LLM API errors
206
+ logger.error(f"Unexpected LLM Error (Attempt {attempt + 1}): {e}", exc_info=True)
207
+ last_error = e
208
+ break # Stop retrying on unexpected API errors
209
+
210
+ logger.error(f"LLM extraction failed after {max_retries} attempts.")
211
+ if last_error:
212
+ logger.error(f"Last error: {last_error}")
213
+ return None # Indicate failure
214
+
215
+ # Add helper function to call LLM for path - needs llm argument
216
+ def _call_llm_for_path(query: str, llm: Optional[BaseChatModel] = None, max_retries: int = 2) -> Optional[str]:
217
+ """Uses LLM as a fallback to extract just the dataset path."""
218
+ if not llm:
219
+ logger.warning("LLM client not provided. Cannot perform LLM path fallback.")
220
+ return None
221
+
222
+ logger.info("Attempting LLM fallback for dataset path extraction...")
223
+ path_extractor_llm = llm.with_structured_output(ExtractedPath)
224
+ prompt = f"Extract the dataset file path (e.g., /path/to/file.csv or https://...) mentioned in the following query. Respond ONLY with the JSON object.\nQuery: \"{query}\""
225
+ messages = [HumanMessage(content=prompt)]
226
+ last_error = None
227
+
228
+ for attempt in range(max_retries):
229
+ try:
230
+ parsed_info = path_extractor_llm.invoke(messages)
231
+ if parsed_info.dataset_path:
232
+ logger.info(f"LLM fallback extracted path: {parsed_info.dataset_path}")
233
+ return parsed_info.dataset_path
234
+ else:
235
+ logger.info("LLM fallback did not find a path.")
236
+ return None # LLM explicitly found no path
237
+ except (OutputParserException, ValidationError) as e:
238
+ logger.warning(f"LLM path extraction parsing/validation error (Attempt {attempt+1}): {e}")
239
+ last_error = e
240
+ messages.append(SystemMessage(content=f"Parsing Error: {e}. Please ensure you provide valid JSON with only the 'dataset_path' key."))
241
+ continue
242
+ except Exception as e:
243
+ logger.error(f"Unexpected LLM Error during path fallback (Attempt {attempt+1}): {e}", exc_info=True)
244
+ last_error = e
245
+ break # Don't retry on unexpected errors
246
+
247
+ logger.error(f"LLM path fallback failed after {max_retries} attempts. Last error: {last_error}")
248
+ return None
249
+
250
+ # Renamed and modified function for regex path extraction + LLM fallback - needs llm argument
251
+ def extract_dataset_path(query: str, llm: Optional[BaseChatModel] = None) -> Optional[str]:
252
+ """
253
+ Extract dataset path from the query using regex patterns, with LLM fallback.
254
+
255
+ Args:
256
+ query: The user's causal question text
257
+ llm: The shared LLM client instance for fallback.
258
+
259
+ Returns:
260
+ String with dataset path or None if not found
261
+ """
262
+ # --- Regex Part (existing logic) ---
263
+ # Check for common patterns indicating dataset paths
264
+ path_patterns = [
265
+ # More specific patterns first
266
+ r"(?:dataset|data|file) (?:at|in|from|located at) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?", # Handles subdirs in path
267
+ r"(?:use|using|analyze|analyse) (?:the |)(?:dataset|data|file) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?",
268
+ # Simpler patterns
269
+ r"[\"']([^\"']+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"']", # Path in quotes
270
+ r"([a-zA-Z0-9_/.:-]+[\\/][a-zA-Z0-9_.:-]+\.csv)", # More generic path-like structure ending in .csv
271
+ r"([^\"\'.,\s]+\.csv)" # Just a .csv file name (least specific)
272
+ ]
273
+
274
+ for pattern in path_patterns:
275
+ matches = re.search(pattern, query, re.IGNORECASE)
276
+ if matches:
277
+ path = matches.group(1).strip()
278
+
279
+ # Basic check if it looks like a path
280
+ if '/' in path or '\\' in path or os.path.exists(path):
281
+ # Check if this is a valid file path immediately
282
+ if os.path.exists(path):
283
+ logger.info(f"Regex found existing path: {path}")
284
+ return path
285
+
286
+ # Check if it's in common data directories
287
+ data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
288
+ for data_dir in data_dir_paths:
289
+ potential_path = os.path.join(data_dir, os.path.basename(path))
290
+ if os.path.exists(potential_path):
291
+ logger.info(f"Regex found path in {data_dir}: {potential_path}")
292
+ return potential_path
293
+
294
+ # If not found but looks like a path, return it anyway - let downstream handle non-existence
295
+ logger.info(f"Regex found potential path (existence not verified): {path}")
296
+ return path
297
+ # Else: it might just be a word ending in .csv, ignore unless it exists
298
+ elif os.path.exists(path):
299
+ logger.info(f"Regex found existing path (simple pattern): {path}")
300
+ return path
301
+
302
+ # --- LLM Fallback ---
303
+ logger.info("Regex did not find dataset path. Trying LLM fallback...")
304
+ llm_fallback_path = _call_llm_for_path(query, llm=llm)
305
+ if llm_fallback_path:
306
+ # Optional: Add existence check here too? Or let downstream handle it.
307
+ # For now, return what LLM found.
308
+ return llm_fallback_path
309
+
310
+ logger.info("No dataset path found via regex or LLM fallback.")
311
+ return None
312
+
313
+ def parse_input(query: str, dataset_path_arg: Optional[str] = None, dataset_info: Optional[Dict] = None, llm: Optional[BaseChatModel] = None) -> Dict[str, Any]:
314
+ """
315
+ Parse the user's causal query using LLM and regex.
316
+
317
+ Args:
318
+ query: The user's causal question text.
319
+ dataset_path_arg: Path to dataset if provided directly as an argument.
320
+ dataset_info: Dictionary with dataset context (columns, types, etc.).
321
+ llm: The shared LLM client instance.
322
+
323
+ Returns:
324
+ Dict containing parsed query information.
325
+ """
326
+ result = {
327
+ "original_query": query,
328
+ "dataset_path": dataset_path_arg, # Start with argument path
329
+ "query_type": "OTHER", # Default values
330
+ "extracted_variables": {},
331
+ "constraints": []
332
+ }
333
+
334
+ # --- 1. Use LLM for core NLP tasks ---
335
+ parsed_llm_info = _extract_query_information_with_llm(query, dataset_info, llm=llm)
336
+
337
+ if parsed_llm_info:
338
+ result["query_type"] = parsed_llm_info.query_type
339
+ result["extracted_variables"] = {k: v if v is not None else [] for k, v in parsed_llm_info.variables.model_dump().items()}
340
+ result["constraints"] = parsed_llm_info.constraints if parsed_llm_info.constraints is not None else []
341
+ llm_mentioned_path = parsed_llm_info.dataset_path_mentioned
342
+ else:
343
+ logger.warning("LLM-based query information extraction failed.")
344
+ llm_mentioned_path = None
345
+ # Consider falling back to old regex methods here if critical
346
+ # logger.info("Falling back to regex-based parsing (if implemented).")
347
+
348
+ # --- 2. Determine Dataset Path (Hybrid Approach) ---
349
+ final_dataset_path = dataset_path_arg # Priority 1: Explicit argument
350
+
351
+ # Pass llm instance to the path extractor for its fallback mechanism
352
+ path_extractor = partial(extract_dataset_path, llm=llm)
353
+
354
+ if not final_dataset_path:
355
+ # Priority 2: Path mentioned in query (extracted by main LLM call)
356
+ if llm_mentioned_path and os.path.exists(llm_mentioned_path):
357
+ logger.info(f"Using dataset path mentioned by LLM: {llm_mentioned_path}")
358
+ final_dataset_path = llm_mentioned_path
359
+ elif llm_mentioned_path: # Check data dirs if path not absolute
360
+ data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
361
+ base_name = os.path.basename(llm_mentioned_path)
362
+ for data_dir in data_dir_paths:
363
+ potential_path = os.path.join(data_dir, base_name)
364
+ if os.path.exists(potential_path):
365
+ logger.info(f"Using dataset path mentioned by LLM (found in {data_dir}): {potential_path}")
366
+ final_dataset_path = potential_path
367
+ break
368
+ if not final_dataset_path:
369
+ logger.warning(f"LLM mentioned path '{llm_mentioned_path}' but it was not found.")
370
+
371
+ if not final_dataset_path:
372
+ # Priority 3: Path extracted by dedicated Regex + LLM fallback function
373
+ logger.info("Attempting dedicated dataset path extraction (Regex + LLM Fallback)...")
374
+ extracted_path = path_extractor(query) # Call the partial function with llm bound
375
+ if extracted_path:
376
+ final_dataset_path = extracted_path
377
+
378
+ result["dataset_path"] = final_dataset_path
379
+
380
+ # Check if a path was found ultimately
381
+ if not result["dataset_path"]:
382
+ logger.warning("Could not determine dataset path from query or arguments.")
383
+ else:
384
+ logger.info(f"Final dataset path determined: {result['dataset_path']}")
385
+
386
+ return result
387
+
388
+ # --- Old Regex-based functions (Commented out or removed) ---
389
+ # def determine_query_type(query: str) -> str:
390
+ # ... (implementation removed)
391
+
392
+ # def extract_variables(query: str) -> Dict[str, Any]:
393
+ # ... (implementation removed)
394
+
395
+ # def detect_constraints(query: str) -> List[str]:
396
+ # ... (implementation removed)
397
+ # --- End Old Functions ---
398
+
399
+ # Renamed function for regex path extraction
400
+ def extract_dataset_path_regex(query: str) -> Optional[str]:
401
+ """
402
+ Extract dataset path from the query using regex patterns.
403
+
404
+ Args:
405
+ query: The user's causal question text
406
+
407
+ Returns:
408
+ String with dataset path or None if not found
409
+ """
410
+ # Check for common patterns indicating dataset paths
411
+ path_patterns = [
412
+ # More specific patterns first
413
+ r"(?:dataset|data|file) (?:at|in|from|located at) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?", # Handles subdirs in path
414
+ r"(?:use|using|analyze|analyse) (?:the |)(?:dataset|data|file) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?",
415
+ # Simpler patterns
416
+ r"[\"']([^\"']+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"']", # Path in quotes
417
+ r"([a-zA-Z0-9_/.:-]+[\\/][a-zA-Z0-9_.:-]+\.csv)", # More generic path-like structure ending in .csv
418
+ r"([^\"\'.,\s]+\.csv)" # Just a .csv file name (least specific)
419
+ ]
420
+
421
+ for pattern in path_patterns:
422
+ matches = re.search(pattern, query, re.IGNORECASE)
423
+ if matches:
424
+ path = matches.group(1).strip()
425
+
426
+ # Basic check if it looks like a path
427
+ if '/' in path or '\\' in path or os.path.exists(path):
428
+ # Check if this is a valid file path immediately
429
+ if os.path.exists(path):
430
+ logger.info(f"Regex found existing path: {path}")
431
+ return path
432
+
433
+ # Check if it's in common data directories
434
+ data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
435
+ # Also check relative to current dir (often useful)
436
+ # base_name = os.path.basename(path)
437
+ for data_dir in data_dir_paths:
438
+ potential_path = os.path.join(data_dir, os.path.basename(path))
439
+ if os.path.exists(potential_path):
440
+ logger.info(f"Regex found path in {data_dir}: {potential_path}")
441
+ return potential_path
442
+
443
+ # If not found but looks like a path, return it anyway - let downstream handle non-existence
444
+ logger.info(f"Regex found potential path (existence not verified): {path}")
445
+ return path
446
+ # Else: it might just be a word ending in .csv, ignore unless it exists
447
+ elif os.path.exists(path):
448
+ logger.info(f"Regex found existing path (simple pattern): {path}")
449
+ return path
450
+
451
+ # TODO: Optional: Add LLM fallback call here if regex fails
452
+ # if no path found:
453
+ # llm_fallback_path = call_llm_for_path(query)
454
+ # return llm_fallback_path
455
+
456
+ return None
auto_causal/components/method_validator.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Method validator component for causal inference methods.
3
+
4
+ This module validates the selected causal inference method against
5
+ dataset characteristics and available variables.
6
+ """
7
+
8
+ from typing import Dict, List, Any, Optional
9
+
10
+
11
+ def validate_method(method_info: Dict[str, Any], dataset_analysis: Dict[str, Any],
12
+ variables: Dict[str, Any]) -> Dict[str, Any]:
13
+ """
14
+ Validate the selected causal method against dataset characteristics.
15
+
16
+ Args:
17
+ method_info: Information about the selected method from decision_tree
18
+ dataset_analysis: Dataset analysis results from dataset_analyzer
19
+ variables: Identified variables from query_interpreter
20
+
21
+ Returns:
22
+ Dict with validation results:
23
+ - valid: Boolean indicating if method is valid
24
+ - concerns: List of concerns/issues with the selected method
25
+ - alternative_suggestions: Alternative methods if the selected method is problematic
26
+ - recommended_method: Updated method recommendation if issues are found
27
+ """
28
+ method = method_info.get("selected_method")
29
+ assumptions = method_info.get("method_assumptions", [])
30
+
31
+ # Get required variables
32
+ treatment = variables.get("treatment_variable")
33
+ outcome = variables.get("outcome_variable")
34
+ covariates = variables.get("covariates", [])
35
+ time_variable = variables.get("time_variable")
36
+ group_variable = variables.get("group_variable")
37
+ instrument_variable = variables.get("instrument_variable")
38
+ running_variable = variables.get("running_variable")
39
+ cutoff_value = variables.get("cutoff_value")
40
+
41
+ # Initialize validation result
42
+ validation_result = {
43
+ "valid": True,
44
+ "concerns": [],
45
+ "alternative_suggestions": [],
46
+ "recommended_method": method,
47
+ }
48
+
49
+ # Common validations for all methods
50
+ if treatment is None:
51
+ validation_result["valid"] = False
52
+ validation_result["concerns"].append("Treatment variable is not identified")
53
+
54
+ if outcome is None:
55
+ validation_result["valid"] = False
56
+ validation_result["concerns"].append("Outcome variable is not identified")
57
+
58
+ # Method-specific validations
59
+ if method == "propensity_score_matching":
60
+ validate_propensity_score_matching(validation_result, dataset_analysis, variables)
61
+
62
+ elif method == "regression_adjustment":
63
+ validate_regression_adjustment(validation_result, dataset_analysis, variables)
64
+
65
+ elif method == "instrumental_variable":
66
+ validate_instrumental_variable(validation_result, dataset_analysis, variables)
67
+
68
+ elif method == "difference_in_differences":
69
+ validate_difference_in_differences(validation_result, dataset_analysis, variables)
70
+
71
+ elif method == "regression_discontinuity_design":
72
+ validate_regression_discontinuity(validation_result, dataset_analysis, variables)
73
+
74
+ elif method == "backdoor_adjustment":
75
+ validate_backdoor_adjustment(validation_result, dataset_analysis, variables)
76
+
77
+ # If there are serious concerns, recommend alternatives
78
+ if not validation_result["valid"]:
79
+ validation_result["recommended_method"] = recommend_alternative(
80
+ method, validation_result["concerns"], method_info.get("alternatives", [])
81
+ )
82
+
83
+ # Make sure assumptions are listed in the validation result
84
+ validation_result["assumptions"] = assumptions
85
+ print("--------------------------")
86
+ print("Validation result:", validation_result)
87
+ print("--------------------------")
88
+ return validation_result
89
+
90
+
91
+ def validate_propensity_score_matching(validation_result: Dict[str, Any],
92
+ dataset_analysis: Dict[str, Any],
93
+ variables: Dict[str, Any]) -> None:
94
+ """
95
+ Validate propensity score matching method requirements.
96
+
97
+ Args:
98
+ validation_result: Current validation result to update
99
+ dataset_analysis: Dataset analysis results
100
+ variables: Identified variables
101
+ """
102
+ treatment = variables.get("treatment_variable")
103
+ covariates = variables.get("covariates", [])
104
+
105
+ # Check if treatment is binary using column_categories
106
+ is_binary = dataset_analysis.get("column_categories", {}).get(treatment) == "binary"
107
+
108
+ # Fallback to check if the column has only two unique values (0 and 1)
109
+ if not is_binary:
110
+ column_types = dataset_analysis.get("column_types", {})
111
+ if column_types.get(treatment) == "int64" or column_types.get(treatment) == "int32":
112
+ # Assuming int type with only 0s and 1s is binary
113
+ is_binary = True
114
+
115
+ if not is_binary:
116
+ validation_result["valid"] = False
117
+ validation_result["concerns"].append(
118
+ "Treatment variable is not binary, which is required for propensity score matching"
119
+ )
120
+
121
+ # Check if there are sufficient covariates
122
+ if len(covariates) < 2:
123
+ validation_result["concerns"].append(
124
+ "Few covariates identified, which may limit the effectiveness of propensity score matching"
125
+ )
126
+
127
+ # Check for sufficient overlap
128
+ variable_relationships = dataset_analysis.get("variable_relationships", {})
129
+ treatment_imbalance = variable_relationships.get("treatment_imbalance", 0.5)
130
+
131
+ if treatment_imbalance < 0.1 or treatment_imbalance > 0.9:
132
+ validation_result["concerns"].append(
133
+ "Treatment groups are highly imbalanced, which may lead to poor matching quality"
134
+ )
135
+ validation_result["alternative_suggestions"].append("regression_adjustment")
136
+
137
+
138
+ def validate_regression_adjustment(validation_result: Dict[str, Any],
139
+ dataset_analysis: Dict[str, Any],
140
+ variables: Dict[str, Any]) -> None:
141
+ """
142
+ Validate regression adjustment method requirements.
143
+
144
+ Args:
145
+ validation_result: Current validation result to update
146
+ dataset_analysis: Dataset analysis results
147
+ variables: Identified variables
148
+ """
149
+ outcome = variables.get("outcome_variable")
150
+
151
+ # Check outcome type for appropriate regression model
152
+ outcome_data = dataset_analysis.get("variable_types", {}).get(outcome, {})
153
+ outcome_type = outcome_data.get("type")
154
+
155
+ if outcome_type == "categorical" and outcome_data.get("n_categories", 0) > 2:
156
+ validation_result["concerns"].append(
157
+ "Outcome is categorical with multiple categories, which may require multinomial regression"
158
+ )
159
+
160
+ # Check for potential nonlinear relationships
161
+ nonlinear_relationships = dataset_analysis.get("nonlinear_relationships", False)
162
+
163
+ if nonlinear_relationships:
164
+ validation_result["concerns"].append(
165
+ "Potential nonlinear relationships detected, which may require more flexible models"
166
+ )
167
+
168
+
169
+ def validate_instrumental_variable(validation_result: Dict[str, Any],
170
+ dataset_analysis: Dict[str, Any],
171
+ variables: Dict[str, Any]) -> None:
172
+ """
173
+ Validate instrumental variable method requirements.
174
+
175
+ Args:
176
+ validation_result: Current validation result to update
177
+ dataset_analysis: Dataset analysis results
178
+ variables: Identified variables
179
+ """
180
+ instrument_variable = variables.get("instrument_variable")
181
+ treatment = variables.get("treatment_variable")
182
+
183
+ if instrument_variable is None:
184
+ validation_result["valid"] = False
185
+ validation_result["concerns"].append(
186
+ "No instrumental variable identified, which is required for this method"
187
+ )
188
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
189
+ return
190
+
191
+ # Check for instrument strength (correlation with treatment)
192
+ variable_relationships = dataset_analysis.get("variable_relationships", {})
193
+ instrument_correlation = next(
194
+ (corr.get("correlation", 0) for corr in variable_relationships.get("correlations", [])
195
+ if corr.get("var1") == instrument_variable and corr.get("var2") == treatment
196
+ or corr.get("var1") == treatment and corr.get("var2") == instrument_variable),
197
+ 0
198
+ )
199
+
200
+ if abs(instrument_correlation) < 0.2:
201
+ validation_result["concerns"].append(
202
+ "Instrument appears weak (low correlation with treatment), which may lead to bias"
203
+ )
204
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
205
+
206
+
207
+ def validate_difference_in_differences(validation_result: Dict[str, Any],
208
+ dataset_analysis: Dict[str, Any],
209
+ variables: Dict[str, Any]) -> None:
210
+ """
211
+ Validate difference-in-differences method requirements.
212
+
213
+ Args:
214
+ validation_result: Current validation result to update
215
+ dataset_analysis: Dataset analysis results
216
+ variables: Identified variables
217
+ """
218
+ time_variable = variables.get("time_variable")
219
+ group_variable = variables.get("group_variable")
220
+
221
+ if time_variable is None:
222
+ validation_result["valid"] = False
223
+ validation_result["concerns"].append(
224
+ "No time variable identified, which is required for difference-in-differences"
225
+ )
226
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
227
+
228
+ if group_variable is None:
229
+ validation_result["valid"] = False
230
+ validation_result["concerns"].append(
231
+ "No group variable identified, which is required for difference-in-differences"
232
+ )
233
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
234
+
235
+ # Check for parallel trends
236
+ temporal_structure = dataset_analysis.get("temporal_structure", {})
237
+ parallel_trends = temporal_structure.get("parallel_trends", False)
238
+
239
+ if not parallel_trends:
240
+ validation_result["concerns"].append(
241
+ "No evidence of parallel trends, which is a key assumption for difference-in-differences"
242
+ )
243
+ validation_result["alternative_suggestions"].append("synthetic_control")
244
+
245
+
246
+ def validate_regression_discontinuity(validation_result: Dict[str, Any],
247
+ dataset_analysis: Dict[str, Any],
248
+ variables: Dict[str, Any]) -> None:
249
+ """
250
+ Validate regression discontinuity method requirements.
251
+
252
+ Args:
253
+ validation_result: Current validation result to update
254
+ dataset_analysis: Dataset analysis results
255
+ variables: Identified variables
256
+ """
257
+ running_variable = variables.get("running_variable")
258
+ cutoff_value = variables.get("cutoff_value")
259
+
260
+ if running_variable is None:
261
+ validation_result["valid"] = False
262
+ validation_result["concerns"].append(
263
+ "No running variable identified, which is required for regression discontinuity"
264
+ )
265
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
266
+
267
+ if cutoff_value is None:
268
+ validation_result["valid"] = False
269
+ validation_result["concerns"].append(
270
+ "No cutoff value identified, which is required for regression discontinuity"
271
+ )
272
+ validation_result["alternative_suggestions"].append("propensity_score_matching")
273
+
274
+ # Check for discontinuity at threshold
275
+ discontinuities = dataset_analysis.get("discontinuities", {})
276
+ has_discontinuity = discontinuities.get("has_discontinuities", False)
277
+
278
+ if not has_discontinuity:
279
+ validation_result["valid"] = False
280
+ validation_result["concerns"].append(
281
+ "No clear discontinuity detected at the threshold, which is necessary for this method"
282
+ )
283
+ validation_result["alternative_suggestions"].append("regression_adjustment")
284
+
285
+ def validate_backdoor_adjustment(validation_result: Dict[str, Any],
286
+ dataset_analysis: Dict[str, Any],
287
+ variables: Dict[str, Any]) -> None:
288
+ """
289
+ Validate backdoor adjustment method requirements.
290
+
291
+ Args:
292
+ validation_result: Current validation result to update
293
+ dataset_analysis: Dataset analysis results
294
+ variables: Identified variables
295
+ """
296
+ covariates = variables.get("covariates", [])
297
+
298
+ if len(covariates) == 0:
299
+ validation_result["valid"] = False
300
+ validation_result["concerns"].append(
301
+ "No covariates identified for backdoor adjustment"
302
+ )
303
+ validation_result["alternative_suggestions"].append("regression_adjustment")
304
+
305
+
306
+ def recommend_alternative(method: str, concerns: List[str], alternatives: List[str]) -> str:
307
+ """
308
+ Recommend an alternative method if the current one has issues.
309
+
310
+ Args:
311
+ method: Current method
312
+ concerns: List of concerns with the current method
313
+ alternatives: List of alternative methods suggested by the decision tree
314
+
315
+ Returns:
316
+ String with the recommended method
317
+ """
318
+ # If there are alternatives, recommend the first one
319
+ if alternatives:
320
+ return alternatives[0]
321
+
322
+ # If no alternatives, use regression adjustment as a fallback
323
+ if method != "regression_adjustment":
324
+ return "regression_adjustment"
325
+
326
+ # If regression adjustment is also problematic, use propensity score matching
327
+ return "propensity_score_matching"
auto_causal/components/output_formatter.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Output formatter component for causal inference results.
3
+
4
+ This module formats the results of causal analysis into a clear,
5
+ structured output for presentation to the user.
6
+ """
7
+
8
+ from typing import Dict, List, Any, Optional
9
+ import json # Add this import at the top of the file
10
+
11
+ # Import the new model
12
+ from auto_causal.models import FormattedOutput
13
+
14
+ # Add this module-level variable, typically near imports or at the top
15
+ CURRENT_OUTPUT_LOG_FILE = None
16
+
17
+ # Revert signature and logic to handle results and structured explanation
18
+ def format_output(
19
+ query: str,
20
+ method: str,
21
+ results: Dict[str, Any],
22
+ explanation: Dict[str, Any],
23
+ dataset_analysis: Optional[Dict[str, Any]] = None,
24
+ dataset_description: Optional[str] = None
25
+ ) -> FormattedOutput:
26
+ """
27
+ Format final results including numerical estimates and explanations.
28
+
29
+ Args:
30
+ query: Original user query
31
+ method: Causal inference method used (string name)
32
+ results: Numerical results from method_executor_tool
33
+ explanation: Structured explanation object from explainer_tool
34
+ dataset_analysis: Optional dictionary of dataset analysis results
35
+ dataset_description: Optional string description of the dataset
36
+
37
+ Returns:
38
+ Dict with formatted output fields ready for presentation.
39
+ """
40
+ # Extract numerical results
41
+ effect_estimate = results.get("effect_estimate")
42
+ confidence_interval = results.get("confidence_interval")
43
+ p_value = results.get("p_value")
44
+ effect_se = results.get("standard_error") # Get SE if available
45
+
46
+ # Format method name for readability
47
+ method_name_formatted = _format_method_name(method)
48
+
49
+ # Extract explanation components (assuming explainer returns structured dict again)
50
+ # If explainer returns single string, adjust this
51
+ method_explanation_text = explanation.get("method_explanation", "")
52
+ interpretation_guide = explanation.get("interpretation_guide", "")
53
+ limitations = explanation.get("limitations", [])
54
+ assumptions_discussion = explanation.get("assumptions", "") # Assuming key is 'assumptions'
55
+ practical_implications = explanation.get("practical_implications", "")
56
+ # Add back final_explanation_text if explainer provides it
57
+ # final_explanation_text = explanation.get("final_explanation_text")
58
+
59
+ # Create summary using numerical results
60
+ ci_text = ""
61
+ if confidence_interval and confidence_interval[0] is not None and confidence_interval[1] is not None:
62
+ ci_text = f" (95% CI: [{confidence_interval[0]:.4f}, {confidence_interval[1]:.4f}])"
63
+
64
+ p_value_text = f", p={p_value:.4f}" if p_value is not None else ""
65
+ effect_text = f"{effect_estimate:.4f}" if effect_estimate is not None else "N/A"
66
+
67
+ summary = (
68
+ f"Based on {method_name_formatted}, the estimated causal effect is {effect_text}"
69
+ f"{ci_text}{p_value_text}. {_create_effect_interpretation(effect_estimate, p_value)}"
70
+ f" See details below regarding assumptions and limitations."
71
+ )
72
+
73
+ # Assemble formatted output dictionary
74
+ results_dict = {
75
+ "query": query,
76
+ "method_used": method_name_formatted,
77
+ "causal_effect": effect_estimate,
78
+ "standard_error": effect_se,
79
+ "confidence_interval": confidence_interval,
80
+ "p_value": p_value,
81
+ "summary": summary,
82
+ "method_explanation": method_explanation_text,
83
+ "interpretation_guide": interpretation_guide,
84
+ "limitations": limitations,
85
+ "assumptions": assumptions_discussion,
86
+ "practical_implications": practical_implications,
87
+ # "full_explanation_text": final_explanation_text # Optionally include combined text
88
+ }
89
+ final_results_dict = {key : results_dict[key] for key in {"query", "method_used", "causal_effect", "standard_error", "confidence_interval"}}
90
+ # print(final_results_dict)
91
+
92
+ # Validate and instantiate the Pydantic model
93
+ try:
94
+ formatted_output_model = FormattedOutput(**results_dict)
95
+ except Exception as e: # Catch validation errors specifically if needed
96
+ # Handle validation error - perhaps log and return a default or raise
97
+ print(f"Error creating FormattedOutput model: {e}") # Or use logger
98
+ # Decide on error handling: raise, return None, return default?
99
+ # For now, re-raising might be simplest if the structure is expected
100
+ raise ValueError(f"Failed to create FormattedOutput from results: {e}")
101
+
102
+ return formatted_output_model # Return the Pydantic model instance
103
+
104
+
105
+ def _format_method_name(method: str) -> str:
106
+ """Format method name for readability."""
107
+ method_names = {
108
+ "propensity_score_matching": "Propensity Score Matching",
109
+ "regression_adjustment": "Regression Adjustment",
110
+ "instrumental_variable": "Instrumental Variable Analysis",
111
+ "difference_in_differences": "Difference-in-Differences",
112
+ "regression_discontinuity": "Regression Discontinuity Design",
113
+ "backdoor_adjustment": "Backdoor Adjustment",
114
+ "propensity_score_weighting": "Propensity Score Weighting"
115
+ }
116
+ return method_names.get(method, method.replace("_", " ").title())
117
+
118
+ # Reinstate helper function for interpretation
119
+ def _create_effect_interpretation(effect: Optional[float], p_value: Optional[float] = None) -> str:
120
+ """Create a basic interpretation of the effect."""
121
+ if effect is None:
122
+ return "Effect estimate not available."
123
+
124
+ significance = ""
125
+ if p_value is not None:
126
+ significance = "statistically significant" if p_value < 0.05 else "not statistically significant"
127
+
128
+ magnitude = ""
129
+ if abs(effect) < 0.01:
130
+ magnitude = "no practical effect"
131
+ elif abs(effect) < 0.1:
132
+ magnitude = "a small effect"
133
+ elif abs(effect) < 0.5:
134
+ magnitude = "a moderate effect"
135
+ else:
136
+ magnitude = "a substantial effect"
137
+
138
+ return f"This suggests {magnitude}{f' and is {significance}' if significance else ''}."
auto_causal/components/query_interpreter.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Query interpreter component for causal inference.
3
+
4
+ This module provides functionality to match query concepts to actual dataset variables,
5
+ identifying treatment, outcome, and covariate variables for causal inference analysis.
6
+ """
7
+
8
+ import re
9
+ from typing import Dict, List, Any, Optional, Union, Tuple
10
+ import pandas as pd
11
+ import logging
12
+ import numpy as np
13
+ from auto_causal.config import get_llm_client
14
+ # Import LLM and message types
15
+ from langchain_core.language_models import BaseChatModel
16
+ from langchain_core.messages import HumanMessage
17
+ from langchain_core.exceptions import OutputParserException
18
+ # Import base Pydantic models needed directly
19
+ from pydantic import BaseModel, ValidationError
20
+ from dowhy import CausalModel
21
+ import json
22
+
23
+ # Import shared Pydantic models from the central location
24
+ from auto_causal.models import (
25
+ LLMSelectedVariable,
26
+ LLMSelectedCovariates,
27
+ LLMIVars,
28
+ LLMRDDVars,
29
+ LLMRCTCheck,
30
+ LLMTreatmentReferenceLevel,
31
+ LLMInteractionSuggestion,
32
+ LLMEstimand,
33
+ # LLMDIDCheck,
34
+ # LLMDiDTemporalVars,
35
+ # LLMDiDGroupVars,
36
+ # LLMRDDCheck,
37
+ # LLMRDDVarsExtended
38
+ )
39
+
40
+ # Import the new prompt templates
41
+ from auto_causal.prompts.method_identification_prompts import (
42
+ IV_IDENTIFICATION_PROMPT_TEMPLATE,
43
+ RDD_IDENTIFICATION_PROMPT_TEMPLATE,
44
+ RCT_IDENTIFICATION_PROMPT_TEMPLATE,
45
+ TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE,
46
+ INTERACTION_TERM_IDENTIFICATION_PROMPT_TEMPLATE,
47
+ TREATMENT_VAR_IDENTIFICATION_PROMPT_TEMPLATE,
48
+ OUTCOME_VAR_IDENTIFICATION_PROMPT_TEMPLATE,
49
+ COVARIATES_IDENTIFICATION_PROMPT_TEMPLATE,
50
+ ESTIMAND_PROMPT_TEMPLATE,
51
+ CONFOUNDER_IDENTIFICATION_PROMPT_TEMPLATE,
52
+ DID_TERM_IDENTIFICATION_PROMPT_TEMPLATE)
53
+
54
+
55
+ # Assume central models are defined elsewhere or keep local definitions for now
56
+ # from ..models import ...
57
+
58
+ # --- Pydantic models for LLM structured output ---
59
+ # REMOVED - Now defined in causalscientist/auto_causal/models.py
60
+ # class LLMSelectedVariable(BaseModel): ...
61
+ # class LLMSelectedCovariates(BaseModel): ...
62
+ # class LLMIVars(BaseModel): ...
63
+ # class LLMRDDVars(BaseModel): ...
64
+ # class LLMRCTCheck(BaseModel): ...
65
+
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+ def infer_treatment_variable_type(treatment_variable: str, column_categories: Dict[str, str],
70
+ dataset_analysis: Dict[str, Any]) -> str:
71
+ """
72
+ Determine treatment variable type from column category and unique value count
73
+ Args:
74
+ treatment_variable: name of the treatment variable
75
+ column_categories: mapping of column names to their categories
76
+ dataset_analysis: exploratory analysis results
77
+
78
+ Returns:
79
+ str: type of the treatment variable (e.g., "binary", "continuous", etc
80
+ """
81
+
82
+ treatment_variable_type = "unknown"
83
+ if treatment_variable and treatment_variable in column_categories:
84
+ category = column_categories[treatment_variable]
85
+ logger.info(f"Category for treatment '{treatment_variable}' is '{category}'.")
86
+
87
+ if category == "continuous_numeric":
88
+ treatment_variable_type = "continuous"
89
+
90
+ elif category == "discrete_numeric":
91
+ num_unique = dataset_analysis.get("column_nunique_counts", {}).get(treatment_variable, -1)
92
+ if num_unique > 10:
93
+ logger.info(f"'{treatment_variable}' has {num_unique} unique values, treating as continuous.")
94
+ treatment_variable_type = "continuous"
95
+ elif num_unique == 2:
96
+ logger.info(f"'{treatment_variable}' has 2 unique values, treating as binary.")
97
+ treatment_variable_type = "binary"
98
+ elif num_unique > 0:
99
+ logger.info(f"'{treatment_variable}' has {num_unique} unique values, treating as discrete_multi_value.")
100
+ treatment_variable_type = "discrete_multi_value"
101
+ else:
102
+ logger.info(f"'{treatment_variable}' unique value count unknown or too few.")
103
+ treatment_variable_type = "discrete_numeric_unknown_cardinality"
104
+
105
+ elif category in ["binary", "binary_categorical"]:
106
+ treatment_variable_type = "binary"
107
+
108
+ elif category in ["categorical", "categorical_numeric"]:
109
+ num_unique = dataset_analysis.get("column_nunique_counts", {}).get(treatment_variable, -1)
110
+ if num_unique == 2:
111
+ treatment_variable_type = "binary"
112
+ elif num_unique > 0:
113
+ treatment_variable_type = "categorical_multi_value"
114
+ else:
115
+ treatment_variable_type = "categorical_unknown_cardinality"
116
+
117
+ else:
118
+ logger.warning(f"Unmapped category '{category}' for '{treatment_variable}', setting as 'other'.")
119
+ treatment_variable_type = "other"
120
+
121
+ elif treatment_variable:
122
+ logger.warning(f"'{treatment_variable}' not found in column_categories.")
123
+ else:
124
+ logger.info("No treatment variable identified.")
125
+
126
+ logger.info(f"Final Determined Treatment Variable Type: {treatment_variable_type}")
127
+ return treatment_variable_type
128
+
129
+ def determine_treatment_reference_level(is_rct: Optional[bool], llm: Optional[BaseChatModel], treatment_variable: Optional[str],
130
+ query_text: str, dataset_description: Optional[str], file_path: Optional[str],
131
+ columns: List[str]) -> Optional[str]:
132
+ """
133
+ Determines the treatment reference level
134
+ """
135
+
136
+ # If LLM didn't explicitly say RCT, default to False or keep None?
137
+ # Let's default to False if LLM didn't provide a boolean value.
138
+ if is_rct is None: is_rct = False
139
+ treatment_reference_level = None
140
+
141
+ if llm and treatment_variable and treatment_variable in columns:
142
+ treatment_values_sample = []
143
+ if file_path:
144
+ try:
145
+ df = pd.read_csv(file_path)
146
+ if treatment_variable in df.columns:
147
+ unique_vals = df[treatment_variable].unique()
148
+ treatment_values_sample = [item.item() if hasattr(item, 'item') else item for item in unique_vals][:10]
149
+ if treatment_values_sample:
150
+ logger.info(f"Successfully read treatment values sample from dataset at '{file_path}' for variable '{treatment_variable}'.")
151
+ else:
152
+ logger.info(f"'{treatment_variable}' in '{file_path}' has no unique values or is empty.")
153
+ else:
154
+ logger.warning(f"'{treatment_variable}' not found in dataset columns at '{file_path}'.")
155
+ except FileNotFoundError:
156
+ logger.warning(f"File not found at: {file_path}")
157
+ except pd.errors.EmptyDataError:
158
+ logger.warning(f"Empty file at: {file_path}")
159
+ except Exception as e:
160
+ logger.warning(f"Error reading dataset at '{file_path}' for '{treatment_variable}': {e}")
161
+
162
+ if not treatment_values_sample:
163
+ logger.warning(f"No unique values found for treatment '{treatment_variable}'. LLM prompt will receive empty list.")
164
+ else:
165
+ logger.info(f"Final treatment values sample: {treatment_values_sample}")
166
+
167
+ try:
168
+ prompt = TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description or 'N/A', treatment_variable=treatment_variable, treatment_variable_values=treatment_values_sample)
169
+ ref_result = _call_llm_for_var(llm, prompt, LLMTreatmentReferenceLevel)
170
+ if ref_result and ref_result.reference_level:
171
+ if treatment_values_sample and ref_result.reference_level not in treatment_values_sample:
172
+ logger.warning(f"LLM reference level '{ref_result.reference_level}' not in sampled values for '{treatment_variable}'.")
173
+ treatment_reference_level = ref_result.reference_level
174
+ logger.info(f"LLM identified reference level: {treatment_reference_level} (Reason: {ref_result.reasoning})")
175
+ elif ref_result:
176
+ logger.info(f"LLM returned no reference level (Reason: {ref_result.reasoning})")
177
+ except Exception as e:
178
+ logger.error(f"LLM error for treatment reference level: {e}")
179
+
180
+ return treatment_reference_level
181
+
182
+ def identify_interaction_term(llm: Optional[BaseChatModel], treatment_variable: Optional[str], covariates: List[str],
183
+ column_categories: Dict[str, str], query_text: str,
184
+ dataset_description: Optional[str]) -> Tuple[bool, Optional[str]]:
185
+ """
186
+ Identifies the interaction term based on the query and the dataset information
187
+ """
188
+
189
+ interaction_term_suggested, interaction_variable_candidate = False, None
190
+
191
+ if llm and treatment_variable and covariates:
192
+ try:
193
+ covariates_list_str = "\n".join([f"- {cov}: {column_categories.get(cov, 'Unknown')}" for cov in covariates]) or "No covariates identified or available."
194
+ prompt = INTERACTION_TERM_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description or 'N/A', treatment_variable=treatment_variable, covariates_list_with_types=covariates_list_str)
195
+ result = _call_llm_for_var(llm, prompt, LLMInteractionSuggestion)
196
+ if result:
197
+ interaction_term_suggested = result.interaction_needed if result.interaction_needed is not None else False
198
+ if interaction_term_suggested and result.interaction_variable:
199
+ if result.interaction_variable in covariates:
200
+ interaction_variable_candidate = result.interaction_variable
201
+ logger.info(f"LLM suggested interaction: needed={interaction_term_suggested}, variable='{interaction_variable_candidate}' (Reason: {result.reasoning})")
202
+ else:
203
+ logger.warning(f"LLM suggested variable '{result.interaction_variable}' not in covariates {covariates}. Ignoring.")
204
+ interaction_term_suggested = False
205
+ elif interaction_term_suggested:
206
+ logger.info(f"LLM suggested interaction is needed but no variable provided (Reason: {result.reasoning})")
207
+ else:
208
+ logger.info(f"LLM suggested no interaction is needed (Reason: {result.reasoning})")
209
+ else:
210
+ logger.warning("LLM returned no result for interaction term suggestion.")
211
+ except Exception as e:
212
+ logger.error(f"LLM error during interaction term check: {e}")
213
+
214
+ return interaction_term_suggested, interaction_variable_candidate
215
+
216
+
217
+ def interpret_query(query_info: Dict[str, Any], dataset_analysis: Dict[str, Any],
218
+ dataset_description: Optional[str] = None) -> Dict[str, Any]:
219
+ """
220
+ Interpret query using hybrid heuristic/LLM approach to identify variables.
221
+
222
+ Args:
223
+ query_info: Information extracted from the user's query (text, hints).
224
+ dataset_analysis: Information about the dataset structure (columns, types, etc.).
225
+ dataset_description: Optional textual description of the dataset.
226
+ llm: Optional language model instance.
227
+
228
+ Returns:
229
+ Dict containing identified variables (treatment, outcome, covariates, etc., and is_rct).
230
+ """
231
+
232
+ logger.info("Interpreting query with hybrid approach...")
233
+ llm = get_llm_client()
234
+
235
+ query_text = query_info.get("query_text", "")
236
+ columns = dataset_analysis.get("columns", [])
237
+ column_categories = dataset_analysis.get("column_categories", {})
238
+ file_path = dataset_analysis["dataset_info"]["file_path"]
239
+
240
+
241
+ # --- Identify Treatment ---
242
+ treatment_hints = query_info.get("potential_treatments", [])
243
+ dataset_treatments = dataset_analysis.get("potential_treatments", [])
244
+ treatment_variable = _identify_variable_hybrid(role="treatment", query_hints=treatment_hints,
245
+ dataset_suggestions=dataset_treatments, columns=columns,
246
+ column_categories=column_categories,
247
+ prioritize_types=["binary", "binary_categorical", "discrete_numeric","continuous_numeric"], # Prioritize binary/discrete
248
+ query_text=query_text, dataset_description=dataset_description,llm=llm)
249
+ logger.info(f"Identified Treatment: {treatment_variable}")
250
+ treatment_variable_type = infer_treatment_variable_type(treatment_variable, column_categories, dataset_analysis)
251
+
252
+
253
+ # --- Identify Outcome ---
254
+ outcome_hints = query_info.get("outcome_hints", [])
255
+ dataset_outcomes = dataset_analysis.get("potential_outcomes", [])
256
+ outcome_variable = _identify_variable_hybrid(role="outcome", query_hints=outcome_hints, dataset_suggestions=dataset_outcomes,
257
+ columns=columns, column_categories=column_categories,
258
+ prioritize_types=["continuous_numeric", "discrete_numeric"], # Prioritize numeric
259
+ exclude_vars=[treatment_variable], # Exclude treatment
260
+ query_text=query_text, dataset_description=dataset_description, llm=llm)
261
+ logger.info(f"Identified Outcome: {outcome_variable}")
262
+
263
+ # --- Identify Covariates ---
264
+ covariate_hints = query_info.get("covariates_hints", [])
265
+ covariates = _identify_covariates_hybrid("covars", treatment_variable=treatment_variable, outcome_variable=outcome_variable,
266
+ columns=columns, column_categories=column_categories, query_hints=covariate_hints,
267
+ query_text=query_text, dataset_description=dataset_description, llm=llm)
268
+ logger.info(f"Identified Covariates: {covariates}")
269
+
270
+ # --- Identify Confounders ---
271
+ confounder_hints = query_info.get("covariates_hints", [])
272
+ confounders = _identify_covariates_hybrid("confounders", treatment_variable=treatment_variable, outcome_variable=outcome_variable,
273
+ columns=columns, column_categories=column_categories, query_hints=confounder_hints,
274
+ query_text=query_text, dataset_description=dataset_description, llm=llm)
275
+ logger.info(f"Identified Confounders: {confounders}")
276
+
277
+ # --- Identify Time/Group (from dataset analysis) ---
278
+ time_variable = None
279
+ group_variable = None
280
+ has_temporal = dataset_analysis.get("temporal_structure", {}).get("has_temporal_structure", False)
281
+ temporal_structure = dataset_analysis.get("temporal_structure", {})
282
+ if temporal_structure.get("has_temporal_structure", False):
283
+ time_variable = temporal_structure.get("time_column") or temporal_structure.get("temporal_columns", [None])[0]
284
+ if temporal_structure.get("is_panel_data", False):
285
+ group_variable = temporal_structure.get("id_column")
286
+ logger.info(f"Identified Time Var: {time_variable}, Group Var: {group_variable}, temporal structure: {temporal_structure}")
287
+
288
+ # --- Identify IV/RDD/RCT using LLM ---
289
+ instrument_variable = None
290
+ running_variable = None
291
+ cutoff_value = None
292
+ is_rct = None
293
+ smd_score = None
294
+
295
+ if llm:
296
+ try:
297
+ # Check for RCT
298
+ prompt_rct = _create_identify_prompt("whether data is from RCT", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
299
+ rct_result = _call_llm_for_var(llm, prompt_rct, LLMRCTCheck)
300
+ is_rct = rct_result.is_rct if rct_result else None
301
+ logger.info(f"LLM identified RCT: {is_rct}")
302
+
303
+ # Check for IV
304
+ prompt_iv = _create_identify_prompt("instrumental variable", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
305
+ iv_result = _call_llm_for_var(llm, prompt_iv, LLMIVars)
306
+ instrument_variable = iv_result.instrument_variable if iv_result else None
307
+ if instrument_variable not in columns:
308
+ instrument_variable = None
309
+ logger.info(f"LLM identified IV: {instrument_variable}")
310
+
311
+ # Check for RDD
312
+ prompt_rdd = _create_identify_prompt("regression discontinuity (running variable and cutoff)", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
313
+ rdd_result = _call_llm_for_var(llm, prompt_rdd, LLMRDDVars)
314
+ if rdd_result:
315
+ running_variable = rdd_result.running_variable
316
+ cutoff_value = rdd_result.cutoff_value
317
+ if running_variable not in columns or cutoff_value is None:
318
+ running_variable = None
319
+ cutoff_value = None
320
+ logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}")
321
+
322
+ ## For graph based methods
323
+ exclude_cols = [treatment_variable, outcome_variable]
324
+ potential_covariates = [col for col in columns if col not in exclude_cols and col is not None]
325
+ usable_covariates = [col for col in potential_covariates if column_categories.get(col) not in ["text_or_other"]]
326
+ logger.info(f"Usable covariates for graph: {usable_covariates}")
327
+
328
+ estimand_prompt = ESTIMAND_PROMPT_TEMPLATE.format(query=query_text,dataset_description=dataset_description,
329
+ dataset_columns=usable_covariates,
330
+ treatment=treatment_variable, outcome=outcome_variable)
331
+
332
+ estimand_result = _call_llm_for_var(llm, estimand_prompt, LLMEstimand)
333
+ estimand = "ate" if "ate" in estimand_result.estimand.strip().lower() else "att"
334
+ logger.info(f"LLM identified estimand: {estimand}")
335
+
336
+ ## Did Term
337
+ did_term_prompt = DID_TERM_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description,
338
+ column_info=columns, time_variable=time_variable,
339
+ group_variable=group_variable, column_types=column_categories)
340
+ did_term_result = _call_llm_for_var(llm, did_term_prompt, LLMRDDVars)
341
+ did_term_result = did_term_result.did_term if did_term_result in columns else None
342
+ logger.info(f"LLM identified DiD term: {did_term_result}")
343
+
344
+
345
+
346
+ #smd_score_all = compute_smd(dataset_analysis.get("data", pd.DataFrame()), treatment_variable, usable_covariates)
347
+ #smd_score = smd_score_all.get("ate", 0.0) if smd_score_all else 0.0
348
+ #logger.info(f"Computed SMD score: {smd_score}")
349
+
350
+ #logger.debug(f"Computed SMD score for {estimand}: {smd_score}")
351
+
352
+
353
+ except Exception as e:
354
+ logger.error(f"Error during LLM checks for IV/RDD/RCT: {e}")
355
+
356
+
357
+
358
+ # --- Identify Treatment Reference Level ---
359
+ treatment_reference_level = determine_treatment_reference_level(is_rct=is_rct, llm=llm, treatment_variable=treatment_variable,
360
+ query_text=query_text, dataset_description=dataset_description,
361
+ file_path=file_path, columns=columns)
362
+
363
+ # --- Identify Interaction Term Suggestion ---
364
+ interaction_term_suggested, interaction_variable_candidate = identify_interaction_term(llm=llm, treatment_variable=treatment_variable,
365
+ covariates=covariates,
366
+ column_categories=column_categories, query_text=query_text,
367
+ dataset_description=dataset_description)
368
+
369
+
370
+ # --- Consolidate ---
371
+ return {
372
+ "treatment_variable": treatment_variable,
373
+ "treatment_variable_type": treatment_variable_type,
374
+ "outcome_variable": outcome_variable,
375
+ "covariates": covariates,
376
+ "time_variable": time_variable,
377
+ "group_variable": group_variable,
378
+ "instrument_variable": instrument_variable,
379
+ "running_variable": running_variable,
380
+ "cutoff_value": cutoff_value,
381
+ "is_rct": is_rct,
382
+ "treatment_reference_level": treatment_reference_level,
383
+ "interaction_term_suggested": interaction_term_suggested,
384
+ "interaction_variable_candidate": interaction_variable_candidate,
385
+ "confounders": confounders,
386
+ "did_term": did_term_result
387
+ }
388
+
389
+ def compute_smd(df: pd.DataFrame, treat, covars_list) -> Dict[str, float]:
390
+ """
391
+ Computed the standardized mean differences (SMD) for the treatment variable
392
+ Args:
393
+ df (pd.DataFrame): The dataset.
394
+ treat (str): Name of the binary treatment column (0/1).
395
+ covars_list (List[str]): List of covariate names to consider for SMD calculation
396
+
397
+ Returns:
398
+ Dict{str ->float}: the standardized mean difference (SMD)
399
+ """
400
+ logger.info(f"Computing SMD for treatment variable '{treat}' with covariates: {covars_list}")
401
+ df_t = df[df[treat] == 1]
402
+ df_c = df[df[treat] == 0]
403
+
404
+ covariates = covars_list if covars_list else df.columns.tolist()
405
+ smd_ate = np.zeros(len(covariates))
406
+ smd_att = np.zeros(len(covariates))
407
+
408
+ for i, col in enumerate(covariates):
409
+ try:
410
+ m_t, m_c = df_t[col].mean(), df_c[col].mean()
411
+ s_t, s_c = df_t[col].std(ddof=0), df_c[col].std(ddof=0)
412
+ pooled = np.sqrt((s_t**2 + s_c**2) / 2)
413
+
414
+ ate_val = 0.0 if pooled == 0 else (m_t - m_c) / pooled
415
+ att_val = 0.0 if s_t == 0 else (m_t - m_c) / s_t
416
+
417
+ smd_ate.append(ate_val)
418
+ smd_att.append(att_val)
419
+ except Exception as e:
420
+ logger.warning(f"SMD computation failed for column '{col}': {e}")
421
+ continue
422
+
423
+ avg_ate = np.nanmean(np.abs(smd_ate))
424
+ avg_att = np.nanmean(np.abs(smd_att))
425
+
426
+ return {"ate":avg_ate, "att":avg_att}
427
+
428
+
429
+
430
+ # --- Helper Functions for Hybrid Identification ---
431
+ def _identify_variable_hybrid(role: str, query_hints: List[str], dataset_suggestions: List[str],
432
+ columns: List[str], column_categories: Dict[str, str],
433
+ prioritize_types: List[str], query_text: str,
434
+ dataset_description: Optional[str],llm: Optional[BaseChatModel],
435
+ exclude_vars: Optional[List[str]] = None) -> Optional[str]:
436
+ """
437
+ Used to identify a variable from the avaiable information by prompting the LLM. In case of failure,
438
+ it will fallback to a programmatic selection (heuristics)
439
+
440
+ Args:
441
+ role: variable type (treatment or outcome)
442
+ query_hints: hints from the query for this variable
443
+ dataset_suggestions: dataset-specific suggestions for this variable
444
+ columns: list of available columns in the dataset
445
+ column_categories: mapping of column names to their categories
446
+ prioritize_types: types to prioritize for this variable
447
+ query_text: the original query text
448
+ dataset_description: description of the dataset
449
+ llm: language model
450
+ exclude_vars: list of variables to exclude from selection (e.g., treatment for outcome)
451
+ Returns:
452
+ str: name of the identified variable, or None if not found
453
+ """
454
+
455
+ candidates = set()
456
+ available_columns = [c for c in columns if c not in (exclude_vars or [])]
457
+ if not available_columns: return None
458
+
459
+ # 1. Exact matches from hints
460
+ for hint in query_hints:
461
+ if hint in available_columns:
462
+ candidates.add(hint)
463
+ # 2. Add dataset suggestions
464
+ for sugg in dataset_suggestions:
465
+ if sugg in available_columns:
466
+ candidates.add(sugg)
467
+
468
+ # 3. Programmatic Filtering based on type
469
+ plausible_candidates = [c for c in candidates if column_categories.get(c) in prioritize_types]
470
+
471
+ if llm:
472
+ if role == "treatment":
473
+ prompt_template = TREATMENT_VAR_IDENTIFICATION_PROMPT_TEMPLATE
474
+ elif role == "outcome":
475
+ prompt_template = OUTCOME_VAR_IDENTIFICATION_PROMPT_TEMPLATE
476
+ else:
477
+ raise ValueError(f"Unsupported role for LLM variable identification: {role}")
478
+
479
+ prompt = prompt_template.format(query=query_text, description=dataset_description,
480
+ column_info=available_columns)
481
+ llm_choice = _call_llm_for_var(llm, prompt, LLMSelectedVariable)
482
+
483
+ if llm_choice and llm_choice.variable_name in available_columns:
484
+ logger.info(f"LLM selected {role}: {llm_choice.variable_name}")
485
+ return llm_choice.variable_name
486
+ else:
487
+ fallback = plausible_candidates[0] if plausible_candidates else None
488
+ logger.warning(f"LLM failed to select valid {role}. Falling back to: {fallback}")
489
+ return fallback
490
+
491
+ if plausible_candidates:
492
+ logger.info(f"No LLM provided. Using first plausible {role}: {plausible_candidates[0]}")
493
+ return plausible_candidates[0]
494
+
495
+ logger.warning(f"No plausible candidates for {role}. Cannot identify variable.")
496
+ return None
497
+
498
+
499
+ def _identify_covariates_hybrid(role, treatment_variable: Optional[str], outcome_variable: Optional[str],
500
+ columns: List[str], column_categories: Dict[str, str], query_hints: List[str],
501
+ query_text: str, dataset_description: Optional[str], llm: Optional[BaseChatModel]) -> List[str]:
502
+ """
503
+ Prompts an LLM to identify the covariates
504
+ """
505
+
506
+ # 1. Initial Programmatic Filtering
507
+ exclude_cols = [treatment_variable, outcome_variable]
508
+ potential_covariates = [col for col in columns if col not in exclude_cols and col is not None]
509
+
510
+ # Filter out unusable types
511
+ usable_covariates = [col for col in potential_covariates if column_categories.get(col) not in ["text_or_other"]]
512
+ logger.debug(f"Initial usable covariates: {usable_covariates}")
513
+
514
+ # 2. LLM Refinement (if LLM available)
515
+ if llm:
516
+ logger.info("Using LLM to refine covariate list...")
517
+ prompt = ""
518
+ if role == "covars":
519
+ prompt = COVARIATES_IDENTIFICATION_PROMPT_TEMPLATE.format("covars", query=query_text, description=dataset_description,
520
+ column_info=", ".join(usable_covariates),
521
+ treatment=treatment_variable, outcome=outcome_variable)
522
+ elif role == "confounders":
523
+ prompt = CONFOUNDER_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description,
524
+ column_info=", ".join(usable_covariates),
525
+ treatment=treatment_variable, outcome=outcome_variable)
526
+ llm_selection = _call_llm_for_var(llm, prompt, LLMSelectedCovariates)
527
+
528
+ if llm_selection and llm_selection.covariates:
529
+ # Validate LLM output against available columns
530
+ valid_llm_covs = [c for c in llm_selection.covariates if c in usable_covariates]
531
+ if len(valid_llm_covs) < len(llm_selection.covariates):
532
+ logger.warning("LLM suggested covariates not found in initial usable list.")
533
+ if valid_llm_covs: # Use LLM selection if it's valid and non-empty
534
+ logger.info(f"LLM refined covariates to: {valid_llm_covs}")
535
+ return valid_llm_covs[:10] # Cap at 10
536
+ else:
537
+ logger.warning("LLM refinement failed or returned empty/invalid list. Falling back.")
538
+ else:
539
+ logger.warning("LLM refinement call failed or returned no covariates. Falling back.")
540
+
541
+ # 3. Fallback to Programmatic List (Capped)
542
+ logger.info(f"Using programmatically determined covariates (capped at 10): {usable_covariates[:10]}")
543
+ return usable_covariates[:10]
544
+
545
+ def _create_identify_prompt(target: str, query: str, description: Optional[str], columns: List[str],
546
+ categories: Dict[str,str], treatment: Optional[str], outcome: Optional[str]) -> str:
547
+ """
548
+ Creates a prompt to ask LLM to identify specific roles like IV, RDD, or RCT by selecting and formatting a specific template
549
+ """
550
+ column_info = "\n".join([f"- '{c}' (Type: {categories.get(c, 'Unknown')})" for c in columns])
551
+
552
+ # Select the appropriate detailed prompt template based on the target
553
+ if "instrumental variable" in target.lower():
554
+ template = IV_IDENTIFICATION_PROMPT_TEMPLATE
555
+ elif "regression discontinuity" in target.lower():
556
+ template = RDD_IDENTIFICATION_PROMPT_TEMPLATE
557
+ elif "rct" in target.lower():
558
+ template = RCT_IDENTIFICATION_PROMPT_TEMPLATE
559
+ else:
560
+ # Fallback or error? For now, let's raise an error if target is unexpected.
561
+ logger.error(f"Unsupported target for _create_identify_prompt: {target}")
562
+ raise ValueError(f"Unsupported target for specific identification prompt: {target}")
563
+
564
+ # Format the selected template with the provided context
565
+ prompt = template.format(query=query, description=description or 'N/A', column_info=column_info,
566
+ treatment=treatment or 'N/A', outcome=outcome or 'N/A')
567
+ return prompt
568
+
569
+ def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel) -> Optional[BaseModel]:
570
+ """Helper to call LLM with structured output and handle errors."""
571
+ try:
572
+ messages = [HumanMessage(content=prompt)]
573
+ structured_llm = llm.with_structured_output(pydantic_model)
574
+ parsed_result = structured_llm.invoke(messages)
575
+ return parsed_result
576
+ except (OutputParserException, ValidationError) as e:
577
+ logger.error(f"LLM call failed parsing/validation for {pydantic_model.__name__}: {e}")
578
+ except Exception as e:
579
+ logger.error(f"LLM call failed unexpectedly for {pydantic_model.__name__}: {e}", exc_info=True)
580
+ return None
auto_causal/components/state_manager.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ State management utilities for the auto_causal workflow.
3
+
4
+ This module provides utility functions to create standardized state updates
5
+ for passing between tools in the auto_causal agent workflow.
6
+ """
7
+
8
+ from typing import Dict, Any, Optional
9
+
10
+ def create_workflow_state_update(
11
+ current_step: str,
12
+ step_completed_flag: bool,
13
+ next_tool: str,
14
+ next_step_reason: str,
15
+ error: Optional[str] = None
16
+ ) -> Dict[str, Any]:
17
+ """
18
+ Create a standardized workflow state update dictionary.
19
+
20
+ Args:
21
+ current_step: Current step in the workflow (e.g., "input_processing")
22
+ step_completed_flag: Flag indicating which step was completed (e.g., "query_parsed")
23
+ next_tool: Name of the next tool to call
24
+ next_step_reason: Reason message for the next step
25
+ error: Optional error message if the step failed
26
+
27
+ Returns:
28
+ Dictionary containing the workflow_state sub-dictionary
29
+ """
30
+ state_update = {
31
+ "workflow_state": {
32
+ "current_step": current_step,
33
+ current_step + "_completed": step_completed_flag,
34
+ "next_tool": next_tool,
35
+ "next_step_reason": next_step_reason
36
+ }
37
+ }
38
+ if error:
39
+ state_update["workflow_state"]["error_message"] = error
40
+ return state_update
auto_causal/config.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # auto_causal/config.py
2
+ """Central configuration for AutoCausal, including LLM client setup."""
3
+
4
+ import os
5
+ import logging
6
+ from typing import Optional
7
+
8
+ # Langchain imports
9
+ from langchain_core.language_models import BaseChatModel
10
+ from langchain_openai import ChatOpenAI # Default
11
+ from langchain_anthropic import ChatAnthropic # Example
12
+ from langchain_google_genai import ChatGoogleGenerativeAI
13
+ # Add other providers if needed, e.g.:
14
+ # from langchain_community.chat_models import ChatOllama
15
+ from dotenv import load_dotenv
16
+ from langchain_deepseek import ChatDeepSeek
17
+ # Create a disk-backed SQLite cache:
18
+ # Import Together provider
19
+ from langchain_together import ChatTogether
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Load .env file when this module is loaded
24
+ load_dotenv()
25
+
26
+ def get_llm_client(provider: Optional[str] = None, model_name: Optional[str] = None, **kwargs) -> BaseChatModel:
27
+ """Initializes and returns the chosen LLM client based on provider.
28
+
29
+ Reads provider, model, and API keys from environment variables if not passed directly.
30
+ Defaults to OpenAI GPT-4o-mini if no provider/model specified.
31
+ """
32
+ # Prioritize arguments, then environment variables, then defaults
33
+ provider = provider or os.getenv("LLM_PROVIDER", "openai")
34
+ provider = provider.lower()
35
+
36
+ # Default model depends on provider
37
+ default_models = {
38
+ "openai": "gpt-4o-mini",
39
+ "anthropic": "claude-3-5-sonnet-latest",
40
+ "together": "deepseek-ai/DeepSeek-V3", # Default Together model
41
+ "gemini" : "gemini-2.5-flash",
42
+ "deepseek" : "deepseek-chat"
43
+ }
44
+
45
+ model_name = model_name or os.getenv("LLM_MODEL", default_models.get(provider, default_models["openai"]))
46
+
47
+ api_key = None
48
+ if model_name not in ['o3-mini', 'o3', 'o4-mini']:
49
+ kwargs.setdefault("temperature", 0) # Default temperature if not provided
50
+
51
+ logger.info(f"Initializing LLM client: Provider='{provider}', Model='{model_name}'")
52
+
53
+ try:
54
+ if provider == "openai":
55
+ api_key = os.getenv("OPENAI_API_KEY")
56
+ if not api_key:
57
+ raise ValueError("OPENAI_API_KEY not found in environment.")
58
+ return ChatOpenAI(model=model_name, api_key=api_key, **kwargs)
59
+
60
+ elif provider == "anthropic":
61
+ api_key = os.getenv("ANTHROPIC_API_KEY")
62
+ if not api_key:
63
+ raise ValueError("ANTHROPIC_API_KEY not found in environment.")
64
+ return ChatAnthropic(model=model_name, api_key=api_key, **kwargs, streaming=False)
65
+
66
+ elif provider == "together":
67
+ api_key = os.getenv("TOGETHER_API_KEY")
68
+ if not api_key:
69
+ raise ValueError("TOGETHER_API_KEY not found in environment.")
70
+ return ChatTogether(model=model_name, api_key=api_key, **kwargs)
71
+
72
+ elif provider == "gemini":
73
+ api_key = os.getenv("GEMINI_API_KEY")
74
+ if not api_key:
75
+ raise ValueError("GEMINI_API_KEY not found in environment.")
76
+ return ChatGoogleGenerativeAI(model=model_name, **kwargs, function_calling="auto")
77
+
78
+ elif provider == "deepseek":
79
+ api_key = os.getenv("DEEPSEEK_API_KEY")
80
+ if not api_key:
81
+ raise ValueError("DEEPSEEK_API_KEY not found in environment.")
82
+ return ChatDeepSeek(model=model_name, **kwargs)
83
+
84
+ # Example for Ollama (ensure langchain_community is installed)
85
+ # elif provider == "ollama":
86
+ # try:
87
+ # from langchain_community.chat_models import ChatOllama
88
+ # return ChatOllama(model=model_name, **kwargs)
89
+ # except ImportError:
90
+ # raise ValueError("langchain_community needed for Ollama. Run `pip install langchain-community`")
91
+
92
+ else:
93
+ raise ValueError(f"Unsupported LLM provider: {provider}")
94
+
95
+ except Exception as e:
96
+ logger.error(f"Failed to initialize LLM (Provider: {provider}, Model: {model_name}): {e}")
97
+ raise # Re-raise the exception
auto_causal/methods/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Causal inference methods for the auto_causal module.
3
+
4
+ This package contains implementations of various causal inference methods
5
+ that can be selected and applied by the auto_causal pipeline.
6
+ """
7
+
8
+ from .causal_method import CausalMethod
9
+ from .propensity_score.matching import estimate_effect as psm_estimate_effect
10
+ from .propensity_score.weighting import estimate_effect as psw_estimate_effect
11
+ from .instrumental_variable.estimator import estimate_effect as iv_estimate_effect
12
+ from .difference_in_differences.estimator import estimate_effect as did_estimate_effect
13
+ from .diff_in_means.estimator import estimate_effect as dim_estimate_effect
14
+ from .linear_regression.estimator import estimate_effect as lr_estimate_effect
15
+ from .backdoor_adjustment.estimator import estimate_effect as ba_estimate_effect
16
+ from .regression_discontinuity.estimator import estimate_effect as rdd_estimate_effect
17
+ from .generalized_propensity_score.estimator import estimate_effect_gps
18
+
19
+ # Mapping of method names to their implementation functions
20
+ METHOD_MAPPING = {
21
+ "propensity_score_matching": psm_estimate_effect,
22
+ "propensity_score_weighting": psw_estimate_effect,
23
+ "instrumental_variable": iv_estimate_effect,
24
+ "difference_in_differences": did_estimate_effect,
25
+ "regression_discontinuity_design": rdd_estimate_effect,
26
+ "backdoor_adjustment": ba_estimate_effect,
27
+ "linear_regression": lr_estimate_effect,
28
+ "diff_in_means": dim_estimate_effect,
29
+ "generalized_propensity_score": estimate_effect_gps,
30
+ }
31
+
32
+ __all__ = [
33
+ "CausalMethod",
34
+ "psm_estimate_effect",
35
+ "psw_estimate_effect",
36
+ "iv_estimate_effect",
37
+ "did_estimate_effect",
38
+ "rdd_estimate_effect",
39
+ "dim_estimate_effect",
40
+ "lr_estimate_effect",
41
+ "ba_estimate_effect",
42
+ "METHOD_MAPPING",
43
+ "estimate_effect_gps",
44
+ ]
auto_causal/methods/backdoor_adjustment/__init__.py ADDED
File without changes
auto_causal/methods/backdoor_adjustment/diagnostics.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diagnostic checks for Backdoor Adjustment models (typically OLS).
3
+ """
4
+
5
+ from typing import Dict, Any, List
6
+ import statsmodels.api as sm
7
+ from statsmodels.stats.diagnostic import het_breuschpagan
8
+ from statsmodels.stats.stattools import jarque_bera, durbin_watson
9
+ from statsmodels.regression.linear_model import RegressionResultsWrapper
10
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
11
+ import pandas as pd
12
+ import numpy as np
13
+ import logging
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def run_backdoor_diagnostics(results: RegressionResultsWrapper, X: pd.DataFrame) -> Dict[str, Any]:
18
+ """
19
+ Runs diagnostic checks on a fitted OLS model used for backdoor adjustment.
20
+
21
+ Args:
22
+ results: A fitted statsmodels OLS results object.
23
+ X: The design matrix (including constant and all predictors) used.
24
+
25
+ Returns:
26
+ Dictionary containing diagnostic metrics.
27
+ """
28
+ diagnostics = {}
29
+ details = {}
30
+
31
+ try:
32
+ details['r_squared'] = results.rsquared
33
+ details['adj_r_squared'] = results.rsquared_adj
34
+ details['f_statistic'] = results.fvalue
35
+ details['f_p_value'] = results.f_pvalue
36
+ details['n_observations'] = int(results.nobs)
37
+ details['degrees_of_freedom_resid'] = int(results.df_resid)
38
+ details['durbin_watson'] = durbin_watson(results.resid) if results.nobs > 5 else 'N/A (Too few obs)' # Autocorrelation
39
+
40
+ # --- Normality of Residuals (Jarque-Bera) ---
41
+ try:
42
+ if results.nobs >= 2:
43
+ jb_value, jb_p_value, skew, kurtosis = jarque_bera(results.resid)
44
+ details['residuals_normality_jb_stat'] = jb_value
45
+ details['residuals_normality_jb_p_value'] = jb_p_value
46
+ details['residuals_skewness'] = skew
47
+ details['residuals_kurtosis'] = kurtosis
48
+ details['residuals_normality_status'] = "Normal" if jb_p_value > 0.05 else "Non-Normal"
49
+ else:
50
+ details['residuals_normality_status'] = "N/A (Too few obs)"
51
+ except Exception as e:
52
+ logger.warning(f"Could not run Jarque-Bera test: {e}")
53
+ details['residuals_normality_status'] = "Test Failed"
54
+
55
+ # --- Homoscedasticity (Breusch-Pagan) ---
56
+ try:
57
+ if X.shape[0] > X.shape[1]: # Needs more observations than predictors
58
+ lm_stat, lm_p_value, f_stat, f_p_value = het_breuschpagan(results.resid, X)
59
+ details['homoscedasticity_bp_lm_stat'] = lm_stat
60
+ details['homoscedasticity_bp_lm_p_value'] = lm_p_value
61
+ details['homoscedasticity_status'] = "Homoscedastic" if lm_p_value > 0.05 else "Heteroscedastic"
62
+ else:
63
+ details['homoscedasticity_status'] = "N/A (Too few obs or too many predictors)"
64
+ except Exception as e:
65
+ logger.warning(f"Could not run Breusch-Pagan test: {e}")
66
+ details['homoscedasticity_status'] = "Test Failed"
67
+
68
+ # --- Multicollinearity (VIF - Placeholder/Basic) ---
69
+ # Full VIF requires calculating for each predictor vs others.
70
+ # Providing a basic status based on condition number as a proxy.
71
+ try:
72
+ cond_no = np.linalg.cond(results.model.exog)
73
+ details['model_condition_number'] = cond_no
74
+ if cond_no > 30:
75
+ details['multicollinearity_status'] = "High (Cond. No. > 30)"
76
+ elif cond_no > 10:
77
+ details['multicollinearity_status'] = "Moderate (Cond. No. > 10)"
78
+ else:
79
+ details['multicollinearity_status'] = "Low"
80
+ except Exception as e:
81
+ logger.warning(f"Could not calculate condition number: {e}")
82
+ details['multicollinearity_status'] = "Check Failed"
83
+ # details['VIF'] = "Not Fully Implemented"
84
+
85
+ # --- Linearity (Still requires visual inspection) ---
86
+ details['linearity_check'] = "Requires visual inspection (e.g., residual vs fitted plot)"
87
+
88
+ return {"status": "Success", "details": details}
89
+
90
+ except Exception as e:
91
+ logger.error(f"Error running Backdoor Adjustment diagnostics: {e}")
92
+ return {"status": "Failed", "error": str(e), "details": details}
auto_causal/methods/backdoor_adjustment/estimator.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Backdoor Adjustment Estimator using Regression.
3
+
4
+ Estimates the Average Treatment Effect (ATE) by regressing the outcome on the
5
+ treatment and a set of covariates assumed to satisfy the backdoor criterion.
6
+ """
7
+ import pandas as pd
8
+ import numpy as np
9
+ import statsmodels.api as sm
10
+ from typing import Dict, Any, List, Optional
11
+ import logging
12
+ from langchain.chat_models.base import BaseChatModel # For type hinting llm
13
+
14
+ # Import diagnostics and llm assist (placeholders for now)
15
+ from .diagnostics import run_backdoor_diagnostics
16
+ from .llm_assist import interpret_backdoor_results, identify_backdoor_set
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ def estimate_effect(
21
+ df: pd.DataFrame,
22
+ treatment: str,
23
+ outcome: str,
24
+ covariates: List[str], # Backdoor set - Required for this method
25
+ query: Optional[str] = None, # For potential LLM use
26
+ llm: Optional[BaseChatModel] = None, # For potential LLM use
27
+ **kwargs # To capture any other potential arguments
28
+ ) -> Dict[str, Any]:
29
+ """
30
+ Estimates the causal effect using Backdoor Adjustment (via OLS regression).
31
+
32
+ Assumes the provided `covariates` list satisfies the backdoor criterion.
33
+
34
+ Args:
35
+ df: Input DataFrame.
36
+ treatment: Name of the treatment variable column.
37
+ outcome: Name of the outcome variable column.
38
+ covariates: List of covariate names forming the backdoor adjustment set.
39
+ query: Optional user query for context (e.g., for LLM).
40
+ llm: Optional Language Model instance.
41
+ **kwargs: Additional keyword arguments.
42
+
43
+ Returns:
44
+ Dictionary containing estimation results:
45
+ - 'effect_estimate': The estimated coefficient for the treatment variable.
46
+ - 'p_value': The p-value associated with the treatment coefficient.
47
+ - 'confidence_interval': The 95% confidence interval for the effect.
48
+ - 'standard_error': The standard error of the treatment coefficient.
49
+ - 'formula': The regression formula used.
50
+ - 'model_summary': Summary object from statsmodels.
51
+ - 'diagnostics': Placeholder for diagnostic results.
52
+ - 'interpretation': LLM interpretation.
53
+ """
54
+ if not covariates: # Check if the list is empty or None
55
+ raise ValueError("Backdoor Adjustment requires a non-empty list of covariates (adjustment set).")
56
+
57
+ required_cols = [treatment, outcome] + covariates
58
+ missing_cols = [col for col in required_cols if col not in df.columns]
59
+ if missing_cols:
60
+ raise ValueError(f"Missing required columns for Backdoor Adjustment: {missing_cols}")
61
+
62
+ # Prepare data for statsmodels (add constant, handle potential NaNs)
63
+ df_analysis = df[required_cols].dropna()
64
+ if df_analysis.empty:
65
+ raise ValueError("No data remaining after dropping NaNs for required columns.")
66
+
67
+ X = df_analysis[[treatment] + covariates]
68
+ X = sm.add_constant(X) # Add intercept
69
+ y = df_analysis[outcome]
70
+
71
+ # Build the formula string for reporting
72
+ formula = f"{outcome} ~ {treatment} + " + " + ".join(covariates) + " + const"
73
+ logger.info(f"Running Backdoor Adjustment regression: {formula}")
74
+
75
+ try:
76
+ model = sm.OLS(y, X)
77
+ results = model.fit()
78
+
79
+ effect_estimate = results.params[treatment]
80
+ p_value = results.pvalues[treatment]
81
+ conf_int = results.conf_int(alpha=0.05).loc[treatment].tolist()
82
+ std_err = results.bse[treatment]
83
+
84
+ # Run diagnostics (Placeholders)
85
+ # Pass the full design matrix X for potential VIF checks etc.
86
+ diag_results = run_backdoor_diagnostics(results, X)
87
+
88
+ # Get interpretation
89
+ interpretation = interpret_backdoor_results(results, diag_results, treatment, covariates, llm=llm)
90
+
91
+ return {
92
+ 'effect_estimate': effect_estimate,
93
+ 'p_value': p_value,
94
+ 'confidence_interval': conf_int,
95
+ 'standard_error': std_err,
96
+ 'formula': formula,
97
+ 'model_summary': results.summary(),
98
+ 'diagnostics': diag_results,
99
+ 'interpretation': interpretation,
100
+ 'method_used': 'Backdoor Adjustment (OLS)'
101
+ }
102
+
103
+ except Exception as e:
104
+ logger.error(f"Backdoor Adjustment failed: {e}")
105
+ raise
auto_causal/methods/backdoor_adjustment/llm_assist.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM assistance functions for Backdoor Adjustment analysis.
3
+ """
4
+
5
+ from typing import List, Dict, Any, Optional
6
+ import logging
7
+
8
+ # Imported for type hinting
9
+ from langchain.chat_models.base import BaseChatModel
10
+ from statsmodels.regression.linear_model import RegressionResultsWrapper
11
+
12
+ # Import shared LLM helpers
13
+ from auto_causal.utils.llm_helpers import call_llm_with_json_output
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def identify_backdoor_set(
18
+ df_cols: List[str],
19
+ treatment: str,
20
+ outcome: str,
21
+ query: Optional[str] = None,
22
+ existing_covariates: Optional[List[str]] = None, # Allow user to provide some
23
+ llm: Optional[BaseChatModel] = None
24
+ ) -> List[str]:
25
+ """
26
+ Use LLM to suggest a potential backdoor adjustment set (confounders).
27
+
28
+ Tries to identify variables that affect both treatment and outcome.
29
+
30
+ Args:
31
+ df_cols: List of available column names in the dataset.
32
+ treatment: Treatment variable name.
33
+ outcome: Outcome variable name.
34
+ query: User's causal query text (provides context).
35
+ existing_covariates: Covariates already considered/provided by user.
36
+ llm: Optional LLM model instance.
37
+
38
+ Returns:
39
+ List of suggested variable names for the backdoor adjustment set.
40
+ """
41
+ if llm is None:
42
+ logger.warning("No LLM provided for backdoor set identification.")
43
+ return existing_covariates or []
44
+
45
+ # Exclude treatment and outcome from potential confounders
46
+ potential_confounders = [c for c in df_cols if c not in [treatment, outcome]]
47
+ if not potential_confounders:
48
+ return existing_covariates or []
49
+
50
+ prompt = f"""
51
+ You are assisting with identifying a backdoor adjustment set for causal inference.
52
+ The goal is to find observed variables that confound the relationship between the treatment and outcome.
53
+ Assume the causal effect of '{treatment}' on '{outcome}' is of interest.
54
+
55
+ User query context (optional): {query}
56
+ Available variables in the dataset (excluding treatment and outcome): {potential_confounders}
57
+ Variables already specified as covariates by user (if any): {existing_covariates}
58
+
59
+ Based *only* on the variable names and the query context, identify which of the available variables are likely to be common causes (confounders) of both '{treatment}' and '{outcome}'.
60
+ These variables should be included in the backdoor adjustment set.
61
+ Consider variables that likely occurred *before* or *at the same time as* the treatment.
62
+
63
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
64
+ {{
65
+ "suggested_backdoor_set": ["confounder1", "confounder2", ...]
66
+ }}
67
+ Include variables from the user-provided list if they seem appropriate as confounders.
68
+ If no plausible confounders are identified among the available variables, return an empty list.
69
+ """
70
+
71
+ response = call_llm_with_json_output(llm, prompt)
72
+
73
+ suggested_set = []
74
+ if response and "suggested_backdoor_set" in response and isinstance(response["suggested_backdoor_set"], list):
75
+ # Basic validation
76
+ valid_vars = [item for item in response["suggested_backdoor_set"] if isinstance(item, str)]
77
+ if len(valid_vars) != len(response["suggested_backdoor_set"]):
78
+ logger.warning("LLM returned non-string items in suggested_backdoor_set list.")
79
+ suggested_set = valid_vars
80
+ else:
81
+ logger.warning(f"Failed to get valid backdoor set recommendations from LLM. Response: {response}")
82
+
83
+ # Combine with existing covariates, removing duplicates
84
+ final_set = list(dict.fromkeys((existing_covariates or []) + suggested_set))
85
+ return final_set
86
+
87
+ def interpret_backdoor_results(
88
+ results: RegressionResultsWrapper,
89
+ diagnostics: Dict[str, Any],
90
+ treatment_var: str,
91
+ covariates: List[str],
92
+ llm: Optional[BaseChatModel] = None
93
+ ) -> str:
94
+ """
95
+ Use LLM to interpret Backdoor Adjustment results.
96
+
97
+ Args:
98
+ results: Fitted statsmodels OLS results object.
99
+ diagnostics: Dictionary of diagnostic results.
100
+ treatment_var: Name of the treatment variable.
101
+ covariates: List of covariates used in the adjustment set.
102
+ llm: Optional LLM model instance.
103
+
104
+ Returns:
105
+ String containing natural language interpretation.
106
+ """
107
+ default_interpretation = "LLM interpretation not available for Backdoor Adjustment."
108
+ if llm is None:
109
+ logger.info("LLM not provided for Backdoor Adjustment interpretation.")
110
+ return default_interpretation
111
+
112
+ try:
113
+ # --- Prepare summary for LLM ---
114
+ results_summary = {}
115
+ diag_details = diagnostics.get('details', {})
116
+
117
+ effect = results.params.get(treatment_var)
118
+ pval = results.pvalues.get(treatment_var)
119
+
120
+ results_summary['Treatment Effect Estimate'] = f"{effect:.3f}" if isinstance(effect, (int, float)) else str(effect)
121
+ results_summary['P-value'] = f"{pval:.3f}" if isinstance(pval, (int, float)) else str(pval)
122
+ try:
123
+ conf_int = results.conf_int().loc[treatment_var]
124
+ results_summary['95% Confidence Interval'] = f"[{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
125
+ except KeyError:
126
+ results_summary['95% Confidence Interval'] = "Not Found"
127
+ except Exception as ci_e:
128
+ results_summary['95% Confidence Interval'] = f"Error ({ci_e})"
129
+
130
+ results_summary['Adjustment Set (Covariates Used)'] = covariates
131
+ results_summary['Model R-squared'] = f"{diagnostics.get('details', {}).get('r_squared', 'N/A'):.3f}" if isinstance(diagnostics.get('details', {}).get('r_squared'), (int, float)) else "N/A"
132
+
133
+ diag_summary = {}
134
+ if diagnostics.get("status") == "Success":
135
+ diag_summary['Residuals Normality Status'] = diag_details.get('residuals_normality_status', 'N/A')
136
+ diag_summary['Homoscedasticity Status'] = diag_details.get('homoscedasticity_status', 'N/A')
137
+ diag_summary['Multicollinearity Status'] = diag_details.get('multicollinearity_status', 'N/A')
138
+ else:
139
+ diag_summary['Status'] = diagnostics.get("status", "Unknown")
140
+
141
+ # --- Construct Prompt ---
142
+ prompt = f"""
143
+ You are assisting with interpreting Backdoor Adjustment (Regression) results.
144
+ The key assumption is that the specified adjustment set (covariates) blocks all confounding paths between the treatment ('{treatment_var}') and outcome.
145
+
146
+ Results Summary:
147
+ {results_summary}
148
+
149
+ Diagnostics Summary (OLS model checks):
150
+ {diag_summary}
151
+
152
+ Explain these results in 2-4 concise sentences. Focus on:
153
+ 1. The estimated average treatment effect after adjusting for the specified covariates (magnitude, direction, statistical significance based on p-value < 0.05).
154
+ 2. **Crucially, mention that this estimate relies heavily on the assumption that the included covariates ('{str(covariates)[:100]}...') are sufficient to control for confounding (i.e., satisfy the backdoor criterion).**
155
+ 3. Briefly mention any major OLS diagnostic issues noted (e.g., non-normal residuals, heteroscedasticity, high multicollinearity).
156
+
157
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
158
+ {{
159
+ "interpretation": "<your concise interpretation text>"
160
+ }}
161
+ """
162
+
163
+ # --- Call LLM ---
164
+ response = call_llm_with_json_output(llm, prompt)
165
+
166
+ # --- Process Response ---
167
+ if response and isinstance(response, dict) and \
168
+ "interpretation" in response and isinstance(response["interpretation"], str):
169
+ return response["interpretation"]
170
+ else:
171
+ logger.warning(f"Failed to get valid interpretation from LLM for Backdoor Adj. Response: {response}")
172
+ return default_interpretation
173
+
174
+ except Exception as e:
175
+ logger.error(f"Error during LLM interpretation for Backdoor Adj: {e}")
176
+ return f"Error generating interpretation: {e}"
auto_causal/methods/causal_method.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract base class for all causal inference methods.
3
+
4
+ This module defines the interface that all causal inference methods
5
+ must implement, ensuring consistent behavior across different methods.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Dict, List, Any
10
+ import pandas as pd
11
+
12
+
13
+ class CausalMethod(ABC):
14
+ """Base class for all causal inference methods.
15
+
16
+ This abstract class defines the required methods that all causal
17
+ inference implementations must provide. It ensures a consistent
18
+ interface across different methods like propensity score matching,
19
+ instrumental variables, etc.
20
+
21
+ Each implementation should handle the specifics of the causal
22
+ inference method while conforming to this interface.
23
+ """
24
+
25
+ @abstractmethod
26
+ def validate_assumptions(self, df: pd.DataFrame, treatment: str,
27
+ outcome: str, covariates: List[str]) -> Dict[str, Any]:
28
+ """Validate method assumptions against the dataset.
29
+
30
+ Args:
31
+ df: DataFrame containing the dataset
32
+ treatment: Name of the treatment variable column
33
+ outcome: Name of the outcome variable column
34
+ covariates: List of covariate column names
35
+
36
+ Returns:
37
+ Dict containing validation results with keys:
38
+ - assumptions_valid (bool): Whether all assumptions are met
39
+ - failed_assumptions (List[str]): List of failed assumptions
40
+ - warnings (List[str]): List of warnings
41
+ - suggestions (List[str]): Suggestions for addressing issues
42
+ """
43
+ pass
44
+
45
+ @abstractmethod
46
+ def estimate_effect(self, df: pd.DataFrame, treatment: str,
47
+ outcome: str, covariates: List[str]) -> Dict[str, Any]:
48
+ """Estimate causal effect using this method.
49
+
50
+ Args:
51
+ df: DataFrame containing the dataset
52
+ treatment: Name of the treatment variable column
53
+ outcome: Name of the outcome variable column
54
+ covariates: List of covariate column names
55
+
56
+ Returns:
57
+ Dict containing estimation results with keys:
58
+ - effect_estimate (float): Estimated causal effect
59
+ - confidence_interval (tuple): Confidence interval (lower, upper)
60
+ - p_value (float): P-value of the estimate
61
+ - additional_metrics (Dict): Any method-specific metrics
62
+ """
63
+ pass
64
+
65
+ @abstractmethod
66
+ def generate_code(self, dataset_path: str, treatment: str,
67
+ outcome: str, covariates: List[str]) -> str:
68
+ """Generate executable code for this causal method.
69
+
70
+ Args:
71
+ dataset_path: Path to the dataset file
72
+ treatment: Name of the treatment variable column
73
+ outcome: Name of the outcome variable column
74
+ covariates: List of covariate column names
75
+
76
+ Returns:
77
+ String containing executable Python code implementing this method
78
+ """
79
+ pass
80
+
81
+ @abstractmethod
82
+ def explain(self) -> str:
83
+ """Explain this causal method, its assumptions, and when to use it.
84
+
85
+ Returns:
86
+ String with detailed explanation of the method
87
+ """
88
+ pass
auto_causal/methods/diff_in_means/__init__.py ADDED
File without changes
auto_causal/methods/diff_in_means/diagnostics.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic descriptive statistics for Difference in Means.
3
+ """
4
+
5
+ from typing import Dict, Any
6
+ import pandas as pd
7
+ import numpy as np
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def run_dim_diagnostics(df: pd.DataFrame, treatment: str, outcome: str) -> Dict[str, Any]:
13
+ """
14
+ Calculates basic descriptive statistics for treatment and control groups.
15
+
16
+ Args:
17
+ df: Input DataFrame (should already be filtered for NaNs in treatment/outcome).
18
+ treatment: Name of the binary treatment variable column.
19
+ outcome: Name of the outcome variable column.
20
+
21
+ Returns:
22
+ Dictionary containing group means, standard deviations, and counts.
23
+ """
24
+ details = {}
25
+ try:
26
+ grouped = df.groupby(treatment)[outcome]
27
+ stats = grouped.agg(['mean', 'std', 'count'])
28
+
29
+ # Ensure both groups (0 and 1) are present if possible
30
+ control_stats = stats.loc[0].to_dict() if 0 in stats.index else {'mean': np.nan, 'std': np.nan, 'count': 0}
31
+ treated_stats = stats.loc[1].to_dict() if 1 in stats.index else {'mean': np.nan, 'std': np.nan, 'count': 0}
32
+
33
+ details['control_group_stats'] = control_stats
34
+ details['treated_group_stats'] = treated_stats
35
+
36
+ if control_stats['count'] == 0 or treated_stats['count'] == 0:
37
+ logger.warning("One or both treatment groups have zero observations.")
38
+ return {"status": "Warning - Empty Group(s)", "details": details}
39
+
40
+ # Simple check for variance difference (Levene's test could be added)
41
+ control_std = control_stats.get('std', 0)
42
+ treated_std = treated_stats.get('std', 0)
43
+ if control_std > 0 and treated_std > 0:
44
+ ratio = (control_std**2) / (treated_std**2)
45
+ details['variance_ratio_control_div_treated'] = ratio
46
+ if ratio > 4 or ratio < 0.25: # Rule of thumb
47
+ details['variance_homogeneity_status'] = "Potentially Unequal (ratio > 4 or < 0.25)"
48
+ else:
49
+ details['variance_homogeneity_status'] = "Likely Similar"
50
+ else:
51
+ details['variance_homogeneity_status'] = "Could not calculate (zero variance in a group)"
52
+
53
+ return {"status": "Success", "details": details}
54
+
55
+ except KeyError as ke:
56
+ logger.error(f"KeyError during diagnostics: {ke}. Treatment levels might not be 0/1.")
57
+ return {"status": "Failed", "error": f"Treatment levels might not be 0/1: {ke}", "details": details}
58
+ except Exception as e:
59
+ logger.error(f"Error running Difference in Means diagnostics: {e}")
60
+ return {"status": "Failed", "error": str(e), "details": details}
auto_causal/methods/diff_in_means/estimator.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Difference in Means / Simple Linear Regression Estimator.
3
+
4
+ Estimates the Average Treatment Effect (ATE) by comparing the mean outcome
5
+ between the treated and control groups. This is equivalent to a simple OLS
6
+ regression of the outcome on the treatment indicator.
7
+
8
+ Assumes no confounding (e.g., suitable for RCT data).
9
+ """
10
+ import pandas as pd
11
+ import statsmodels.api as sm
12
+ import numpy as np
13
+ import warnings
14
+ from typing import Dict, Any, Optional
15
+ import logging
16
+ from langchain.chat_models.base import BaseChatModel # For type hinting llm
17
+
18
+ from .diagnostics import run_dim_diagnostics
19
+ from .llm_assist import interpret_dim_results
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ def estimate_effect(
24
+ df: pd.DataFrame,
25
+ treatment: str,
26
+ outcome: str,
27
+ query: Optional[str] = None, # For potential LLM use
28
+ llm: Optional[BaseChatModel] = None, # For potential LLM use
29
+ **kwargs # To capture any other potential arguments (e.g., covariates - which are ignored)
30
+ ) -> Dict[str, Any]:
31
+ """
32
+ Estimates the causal effect using Difference in Means (via OLS).
33
+
34
+ Ignores any provided covariates.
35
+
36
+ Args:
37
+ df: Input DataFrame.
38
+ treatment: Name of the binary treatment variable column (should be 0 or 1).
39
+ outcome: Name of the outcome variable column.
40
+ query: Optional user query for context.
41
+ llm: Optional Language Model instance.
42
+ **kwargs: Additional keyword arguments (ignored).
43
+
44
+ Returns:
45
+ Dictionary containing estimation results:
46
+ - 'effect_estimate': The difference in means (treatment coefficient).
47
+ - 'p_value': The p-value associated with the difference.
48
+ - 'confidence_interval': The 95% confidence interval for the difference.
49
+ - 'standard_error': The standard error of the difference.
50
+ - 'formula': The regression formula used.
51
+ - 'model_summary': Summary object from statsmodels.
52
+ - 'diagnostics': Basic group statistics.
53
+ - 'interpretation': LLM interpretation.
54
+ """
55
+ required_cols = [treatment, outcome]
56
+ missing_cols = [col for col in required_cols if col not in df.columns]
57
+ if missing_cols:
58
+ raise ValueError(f"Missing required columns: {missing_cols}")
59
+
60
+ # Validate treatment is binary (or close to it)
61
+ treat_vals = df[treatment].dropna().unique()
62
+ if not np.all(np.isin(treat_vals, [0, 1])):
63
+ warnings.warn(f"Treatment column '{treatment}' contains values other than 0 and 1: {treat_vals}. Proceeding, but results may be unreliable.", UserWarning)
64
+ # Optional: could raise ValueError here if strict binary is required
65
+
66
+ # Prepare data for statsmodels (add constant, handle potential NaNs)
67
+ df_analysis = df[required_cols].dropna()
68
+ if df_analysis.empty:
69
+ raise ValueError("No data remaining after dropping NaNs for required columns.")
70
+
71
+ X = df_analysis[[treatment]]
72
+ X = sm.add_constant(X) # Add intercept
73
+ y = df_analysis[outcome]
74
+
75
+ formula = f"{outcome} ~ {treatment} + const"
76
+ logger.info(f"Running Difference in Means regression: {formula}")
77
+
78
+ try:
79
+ model = sm.OLS(y, X)
80
+ results = model.fit()
81
+
82
+ effect_estimate = results.params[treatment]
83
+ p_value = results.pvalues[treatment]
84
+ conf_int = results.conf_int(alpha=0.05).loc[treatment].tolist()
85
+ std_err = results.bse[treatment]
86
+
87
+ # Run basic diagnostics (group means, stds, counts)
88
+ diag_results = run_dim_diagnostics(df_analysis, treatment, outcome)
89
+
90
+ # Get interpretation
91
+ interpretation = interpret_dim_results(results, diag_results, treatment, llm=llm)
92
+
93
+ return {
94
+ 'effect_estimate': effect_estimate,
95
+ 'p_value': p_value,
96
+ 'confidence_interval': conf_int,
97
+ 'standard_error': std_err,
98
+ 'formula': formula,
99
+ 'model_summary': results.summary(),
100
+ 'diagnostics': diag_results,
101
+ 'interpretation': interpretation,
102
+ 'method_used': 'Difference in Means (OLS)'
103
+ }
104
+
105
+ except Exception as e:
106
+ logger.error(f"Difference in Means failed: {e}")
107
+ raise
auto_causal/methods/diff_in_means/llm_assist.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM assistance functions for Difference in Means analysis.
3
+ """
4
+
5
+ from typing import Dict, Any, Optional
6
+ import logging
7
+
8
+ # Imported for type hinting
9
+ from langchain.chat_models.base import BaseChatModel
10
+ from statsmodels.regression.linear_model import RegressionResultsWrapper
11
+
12
+ # Import shared LLM helpers
13
+ from auto_causal.utils.llm_helpers import call_llm_with_json_output
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def interpret_dim_results(
18
+ results: RegressionResultsWrapper,
19
+ diagnostics: Dict[str, Any],
20
+ treatment_var: str,
21
+ llm: Optional[BaseChatModel] = None
22
+ ) -> str:
23
+ """
24
+ Use LLM to interpret Difference in Means results.
25
+
26
+ Args:
27
+ results: Fitted statsmodels OLS results object (from outcome ~ treatment).
28
+ diagnostics: Dictionary of diagnostic results (group stats).
29
+ treatment_var: Name of the treatment variable.
30
+ llm: Optional LLM model instance.
31
+
32
+ Returns:
33
+ String containing natural language interpretation.
34
+ """
35
+ default_interpretation = "LLM interpretation not available for Difference in Means."
36
+ if llm is None:
37
+ logger.info("LLM not provided for Difference in Means interpretation.")
38
+ return default_interpretation
39
+
40
+ try:
41
+ # --- Prepare summary for LLM ---
42
+ results_summary = {}
43
+ diag_details = diagnostics.get('details', {})
44
+ control_stats = diag_details.get('control_group_stats', {})
45
+ treated_stats = diag_details.get('treated_group_stats', {})
46
+
47
+ effect = results.params.get(treatment_var)
48
+ pval = results.pvalues.get(treatment_var)
49
+
50
+ results_summary['Effect Estimate (Difference in Means)'] = f"{effect:.3f}" if isinstance(effect, (int, float)) else str(effect)
51
+ results_summary['P-value'] = f"{pval:.3f}" if isinstance(pval, (int, float)) else str(pval)
52
+ try:
53
+ conf_int = results.conf_int().loc[treatment_var]
54
+ results_summary['95% Confidence Interval'] = f"[{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
55
+ except KeyError:
56
+ results_summary['95% Confidence Interval'] = "Not Found"
57
+ except Exception as ci_e:
58
+ results_summary['95% Confidence Interval'] = f"Error ({ci_e})"
59
+
60
+ results_summary['Control Group Mean Outcome'] = f"{control_stats.get('mean', 'N/A'):.3f}" if isinstance(control_stats.get('mean'), (int, float)) else str(control_stats.get('mean'))
61
+ results_summary['Treated Group Mean Outcome'] = f"{treated_stats.get('mean', 'N/A'):.3f}" if isinstance(treated_stats.get('mean'), (int, float)) else str(treated_stats.get('mean'))
62
+ results_summary['Control Group Size'] = control_stats.get('count', 'N/A')
63
+ results_summary['Treated Group Size'] = treated_stats.get('count', 'N/A')
64
+
65
+ # --- Construct Prompt ---
66
+ prompt = f"""
67
+ You are assisting with interpreting Difference in Means results, likely from an RCT.
68
+
69
+ Results Summary:
70
+ {results_summary}
71
+
72
+ Explain these results in 1-3 concise sentences. Focus on:
73
+ 1. The estimated average treatment effect (magnitude, direction, statistical significance based on p-value < 0.05).
74
+ 2. Compare the mean outcomes between the treated and control groups.
75
+
76
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
77
+ {{
78
+ "interpretation": "<your concise interpretation text>"
79
+ }}
80
+ """
81
+
82
+ # --- Call LLM ---
83
+ response = call_llm_with_json_output(llm, prompt)
84
+
85
+ # --- Process Response ---
86
+ if response and isinstance(response, dict) and \
87
+ "interpretation" in response and isinstance(response["interpretation"], str):
88
+ return response["interpretation"]
89
+ else:
90
+ logger.warning(f"Failed to get valid interpretation from LLM for Difference in Means. Response: {response}")
91
+ return default_interpretation
92
+
93
+ except Exception as e:
94
+ logger.error(f"Error during LLM interpretation for Difference in Means: {e}")
95
+ return f"Error generating interpretation: {e}"
auto_causal/methods/difference_in_differences/diagnostics.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Diagnostic functions for Difference-in-Differences method."""
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ from typing import Dict, Any, Optional, List
6
+ import logging
7
+ import statsmodels.formula.api as smf # Import statsmodels
8
+ from patsy import PatsyError # To catch formula errors
9
+
10
+ # Import helper function from estimator -> Change to utils
11
+ from .utils import create_post_indicator
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def validate_parallel_trends(df: pd.DataFrame, time_var: str, outcome: str,
16
+ group_indicator_col: str, treatment_period_start: Any,
17
+ dataset_description: Optional[str] = None,
18
+ time_varying_covariates: Optional[List[str]] = None) -> Dict[str, Any]:
19
+ """Validates the parallel trends assumption using pre-treatment data.
20
+
21
+ Regresses the outcome on group-specific time trends before the treatment period.
22
+ Tests if the interaction terms between group and pre-treatment time periods are jointly significant.
23
+
24
+ Args:
25
+ df: DataFrame containing the data.
26
+ time_var: Name of the time variable column.
27
+ outcome: Name of the outcome variable column.
28
+ group_indicator_col: Name of the binary treatment group indicator column (0/1).
29
+ treatment_period_start: The time period value when treatment starts.
30
+ dataset_description: Optional dictionary for additional dataset description.
31
+ time_varying_covariates: Optional list of time-varying covariates to include.
32
+
33
+ Returns:
34
+ Dictionary with validation results.
35
+ """
36
+ logger.info("Validating parallel trends...")
37
+ validation_result = {"valid": False, "p_value": 1.0, "details": "", "error": None}
38
+
39
+ try:
40
+ # Filter pre-treatment data
41
+ pre_df = df[df[time_var] < treatment_period_start].copy()
42
+
43
+ if len(pre_df) < 20 or pre_df[group_indicator_col].nunique() < 2 or pre_df[time_var].nunique() < 2:
44
+ validation_result["details"] = "Insufficient pre-treatment data or variation to perform test."
45
+ logger.warning(validation_result["details"])
46
+ # Assume valid if cannot test? Or invalid? Let's default to True if we can't test
47
+ validation_result["valid"] = True
48
+ validation_result["details"] += " Defaulting to assuming parallel trends (unable to test)."
49
+ return validation_result
50
+
51
+ # Check if group indicator is binary
52
+ if pre_df[group_indicator_col].nunique() > 2:
53
+ validation_result["details"] = f"Group indicator '{group_indicator_col}' has more than 2 unique values. Using simple visual assessment."
54
+ logger.warning(validation_result["details"])
55
+ # Use visual assessment method instead (check if trends look roughly parallel)
56
+ validation_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col)
57
+ # Ensure p_value is set
58
+ if validation_result["p_value"] is None:
59
+ validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04
60
+ return validation_result
61
+
62
+ # Use a robust approach first - test for pre-trend differences using a simpler model
63
+ try:
64
+ # Create a linear time trend
65
+ pre_df['time_trend'] = pre_df[time_var].astype(float)
66
+
67
+ # Create interaction between trend and group
68
+ pre_df['group_trend'] = pre_df['time_trend'] * pre_df[group_indicator_col].astype(float)
69
+
70
+ # Simple regression with linear trend interaction
71
+ simple_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}') + time_trend + group_trend"
72
+ simple_model = smf.ols(simple_formula, data=pre_df)
73
+ simple_results = simple_model.fit()
74
+
75
+ # Check if trend interaction coefficient is significant
76
+ group_trend_pvalue = simple_results.pvalues['group_trend']
77
+
78
+ # If p > 0.05, trends are not significantly different
79
+ validation_result["valid"] = group_trend_pvalue > 0.05
80
+ validation_result["p_value"] = group_trend_pvalue
81
+ validation_result["details"] = f"Simple linear trend test: p-value for group-trend interaction: {group_trend_pvalue:.4f}. Parallel trends: {validation_result['valid']}."
82
+ logger.info(validation_result["details"])
83
+
84
+ # If we've successfully validated with the simple approach, return
85
+ return validation_result
86
+
87
+ except Exception as e:
88
+ logger.warning(f"Simple trend test failed: {e}. Trying alternative approach.")
89
+ # Continue to more complex method if simple method fails
90
+
91
+ # Try more complex approach with period-specific interactions
92
+ try:
93
+ # Create period dummies to avoid issues with categorical variables
94
+ time_periods = sorted(pre_df[time_var].unique())
95
+
96
+ # Create dummy variables for time periods (except first)
97
+ for period in time_periods[1:]:
98
+ period_col = f'period_{period}'
99
+ pre_df[period_col] = (pre_df[time_var] == period).astype(int)
100
+
101
+ # Create interaction with group
102
+ pre_df[f'group_x_{period_col}'] = pre_df[period_col] * pre_df[group_indicator_col].astype(float)
103
+
104
+ # Construct formula with manual dummies
105
+ interaction_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}')"
106
+
107
+ # Add period dummies except first (reference)
108
+ for period in time_periods[1:]:
109
+ period_col = f'period_{period}'
110
+ interaction_formula += f" + {period_col}"
111
+
112
+ # Add interactions
113
+ interaction_terms = []
114
+ for period in time_periods[1:]:
115
+ interaction_col = f'group_x_period_{period}'
116
+ interaction_formula += f" + {interaction_col}"
117
+ interaction_terms.append(interaction_col)
118
+
119
+ # Add covariates if provided
120
+ if time_varying_covariates:
121
+ for cov in time_varying_covariates:
122
+ interaction_formula += f" + Q('{cov}')"
123
+
124
+ # Fit model
125
+ complex_model = smf.ols(interaction_formula, data=pre_df)
126
+ complex_results = complex_model.fit()
127
+
128
+ # Test joint significance of interaction terms
129
+ if interaction_terms:
130
+ from statsmodels.formula.api import ols
131
+ from statsmodels.stats.anova import anova_lm
132
+
133
+ # Create models with and without interactions
134
+ formula_with = interaction_formula
135
+ formula_without = interaction_formula
136
+ for term in interaction_terms:
137
+ formula_without = formula_without.replace(f" + {term}", "")
138
+
139
+ model_with = smf.ols(formula_with, data=pre_df).fit()
140
+ model_without = smf.ols(formula_without, data=pre_df).fit()
141
+
142
+ # Compare models
143
+ try:
144
+ from scipy import stats
145
+ df_model = len(interaction_terms)
146
+ df_residual = model_with.df_resid
147
+ f_value = ((model_without.ssr - model_with.ssr) / df_model) / (model_with.ssr / df_residual)
148
+ p_value = 1 - stats.f.cdf(f_value, df_model, df_residual)
149
+
150
+ validation_result["valid"] = p_value > 0.05
151
+ validation_result["p_value"] = p_value
152
+ validation_result["details"] = f"Manual F-test for pre-treatment interactions: F({df_model}, {df_residual})={f_value:.4f}, p={p_value:.4f}. Parallel trends: {validation_result['valid']}."
153
+ logger.info(validation_result["details"])
154
+
155
+ except Exception as e:
156
+ logger.warning(f"Manual F-test failed: {e}. Using individual coefficient significance.")
157
+
158
+ # If F-test fails, check individual coefficients
159
+ significant_interactions = 0
160
+ for term in interaction_terms:
161
+ if term in complex_results.pvalues and complex_results.pvalues[term] < 0.05:
162
+ significant_interactions += 1
163
+
164
+ validation_result["valid"] = significant_interactions == 0
165
+ # Set a dummy p-value based on proportion of significant interactions
166
+ if len(interaction_terms) > 0:
167
+ validation_result["p_value"] = 1.0 - (significant_interactions / len(interaction_terms))
168
+ else:
169
+ validation_result["p_value"] = 1.0 # Default to 1.0 if no interaction terms
170
+ validation_result["details"] = f"{significant_interactions} out of {len(interaction_terms)} pre-treatment interactions are significant at p<0.05. Parallel trends: {validation_result['valid']}."
171
+ logger.info(validation_result["details"])
172
+ else:
173
+ validation_result["valid"] = True
174
+ validation_result["p_value"] = 1.0 # Default to 1.0 if no interaction terms
175
+ validation_result["details"] = "No pre-treatment interaction terms could be tested. Defaulting to assuming parallel trends."
176
+ logger.warning(validation_result["details"])
177
+
178
+ except Exception as e:
179
+ logger.warning(f"Complex trend test failed: {e}. Falling back to visual assessment.")
180
+ tmp_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col)
181
+ # Copy over values from visual assessment ensuring p_value is set
182
+ validation_result.update(tmp_result)
183
+ # Ensure p_value is set
184
+ if validation_result["p_value"] is None:
185
+ validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04
186
+
187
+ except Exception as e:
188
+ error_msg = f"Error during parallel trends validation: {e}"
189
+ logger.error(error_msg, exc_info=True)
190
+ validation_result["details"] = error_msg
191
+ validation_result["error"] = str(e)
192
+ # Default to assuming valid if test fails completely
193
+ validation_result["valid"] = True
194
+ validation_result["p_value"] = 1.0 # Default to 1.0 if test fails
195
+ validation_result["details"] += " Defaulting to assuming parallel trends (test failed)."
196
+
197
+ return validation_result
198
+
199
+ def assess_trends_visually(df: pd.DataFrame, time_var: str, outcome: str,
200
+ group_indicator_col: str) -> Dict[str, Any]:
201
+ """Simple visual assessment of parallel trends by comparing group means over time.
202
+
203
+ This is a fallback method when statistical tests fail.
204
+ """
205
+ result = {"valid": False, "p_value": 1.0, "details": "", "error": None}
206
+
207
+ try:
208
+ # Group by time and treatment group, calculate means
209
+ grouped = df.groupby([time_var, group_indicator_col])[outcome].mean().reset_index()
210
+
211
+ # Pivot to get time series for each group
212
+ if df[group_indicator_col].nunique() <= 10: # Only if reasonable number of groups
213
+ pivot = grouped.pivot(index=time_var, columns=group_indicator_col, values=outcome)
214
+
215
+ # Calculate slopes between consecutive periods for each group
216
+ slopes = {}
217
+ time_values = sorted(df[time_var].unique())
218
+
219
+ if len(time_values) >= 3: # Need at least 3 periods to compare slopes
220
+ for group in pivot.columns:
221
+ group_slopes = []
222
+ for i in range(len(time_values) - 1):
223
+ t1, t2 = time_values[i], time_values[i+1]
224
+ if t1 in pivot.index and t2 in pivot.index:
225
+ slope = (pivot.loc[t2, group] - pivot.loc[t1, group]) / (t2 - t1)
226
+ group_slopes.append(slope)
227
+ if group_slopes:
228
+ slopes[group] = group_slopes
229
+
230
+ # Compare slopes between groups
231
+ if len(slopes) >= 2:
232
+ slope_diffs = []
233
+ groups = list(slopes.keys())
234
+ for i in range(len(slopes[groups[0]])):
235
+ if i < len(slopes[groups[1]]):
236
+ slope_diffs.append(abs(slopes[groups[0]][i] - slopes[groups[1]][i]))
237
+
238
+ # If average slope difference is small relative to outcome scale
239
+ outcome_scale = df[outcome].std()
240
+ avg_slope_diff = sum(slope_diffs) / len(slope_diffs) if slope_diffs else 0
241
+ relative_diff = avg_slope_diff / outcome_scale if outcome_scale > 0 else 0
242
+
243
+ result["valid"] = relative_diff < 0.2 # Threshold for "parallel enough"
244
+ # Set p-value based on relative difference
245
+ result["p_value"] = 1.0 - (relative_diff * 5) if relative_diff < 0.2 else 0.04
246
+ result["details"] = f"Visual assessment: relative slope difference = {relative_diff:.4f}. Parallel trends: {result['valid']}."
247
+ else:
248
+ result["valid"] = True
249
+ result["p_value"] = 1.0
250
+ result["details"] = "Visual assessment: insufficient group data for comparison. Defaulting to assuming parallel trends."
251
+ else:
252
+ result["valid"] = True
253
+ result["p_value"] = 1.0
254
+ result["details"] = "Visual assessment: insufficient time periods for comparison. Defaulting to assuming parallel trends."
255
+ else:
256
+ result["valid"] = True
257
+ result["p_value"] = 1.0
258
+ result["details"] = f"Visual assessment: too many groups ({df[group_indicator_col].nunique()}) for visual comparison. Defaulting to assuming parallel trends."
259
+
260
+ except Exception as e:
261
+ result["error"] = str(e)
262
+ result["valid"] = True
263
+ result["p_value"] = 1.0
264
+ result["details"] = f"Visual assessment failed: {e}. Defaulting to assuming parallel trends."
265
+
266
+ logger.info(result["details"])
267
+ return result
268
+
269
+ def run_placebo_test(df: pd.DataFrame, time_var: str, group_var: str, outcome: str,
270
+ treated_unit_indicator: str, covariates: List[str],
271
+ treatment_period_start: Any,
272
+ placebo_period_start: Any) -> Dict[str, Any]:
273
+ """Runs a placebo test for DiD by assigning a fake earlier treatment period.
274
+
275
+ Re-runs the DiD estimation using the placebo period and checks if the effect is non-significant.
276
+
277
+ Args:
278
+ df: Original DataFrame.
279
+ time_var: Name of the time variable column.
280
+ group_var: Name of the unit/group ID column (for clustering SE).
281
+ outcome: Name of the outcome variable column.
282
+ treated_unit_indicator: Name of the binary treatment group indicator column (0/1).
283
+ covariates: List of covariate names.
284
+ treatment_period_start: The actual treatment start period.
285
+ placebo_period_start: The fake treatment start period (must be before actual start).
286
+
287
+ Returns:
288
+ Dictionary with placebo test results.
289
+ """
290
+ logger.info(f"Running placebo test assigning treatment start at {placebo_period_start}...")
291
+ placebo_result = {"passed": False, "effect_estimate": None, "p_value": None, "details": "", "error": None}
292
+
293
+ if placebo_period_start >= treatment_period_start:
294
+ error_msg = "Placebo period must be before the actual treatment period."
295
+ logger.error(error_msg)
296
+ placebo_result["error"] = error_msg
297
+ placebo_result["details"] = error_msg
298
+ return placebo_result
299
+
300
+ try:
301
+ df_placebo = df.copy()
302
+ # Create placebo post and interaction terms
303
+ post_placebo_col = 'post_placebo'
304
+ interaction_placebo_col = 'did_interaction_placebo'
305
+
306
+ df_placebo[post_placebo_col] = create_post_indicator(df_placebo, time_var, placebo_period_start)
307
+ df_placebo[interaction_placebo_col] = df_placebo[treated_unit_indicator] * df_placebo[post_placebo_col]
308
+
309
+ # Construct formula for placebo regression
310
+ formula = f"`{outcome}` ~ `{treated_unit_indicator}` + `{post_placebo_col}` + `{interaction_placebo_col}`"
311
+ if covariates:
312
+ formula += f" + {' + '.join([f'`{c}`' for c in covariates])}"
313
+ formula += f" + C(`{group_var}`) + C(`{time_var}`)" # Include FEs
314
+
315
+ logger.debug(f"Placebo test formula: {formula}")
316
+
317
+ # Fit the placebo model with clustered SE
318
+ ols_model = smf.ols(formula=formula, data=df_placebo)
319
+ results = ols_model.fit(cov_type='cluster', cov_kwds={'groups': df_placebo[group_var]})
320
+
321
+ # Check the significance of the placebo interaction term
322
+ placebo_effect = float(results.params[interaction_placebo_col])
323
+ placebo_p_value = float(results.pvalues[interaction_placebo_col])
324
+
325
+ # Test passes if the placebo effect is not statistically significant (e.g., p > 0.1)
326
+ passed_test = placebo_p_value > 0.10
327
+
328
+ placebo_result["passed"] = passed_test
329
+ placebo_result["effect_estimate"] = placebo_effect
330
+ placebo_result["p_value"] = placebo_p_value
331
+ placebo_result["details"] = f"Placebo treatment effect estimated at {placebo_effect:.4f} (p={placebo_p_value:.4f}). Test passed: {passed_test}."
332
+ logger.info(placebo_result["details"])
333
+
334
+ except (KeyError, PatsyError, ValueError, Exception) as e:
335
+ error_msg = f"Error during placebo test execution: {e}"
336
+ logger.error(error_msg, exc_info=True)
337
+ placebo_result["details"] = error_msg
338
+ placebo_result["error"] = str(e)
339
+
340
+ return placebo_result
341
+
342
+ # TODO: Add function for Event Study plot (plot_event_study)
343
+ # This would involve estimating effects for leads and lags around the treatment period.
344
+
345
+ # Add other diagnostic functions as needed (e.g., plot_event_study)
auto_causal/methods/difference_in_differences/estimator.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Difference-in-Differences Estimator using DoWhy with Statsmodels fallback.
3
+ """
4
+
5
+ import logging
6
+ import pandas as pd
7
+ import numpy as np
8
+ from typing import Dict, List, Optional, Any, Tuple
9
+ from auto_causal.config import get_llm_client # IMPORT LLM Client Factory
10
+
11
+ # DoWhy imports (Commented out for simplification)
12
+ # from dowhy import CausalModel
13
+ # from dowhy.causal_estimators import CausalEstimator
14
+ # from dowhy.causal_estimator import CausalEstimate
15
+ # Statsmodels import for estimation
16
+ import statsmodels.formula.api as smf
17
+
18
+ # Local imports
19
+ from .llm_assist import (
20
+ identify_time_variable,
21
+ determine_treatment_period,
22
+ identify_treatment_group,
23
+ interpret_did_results
24
+ )
25
+ from .diagnostics import validate_parallel_trends # Import diagnostics
26
+ # Import from the new utils module
27
+ from .utils import create_post_indicator
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # --- Helper functions moved from old file ---
32
+ def format_did_results(statsmodels_results: Any, interaction_term_key: str,
33
+ validation_results: Dict[str, Any],
34
+ method_details: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
35
+ '''Formats the DiD results from statsmodels results into a standard dictionary.'''
36
+
37
+ try:
38
+ # Use the interaction_term_key passed directly
39
+ effect = float(statsmodels_results.params[interaction_term_key])
40
+ stderr = float(statsmodels_results.bse[interaction_term_key])
41
+ pval = float(statsmodels_results.pvalues[interaction_term_key])
42
+ ci = statsmodels_results.conf_int().loc[interaction_term_key].values.tolist()
43
+ ci_lower, ci_upper = float(ci[0]), float(ci[1])
44
+ logger.info(f"Extracted effect for '{interaction_term_key}'")
45
+
46
+ except KeyError:
47
+ logger.error(f"Interaction term '{interaction_term_key}' not found in statsmodels results. Available params: {statsmodels_results.params.index.tolist()}")
48
+ # Fallback to NaN if term not found
49
+ effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan
50
+ except Exception as e:
51
+ logger.error(f"Error extracting results from statsmodels object: {e}")
52
+ effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan
53
+
54
+ # Create a standardized results dictionary
55
+ results = {
56
+ "effect_estimate": effect,
57
+ "standard_error": stderr,
58
+ "p_value": pval,
59
+ "confidence_interval": [ci_lower, ci_upper],
60
+ "diagnostics": validation_results,
61
+ "parameters": parameters,
62
+ "details": str(statsmodels_results.summary())
63
+ }
64
+
65
+ return results
66
+
67
+ # Comment out unused DoWhy result formatter
68
+ # def format_dowhy_results(estimate: CausalEstimate,
69
+ # validation_results: Dict[str, Any],
70
+ # parameters: Dict[str, Any]) -> Dict[str, Any]:
71
+ # '''Formats the DiD results from DoWhy causal estimate into a standard dictionary.'''
72
+
73
+ # try:
74
+ # # Extract values from DoWhy estimate
75
+ # effect = float(estimate.value)
76
+ # stderr = float(estimate.get_standard_error()) if hasattr(estimate, 'get_standard_error') else np.nan
77
+ # ci_lower, ci_upper = estimate.get_confidence_intervals() if hasattr(estimate, 'get_confidence_intervals') else (np.nan, np.nan)
78
+ # # Extract p-value if available, otherwise use NaN
79
+ # pval = estimate.get_significance_test_results().get('p_value', np.nan) if hasattr(estimate, 'get_significance_test_results') else np.nan
80
+
81
+ # # Get available details from estimate
82
+ # details = str(estimate)
83
+ # if hasattr(estimate, 'summary'):
84
+ # details = str(estimate.summary())
85
+
86
+ # logger.info(f"Extracted effect from DoWhy estimate: {effect}")
87
+
88
+ # except Exception as e:
89
+ # logger.error(f"Error extracting results from DoWhy estimate: {e}")
90
+ # effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan
91
+ # details = f"Error extracting DoWhy results: {e}"
92
+
93
+ # # Create a standardized results dictionary
94
+ # results = {
95
+ # "effect_estimate": effect,
96
+ # "effect_se": stderr,
97
+ # "p_value": pval,
98
+ # "confidence_interval": [ci_lower, ci_upper],
99
+ # "diagnostics": validation_results,
100
+ # "parameters": parameters,
101
+ # "details": details,
102
+ # "estimator": "dowhy"
103
+ # }
104
+
105
+ # return results
106
+
107
+ # --- Main `estimate_effect` function ---
108
+
109
+ def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str,
110
+ covariates: List[str],
111
+ dataset_description: Optional[str] = None,
112
+ query: Optional[str] = None,
113
+ **kwargs) -> Dict[str, Any]:
114
+ """Difference-in-Differences estimation using DoWhy with Statsmodels fallback.
115
+
116
+ Args:
117
+ df: Dataset containing causal variables
118
+ treatment: Name of treatment variable (or variable indicating treated group)
119
+ outcome: Name of outcome variable
120
+ covariates: List of covariate names
121
+ dataset_description: Optional dictionary describing the dataset
122
+ **kwargs: Method-specific parameters (e.g., time_var, group_var, query, llm instance if needed)
123
+
124
+ Returns:
125
+ Dictionary with effect estimate and diagnostics
126
+ """
127
+ query = kwargs.get('query_str')
128
+ # llm_instance = kwargs.get('llm') # Pass llm if helpers need it
129
+ df_processed = df.copy() # Work on a copy
130
+
131
+ logger.info("Starting DiD estimation using DoWhy with Statsmodels fallback...")
132
+
133
+ # --- Step 1: Identify Key Variables (using LLM Assist placeholders) ---
134
+ # Pass llm_instance to helpers if they are implemented to use it
135
+ llm_instance = get_llm_client() # Get llm instance if passed
136
+ time_var = kwargs.get('time_variable', identify_time_variable(df_processed, query, dataset_description, llm=llm_instance))
137
+ if time_var is None:
138
+ raise ValueError("Time variable could not be identified for DiD.")
139
+ if time_var not in df_processed.columns:
140
+ raise ValueError(f"Identified time variable '{time_var}' not found in DataFrame.")
141
+
142
+ # Determine the variable that identifies the panel unit (for grouping/FE)
143
+ group_var = kwargs.get('group_variable', identify_treatment_group(df_processed, treatment, query, dataset_description, llm=llm_instance))
144
+ if group_var is None:
145
+ raise ValueError("Group/Unit variable could not be identified for DiD.")
146
+ if group_var not in df_processed.columns:
147
+ raise ValueError(f"Identified group/unit variable '{group_var}' not found in DataFrame.")
148
+
149
+ # Check outcome exists before proceeding further
150
+ if outcome not in df_processed.columns:
151
+ raise ValueError(f"Outcome variable '{outcome}' not found in DataFrame.")
152
+
153
+ # Determine treatment period start
154
+ treatment_period = kwargs.get('treatment_period_start', kwargs.get('treatment_period',
155
+ determine_treatment_period(df_processed, time_var, treatment, query, dataset_description, llm=llm_instance)))
156
+
157
+ # --- Identify the TRUE binary treatment group indicator column ---
158
+ treated_group_col_for_formula = None
159
+
160
+ # Priority 1: Check if the 'treatment' argument itself is a valid binary indicator
161
+ if treatment in df_processed.columns and pd.api.types.is_numeric_dtype(df_processed[treatment]):
162
+ unique_treat_vals = set(df_processed[treatment].dropna().unique())
163
+ if unique_treat_vals.issubset({0, 1}):
164
+ treated_group_col_for_formula = treatment
165
+ logger.info(f"Using the provided 'treatment' argument '{treatment}' as binary group indicator.")
166
+
167
+ # Priority 2: Check if a column explicitly named 'group' exists and is binary
168
+ if treated_group_col_for_formula is None and 'group' in df_processed.columns and pd.api.types.is_numeric_dtype(df_processed['group']):
169
+ unique_group_vals = set(df_processed['group'].dropna().unique())
170
+ if unique_group_vals.issubset({0, 1}):
171
+ treated_group_col_for_formula = 'group'
172
+ logger.info(f"Using column 'group' as binary group indicator.")
173
+
174
+ # Priority 3: Fallback - Search other columns (excluding known roles and time-related ones)
175
+ if treated_group_col_for_formula is None:
176
+ logger.warning(f"Provided 'treatment' arg '{treatment}' is not binary 0/1 and no 'group' column found. Searching other columns...")
177
+ potential_group_cols = []
178
+ # Exclude outcome, time var, unit ID var, and common time indicators like 'post'
179
+ excluded_cols = [outcome, time_var, group_var, 'post', 'is_post_treatment', 'did_interaction']
180
+ for col_name in df_processed.columns:
181
+ if col_name in excluded_cols:
182
+ continue
183
+ try:
184
+ col_data = df_processed[col_name]
185
+ # Ensure we are working with a Series
186
+ if isinstance(col_data, pd.DataFrame):
187
+ if col_data.shape[1] == 1:
188
+ col_data = col_data.iloc[:, 0] # Extract the Series
189
+ else:
190
+ logger.warning(f"Skipping multi-column DataFrame slice for '{col_name}'.")
191
+ continue
192
+
193
+ # Check if the Series can be interpreted as binary 0/1
194
+ if not pd.api.types.is_numeric_dtype(col_data) and not pd.api.types.is_bool_dtype(col_data):
195
+ continue # Skip non-numeric/non-boolean columns
196
+
197
+ unique_vals = set(col_data.dropna().unique())
198
+ # Simplified check: directly test if unique values are a subset of {0, 1}
199
+ if unique_vals.issubset({0, 1}):
200
+ logger.info(f" Found potential binary indicator: {col_name}")
201
+ potential_group_cols.append(col_name)
202
+
203
+ except AttributeError as ae:
204
+ # Catch attribute errors likely due to unexpected types
205
+ logger.warning(f"Attribute error checking column '{col_name}': {ae}. Skipping.")
206
+ except Exception as e:
207
+ logger.warning(f"Unexpected error checking column '{col_name}' during group ID search: {e}")
208
+
209
+ if potential_group_cols:
210
+ treated_group_col_for_formula = potential_group_cols[0] # Take the first suitable one found
211
+ logger.info(f"Using column '{treated_group_col_for_formula}' found during search as binary group indicator.")
212
+ else:
213
+ # Final fallback: Use the originally identified group_var, but warn heavily
214
+ treated_group_col_for_formula = group_var
215
+ logger.error(f"CRITICAL WARNING: Could not find suitable binary treatment group indicator. Using '{group_var}', but this is likely incorrect and will produce invalid DiD estimates.")
216
+
217
+ # --- Final Check ---
218
+ if treated_group_col_for_formula not in df_processed.columns:
219
+ # This case should ideally not happen with the logic above but added defensively
220
+ raise ValueError(f"Determined treatment group column '{treated_group_col_for_formula}' not found in DataFrame.")
221
+ if df_processed[treated_group_col_for_formula].nunique(dropna=True) > 2:
222
+ logger.warning(f"Selected treatment group column '{treated_group_col_for_formula}' is not binary (has {df_processed[treated_group_col_for_formula].nunique()} unique values). DiD requires binary treatment group.")
223
+
224
+ # --- Step 2: Create Indicator Variables ---
225
+ post_indicator_col = 'post'
226
+ if post_indicator_col not in df_processed.columns:
227
+ # Create the post indicator if it doesn't exist
228
+ df_processed[post_indicator_col] = create_post_indicator(df_processed, time_var, treatment_period)
229
+
230
+ # Interaction term is treatment group * post
231
+ interaction_term_col = 'did_interaction' # Keep explicit interaction term
232
+ df_processed[interaction_term_col] = df_processed[treated_group_col_for_formula] * df_processed[post_indicator_col]
233
+
234
+ # --- Step 3: Validate Parallel Trends (using the group column) ---
235
+ parallel_trends_validation = validate_parallel_trends(df_processed, time_var, outcome,
236
+ treated_group_col_for_formula, treatment_period, dataset_description)
237
+ # Note: The validation result is currently just a placeholder
238
+ if not parallel_trends_validation.get('valid', False):
239
+ logger.warning("Parallel trends assumption potentially violated (based on placeholder check). Proceeding with estimation, but results may be biased.")
240
+ # Add this info to the final results diagnostics
241
+
242
+ # --- Step 4: Prepare for Statsmodels Estimation ---
243
+ # (DoWhy section commented out for simplicity)
244
+ # all_common_causes = covariates + [time_var, group_var] # group_var is unit ID
245
+ # use_dowhy_estimate = False
246
+ # dowhy_estimate = None
247
+
248
+ # try:
249
+ # # Create DoWhy CausalModel
250
+ # model = CausalModel(
251
+ # data=df_processed,
252
+ # treatment=treated_group_col_for_formula, # Use group indicator here
253
+ # outcome=outcome,
254
+ # common_causes=all_common_causes,
255
+ # )
256
+ # logger.info("DoWhy CausalModel created for DiD estimation.")
257
+
258
+ # # Identify estimand
259
+ # identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
260
+ # logger.info(f"DoWhy identified estimand: {identified_estimand.estimand_type}")
261
+
262
+ # # Try to estimate using DiD estimator if available in DoWhy
263
+ # try:
264
+ # logger.info("Attempting to use DoWhy's DiD estimator...")
265
+
266
+ # # Debug info - print DataFrame info to help diagnose possible issues
267
+ # logger.debug(f"DataFrame shape before DoWhy DiD: {df_processed.shape}")
268
+ # # ... (rest of DoWhy debug logs commented out) ...
269
+
270
+ # # Create params dictionary for DoWhy DiD estimator
271
+ # did_params = {
272
+ # 'time_var': time_var,
273
+ # 'treatment_period': treatment_period,
274
+ # 'unit_var': group_var
275
+ # }
276
+
277
+ # # Add control variables if available
278
+ # if covariates:
279
+ # did_params['control_vars'] = covariates
280
+
281
+ # logger.debug(f"DoWhy DiD params: {did_params}")
282
+
283
+ # # Try to use DiD estimator from DoWhy (requires recent version of DoWhy)
284
+ # if hasattr(model, 'estimate_effect'):
285
+ # try:
286
+ # # First check if difference_in_differences method is available
287
+ # available_methods = model.get_available_effect_estimators() if hasattr(model, 'get_available_effect_estimators') else []
288
+ # logger.debug(f"Available DoWhy estimators: {available_methods}")
289
+
290
+ # if "difference_in_differences" not in str(available_methods):
291
+ # logger.warning("'difference_in_differences' estimator not found in available DoWhy estimators. Falling back to statsmodels.")
292
+ # else:
293
+ # # Try the estimation with more error handling
294
+ # logger.info("Calling DoWhy DiD estimator...")
295
+ # estimate = model.estimate_effect(
296
+ # identified_estimand,
297
+ # method_name="difference_in_differences",
298
+ # method_params=did_params
299
+ # )
300
+
301
+ # if estimate:
302
+ # # Extra check to verify estimate has expected attributes
303
+ # if hasattr(estimate, 'value') and not pd.isna(estimate.value):
304
+ # dowhy_estimate = estimate
305
+ # use_dowhy_estimate = True
306
+ # logger.info(f"Successfully used DoWhy's DiD estimator. Effect estimate: {estimate.value}")
307
+ # else:
308
+ # logger.warning(f"DoWhy's DiD estimator returned invalid estimate: {estimate}. Falling back to statsmodels.")
309
+ # else:
310
+ # logger.warning("DoWhy's DiD estimator returned None. Falling back to statsmodels.")
311
+ # except IndexError as idx_err:
312
+ # # Handle specific IndexError that's occurring
313
+ # logger.error(f"IndexError in DoWhy DiD estimator: {idx_err}. Check input data structure.")
314
+ # # Trace more details about the error
315
+ # import traceback
316
+ # logger.error(f"Error traceback: {traceback.format_exc()}")
317
+ # logger.warning("Falling back to statsmodels due to IndexError in DoWhy.")
318
+ # else:
319
+ # logger.warning("DoWhy model does not have estimate_effect method. Falling back to statsmodels.")
320
+
321
+ # except (ImportError, AttributeError) as e:
322
+ # logger.warning(f"DoWhy DiD estimator not available or not implemented: {e}. Falling back to statsmodels.")
323
+ # except ValueError as ve:
324
+ # logger.error(f"ValueError in DoWhy DiD estimator: {ve}. Likely issue with data formatting. Falling back to statsmodels.")
325
+ # except Exception as e:
326
+ # logger.error(f"Error using DoWhy's DiD estimator: {e}. Falling back to statsmodels.")
327
+ # # Add traceback for better debugging
328
+ # import traceback
329
+ # logger.error(f"Full error traceback: {traceback.format_exc()}")
330
+
331
+ # except Exception as e:
332
+ # logger.error(f"Failed to create DoWhy CausalModel: {e}", exc_info=True)
333
+ # # model = None # Set model to None if creation fails
334
+
335
+ # Create parameters dictionary for formatting results
336
+ parameters = {
337
+ "time_var": time_var,
338
+ "group_var": group_var, # Unit ID
339
+ "treatment_indicator": treated_group_col_for_formula, # Group indicator used in formula basis
340
+ "post_indicator": post_indicator_col,
341
+ "treatment_period_start": treatment_period,
342
+ "covariates": covariates,
343
+ }
344
+
345
+ # Group diagnostics for formatting
346
+ did_diagnostics = {
347
+ "parallel_trends": parallel_trends_validation,
348
+ # "placebo_test": run_placebo_test(...)
349
+ }
350
+
351
+ # If DoWhy estimation was successful, use those results (Section Commented Out)
352
+ # if use_dowhy_estimate and dowhy_estimate:
353
+ # logger.info("Using DoWhy DiD estimation results.")
354
+ # parameters["estimation_method"] = "DoWhy Difference-in-Differences"
355
+
356
+ # # Format the results
357
+ # formatted_results = format_dowhy_results(dowhy_estimate, did_diagnostics, parameters)
358
+ # else:
359
+
360
+ # --- Step 5: Use Statsmodels OLS ---
361
+ logger.info("Determining Statsmodels OLS formula based on number of time periods...")
362
+
363
+ num_time_periods = df_processed[time_var].nunique()
364
+
365
+ interaction_term_key_for_results: str
366
+ method_details_str: str
367
+ formula: str
368
+
369
+ if num_time_periods == 2:
370
+ logger.info(
371
+ f"Number of unique time periods is 2. Using 2x2 DiD formula: "
372
+ f"{outcome} ~ {treated_group_col_for_formula} * {post_indicator_col}"
373
+ )
374
+ # For 2x2 DiD: outcome ~ group * post_indicator
375
+ # The interaction term A:B in statsmodels gives the DiD estimate.
376
+ formula_core = f"{treated_group_col_for_formula} * {post_indicator_col}"
377
+ interaction_term_key_for_results = f"{treated_group_col_for_formula}:{post_indicator_col}"
378
+
379
+ formula_parts = [formula_core]
380
+ main_model_terms = {outcome, treated_group_col_for_formula, post_indicator_col}
381
+
382
+ if covariates:
383
+ filtered_covs = [
384
+ c for c in covariates if c not in main_model_terms
385
+ ]
386
+ if filtered_covs:
387
+ formula_parts.extend(filtered_covs)
388
+
389
+ formula = f"{outcome} ~ {' + '.join(formula_parts)}"
390
+ parameters["estimation_method"] = "Statsmodels OLS for 2x2 DiD (Group * Post interaction)"
391
+ method_details_str = "DiD via Statsmodels 2x2 (Group * Post interaction)"
392
+
393
+ else: # num_time_periods > 2
394
+ logger.info(
395
+ f"Number of unique time periods is {num_time_periods} (>2). "
396
+ f"Using TWFE DiD formula: {outcome} ~ {interaction_term_col} + C({group_var}) + C({time_var})"
397
+ )
398
+ # For TWFE: outcome ~ actual_treatment_variable + UnitFE + TimeFE
399
+ # actual_treatment_variable is interaction_term_col (e.g., treated_group * post_indicator)
400
+ # UnitFE is C(group_var), TimeFE is C(time_var)
401
+ formula_parts = [
402
+ interaction_term_col,
403
+ f"C({group_var})",
404
+ f"C({time_var})"
405
+ ]
406
+ interaction_term_key_for_results = interaction_term_col
407
+ main_model_terms = {outcome, interaction_term_col, group_var, time_var}
408
+
409
+ if covariates:
410
+ filtered_covs = [
411
+ c for c in covariates if c not in main_model_terms
412
+ ]
413
+ if filtered_covs:
414
+ formula_parts.extend(filtered_covs)
415
+
416
+ formula = f"{outcome} ~ {' + '.join(formula_parts)}"
417
+ parameters["estimation_method"] = "Statsmodels OLS with TWFE (C() Notation)"
418
+ method_details_str = "DiD via Statsmodels TWFE (C() Notation)"
419
+
420
+ try:
421
+ logger.info(f"Using formula: {formula}")
422
+ logger.debug(f"Data head for statsmodels:\n{df_processed.head().to_string()}")
423
+ logger.debug(f"Regression DataFrame shape: {df_processed.shape}, Columns: {df_processed.columns.tolist()}")
424
+
425
+ ols_model = smf.ols(formula=formula, data=df_processed)
426
+ if group_var not in df_processed.columns:
427
+ # This check is mainly for clustering but good to ensure group_var exists.
428
+ # For 2x2, group_var (unit ID) might not be in formula but needed for clustering.
429
+ raise ValueError(f"Clustering variable '{group_var}' (panel unit ID) not found in regression data.")
430
+ logger.debug(f"Clustering standard errors by: {group_var}")
431
+ results = ols_model.fit(cov_type='cluster', cov_kwds={'groups': df_processed[group_var]})
432
+
433
+ logger.info("Statsmodels estimation complete.")
434
+ logger.info(f"Statsmodels Results Summary:\n{results.summary()}")
435
+
436
+ logger.debug(f"Extracting results using interaction term key: {interaction_term_key_for_results}")
437
+
438
+ parameters["final_formula"] = formula
439
+ parameters["interaction_term_coefficient_name"] = interaction_term_key_for_results
440
+
441
+ formatted_results = format_did_results(results, interaction_term_key_for_results,
442
+ did_diagnostics,
443
+ method_details=method_details_str,
444
+ parameters=parameters)
445
+ formatted_results["estimator"] = "statsmodels"
446
+
447
+ except Exception as e:
448
+ logger.error(f"Statsmodels OLS estimation failed: {e}", exc_info=True)
449
+ raise ValueError(f"DiD estimation failed (both DoWhy and Statsmodels): {e}")
450
+
451
+
452
+
453
+
454
+ # --- Add Interpretation --- (Now add interpretation to the formatted results)
455
+ try:
456
+ # Use the llm_instance fetched earlier
457
+ interpretation = interpret_did_results(formatted_results, did_diagnostics, dataset_description, llm=llm_instance)
458
+ formatted_results['interpretation'] = interpretation
459
+ except Exception as interp_e:
460
+ logger.error(f"DiD Interpretation failed: {interp_e}")
461
+ formatted_results['interpretation'] = "Interpretation failed."
462
+
463
+ return formatted_results
auto_causal/methods/difference_in_differences/llm_assist.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM Assist functions for Difference-in-Differences method."""
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ from typing import Optional, Any, Dict, Union
6
+ import logging
7
+ from pydantic import BaseModel, Field, ValidationError
8
+ from langchain_core.messages import HumanMessage
9
+ from langchain_core.exceptions import OutputParserException
10
+
11
+ # Import shared types if needed
12
+ from langchain_core.language_models import BaseChatModel
13
+
14
+ # Import shared LLM helpers
15
+ from auto_causal.utils.llm_helpers import call_llm_with_json_output
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Placeholder LLM/Helper Functions
20
+
21
+ # --- Pydantic model for LLM time variable extraction ---
22
+ class LLMTimeVar(BaseModel):
23
+ time_variable_name: Optional[str] = Field(None, description="The column name identified as the primary time variable.")
24
+
25
+
26
+ def identify_time_variable(df: pd.DataFrame,
27
+ query: Optional[str] = None,
28
+ dataset_description: Optional[str] = None,
29
+ llm: Optional[BaseChatModel] = None) -> Optional[str]:
30
+ '''Identifies the most likely time variable.
31
+
32
+ Current Implementation: Heuristic based on column names, with LLM fallback.
33
+ Future: Refine LLM prompt and parsing.
34
+ '''
35
+ # 1. Heuristic based on common time-related keywords
36
+ time_patterns = ['time', 'year', 'date', 'period', 'month', 'day']
37
+ columns = df.columns.tolist()
38
+ for col in columns:
39
+ if any(pattern in col.lower() for pattern in time_patterns):
40
+ logger.info(f"Identified '{col}' as time variable (heuristic).")
41
+ return col
42
+
43
+ # 2. LLM Fallback if heuristic fails and LLM is provided
44
+ if llm and query:
45
+ logger.warning("Heuristic failed for time variable. Trying LLM fallback...")
46
+ # --- Example: Add dataset description context ---
47
+ context_str = ""
48
+ if dataset_description:
49
+ # col_types = dataset_description.get('column_types', {}) # Description is now a string
50
+ context_str += f"\nDataset Description: {dataset_description}"
51
+ # Add other relevant info like sample values if available
52
+ # ------------------------------------------------
53
+ prompt = f"""Given the user query and the available data columns, identify the single most likely column representing the primary time dimension (e.g., year, date, period).
54
+
55
+ User Query: "{query}"
56
+ Available Columns: {columns}{context_str}
57
+
58
+ Respond ONLY with a JSON object containing the identified column name using the key 'time_variable_name'. If no suitable time variable is found, return null for the value.
59
+ Example: {{"time_variable_name": "Year"}} or {{"time_variable_name": null}}"""
60
+
61
+ messages = [HumanMessage(content=prompt)]
62
+ structured_llm = llm.with_structured_output(LLMTimeVar)
63
+
64
+ try:
65
+ parsed_result = structured_llm.invoke(messages)
66
+ llm_identified_col = parsed_result.time_variable_name
67
+
68
+ if llm_identified_col and llm_identified_col in columns:
69
+ logger.info(f"Identified '{llm_identified_col}' as time variable (LLM fallback).")
70
+ return llm_identified_col
71
+ elif llm_identified_col:
72
+ logger.warning(f"LLM fallback identified '{llm_identified_col}' but it's not in the columns. Ignoring.")
73
+ else:
74
+ logger.info("LLM fallback did not identify a time variable.")
75
+
76
+ except (OutputParserException, ValidationError) as e:
77
+ logger.error(f"LLM fallback for time variable failed parsing/validation: {e}")
78
+ except Exception as e:
79
+ logger.error(f"LLM fallback for time variable failed unexpectedly: {e}", exc_info=True)
80
+
81
+ logger.warning("Could not identify time variable using heuristics or LLM fallback.")
82
+ return None
83
+
84
+ # --- Pydantic model for LLM treatment period extraction ---
85
+ class LLMTreatmentPeriod(BaseModel):
86
+ treatment_start_period: Optional[Union[str, int, float]] = Field(None, description="The time period value (as string) when treatment is believed to start based on the query.")
87
+
88
+ def determine_treatment_period(df: pd.DataFrame, time_var: str, treatment: str,
89
+ query: Optional[str] = None,
90
+ dataset_description: Optional[str] = None,
91
+ llm: Optional[BaseChatModel] = None) -> Any:
92
+ '''Determines the period when treatment starts.
93
+
94
+ Tries LLM first if available, then falls back to heuristic.
95
+ '''
96
+ if time_var not in df.columns:
97
+ raise ValueError(f"Time variable '{time_var}' not found in DataFrame.")
98
+
99
+ unique_times_sorted = np.sort(df[time_var].dropna().unique())
100
+ if len(unique_times_sorted) < 2:
101
+ raise ValueError("Need at least two time periods for DiD")
102
+
103
+ # --- Try LLM First (if available) ---
104
+ llm_period = None
105
+ if llm and query:
106
+ logger.info("Attempting LLM call to determine treatment period start...")
107
+ # Provide sorted unique times for context
108
+ times_str = ", ".join(map(str, unique_times_sorted)) if len(unique_times_sorted) < 20 else f"{unique_times_sorted[0]}...{unique_times_sorted[-1]}"
109
+ # --- Example: Add dataset description context ---
110
+ context_str = ""
111
+ if dataset_description:
112
+ # Example: Show summary stats for time var if helpful
113
+ # time_stats = dataset_description.get('summary_stats', {}).get(time_var) # Cannot get from string
114
+ context_str += f"\nDataset Description: {dataset_description}"
115
+ # ------------------------------------------------
116
+ prompt = f"""Based on the user query and the observed time periods, determine the specific period value when the treatment ('{treatment}') likely started.
117
+
118
+ User Query: "{query}"
119
+ Time Variable Name: '{time_var}'
120
+ Observed Time Periods (sorted): [{times_str}]{context_str}
121
+
122
+ Respond ONLY with a JSON object containing the identified start period using the key 'treatment_start_period'. The value should be one of the observed periods if possible. If the query doesn't specify a start period, return null.
123
+ Example: {{"treatment_start_period": 2015}} or {{"treatment_start_period": null}}"""
124
+
125
+ messages = [HumanMessage(content=prompt)]
126
+ structured_llm = llm.with_structured_output(LLMTreatmentPeriod)
127
+
128
+ try:
129
+ parsed_result = structured_llm.invoke(messages)
130
+ potential_period = parsed_result.treatment_start_period
131
+
132
+ # Validate if the period exists in the data (might need type conversion)
133
+ if potential_period is not None:
134
+ # Try converting LLM output type to match data type if needed
135
+ try:
136
+ series_dtype = df[time_var].dtype
137
+ converted_period = pd.Series([potential_period]).astype(series_dtype).iloc[0]
138
+ except Exception:
139
+ converted_period = potential_period # Use raw if conversion fails
140
+
141
+ if converted_period in unique_times_sorted:
142
+ llm_period = converted_period
143
+ logger.info(f"LLM identified treatment period start: {llm_period}")
144
+ else:
145
+ logger.warning(f"LLM identified period '{potential_period}' (converted: '{converted_period}'), but it's not in the observed time periods. Ignoring LLM result.")
146
+ else:
147
+ logger.info("LLM did not identify a specific treatment start period from the query.")
148
+
149
+ except (OutputParserException, ValidationError) as e:
150
+ logger.error(f"LLM fallback for treatment period failed parsing/validation: {e}")
151
+ except Exception as e:
152
+ logger.error(f"LLM fallback for treatment period failed unexpectedly: {e}", exc_info=True)
153
+
154
+ if llm_period is not None:
155
+ return llm_period
156
+
157
+ # --- Fallback to Heuristic ---
158
+ logger.warning("Using heuristic (median time) to determine treatment period start.")
159
+ treatment_period_start = None
160
+ try:
161
+ if pd.api.types.is_numeric_dtype(df[time_var]):
162
+ median_time = np.median(unique_times_sorted)
163
+ possible_starts = unique_times_sorted[unique_times_sorted > median_time]
164
+ if len(possible_starts) > 0:
165
+ treatment_period_start = possible_starts[0]
166
+ else:
167
+ treatment_period_start = unique_times_sorted[-1]
168
+ logger.warning(f"Could not determine treatment start > median time. Defaulting to last period: {treatment_period_start}")
169
+ else: # Assume sortable categories or dates
170
+ median_idx = len(unique_times_sorted) // 2
171
+ if median_idx < len(unique_times_sorted):
172
+ treatment_period_start = unique_times_sorted[median_idx]
173
+ else:
174
+ treatment_period_start = unique_times_sorted[0]
175
+
176
+ if treatment_period_start is not None:
177
+ logger.info(f"Determined treatment period start: {treatment_period_start} (heuristic: median time).")
178
+ return treatment_period_start
179
+ else:
180
+ raise ValueError("Could not determine treatment start period using heuristic.")
181
+
182
+ except Exception as e:
183
+ logger.error(f"Error in heuristic for treatment period: {e}")
184
+ raise ValueError(f"Could not determine treatment start period using heuristic: {e}")
185
+
186
+ # --- Pydantic model for LLM group variable extraction ---
187
+ class LLMGroupVar(BaseModel):
188
+ group_variable_name: Optional[str] = Field(None, description="The column name identifying the panel unit (e.g., state, individual, firm).")
189
+
190
+ def identify_treatment_group(df: pd.DataFrame, treatment_var: str,
191
+ query: Optional[str] = None,
192
+ dataset_description: Optional[str] = None,
193
+ llm: Optional[BaseChatModel] = None) -> Optional[str]:
194
+ '''Identifies the variable indicating the treated group/unit ID.
195
+
196
+ Tries heuristic check for non-binary treatment_var first, then LLM,
197
+ then falls back to assuming treatment_var is the group/unit identifier.
198
+ '''
199
+ columns = df.columns.tolist()
200
+ if treatment_var not in columns:
201
+ logger.error(f"Treatment variable '{treatment_var}' provided to identify_treatment_group not found in DataFrame.")
202
+ # Fallback: Look for common ID names if specified treatment is missing
203
+ id_keywords = ['id', 'unit', 'group', 'entity', 'state', 'firm']
204
+ for col in columns:
205
+ if any(keyword in col.lower() for keyword in id_keywords):
206
+ logger.warning(f"Specified treatment '{treatment_var}' not found. Falling back to potential ID column '{col}' as group identifier.")
207
+ return col
208
+ return None # Give up if no likely ID column found
209
+
210
+ # --- Heuristic: Check if treatment_var is non-binary, if so, look for ID columns ---
211
+ is_potentially_binary = False
212
+ if pd.api.types.is_numeric_dtype(df[treatment_var]):
213
+ unique_vals = set(df[treatment_var].dropna().unique())
214
+ if unique_vals.issubset({0, 1}):
215
+ is_potentially_binary = True
216
+
217
+ if not is_potentially_binary:
218
+ logger.info(f"Provided treatment variable '{treatment_var}' is not binary (0/1). Searching for a separate group/unit ID column heuristically.")
219
+ id_keywords = ['id', 'unit', 'group', 'entity', 'state', 'firm']
220
+ # Prioritize 'group' or 'unit' if available
221
+ for keyword in ['group', 'unit']:
222
+ for col in columns:
223
+ if keyword == col.lower():
224
+ logger.info(f"Heuristically identified '{col}' as group/unit ID (treatment '{treatment_var}' was non-binary)." )
225
+ return col
226
+ # Then check other keywords
227
+ for col in columns:
228
+ if col != treatment_var and any(keyword in col.lower() for keyword in id_keywords):
229
+ logger.info(f"Heuristically identified '{col}' as group/unit ID (treatment '{treatment_var}' was non-binary)." )
230
+ return col
231
+ logger.warning("Heuristic search for group/unit ID failed when treatment was non-binary.")
232
+
233
+ # --- LLM Attempt (if heuristic didn't find an alternative or wasn't needed) ---
234
+ # Useful if query context helps disambiguate (e.g., "effect across states")
235
+ if llm and query:
236
+ logger.info("Attempting LLM call to identify group/unit variable...")
237
+ # --- Example: Add dataset description context ---
238
+ context_str = ""
239
+ if dataset_description:
240
+ # col_types = dataset_description.get('column_types', {}) # Description is now a string
241
+ context_str += f"\nDataset Description: {dataset_description}"
242
+ # ------------------------------------------------
243
+ prompt = f"""Given the user query and data columns, identify the single column that most likely represents the unique identifier for the panel units (e.g., state, individual, firm, unit ID), distinct from the treatment status indicator ('{treatment_var}').
244
+
245
+ User Query: "{query}"
246
+ Treatment Variable Mentioned: '{treatment_var}'
247
+ Available Columns: {columns}{context_str}
248
+
249
+ Respond ONLY with a JSON object containing the identified unit identifier column name using the key 'group_variable_name'. If the best identifier seems to be the treatment variable itself or none is suitable, return null.
250
+ Example: {{"group_variable_name": "state_id"}} or {{"group_variable_name": null}}"""
251
+
252
+ messages = [HumanMessage(content=prompt)]
253
+ structured_llm = llm.with_structured_output(LLMGroupVar)
254
+
255
+ try:
256
+ parsed_result = structured_llm.invoke(messages)
257
+ llm_identified_col = parsed_result.group_variable_name
258
+
259
+ if llm_identified_col and llm_identified_col in columns:
260
+ logger.info(f"Identified '{llm_identified_col}' as group/unit variable (LLM).")
261
+ return llm_identified_col
262
+ elif llm_identified_col:
263
+ logger.warning(f"LLM identified '{llm_identified_col}' but it's not in the columns. Ignoring.")
264
+ else:
265
+ logger.info("LLM did not identify a separate group/unit variable.")
266
+
267
+ except (OutputParserException, ValidationError) as e:
268
+ logger.error(f"LLM call for group/unit variable failed parsing/validation: {e}")
269
+ except Exception as e:
270
+ logger.error(f"LLM call for group/unit variable failed unexpectedly: {e}", exc_info=True)
271
+
272
+ # --- Final Fallback ---
273
+ logger.info(f"Defaulting to using provided treatment variable '{treatment_var}' as the group/unit identifier.")
274
+ return treatment_var
275
+
276
+ # --- Add interpret_did_results function ---
277
+
278
+ def interpret_did_results(
279
+ results: Dict[str, Any],
280
+ diagnostics: Optional[Dict[str, Any]],
281
+ dataset_description: Optional[str] = None,
282
+ llm: Optional[BaseChatModel] = None
283
+ ) -> str:
284
+ """Use LLM to interpret Difference-in-Differences results."""
285
+ default_interpretation = "LLM interpretation not available for DiD."
286
+ if llm is None:
287
+ logger.info("LLM not provided for DiD interpretation.")
288
+ return default_interpretation
289
+
290
+ try:
291
+ # --- Prepare summary for LLM ---
292
+ results_summary = {}
293
+ params = results.get('parameters', {})
294
+ diag_details = diagnostics.get('details', {}) if diagnostics else {}
295
+ parallel_trends = diag_details.get('parallel_trends', {})
296
+
297
+ effect = results.get('effect_estimate')
298
+ pval = results.get('p_value')
299
+ ci = results.get('confidence_interval')
300
+
301
+ results_summary['Method Used'] = results.get('method_details', 'Difference-in-Differences')
302
+ results_summary['Effect Estimate'] = f"{effect:.3f}" if isinstance(effect, (int, float)) else str(effect)
303
+ results_summary['P-value'] = f"{pval:.3f}" if isinstance(pval, (int, float)) else str(pval)
304
+ if isinstance(ci, (list, tuple)) and len(ci) == 2:
305
+ results_summary['Confidence Interval'] = f"[{ci[0]:.3f}, {ci[1]:.3f}]"
306
+ else:
307
+ results_summary['Confidence Interval'] = str(ci) if ci is not None else "N/A"
308
+
309
+ results_summary['Time Variable'] = params.get('time_var', 'N/A')
310
+ results_summary['Group/Unit Variable'] = params.get('group_var', 'N/A')
311
+ results_summary['Treatment Indicator Used'] = params.get('treatment_indicator', 'N/A')
312
+ results_summary['Treatment Start Period'] = params.get('treatment_period_start', 'N/A')
313
+ results_summary['Covariates Included'] = params.get('covariates', [])
314
+
315
+ diag_summary = {}
316
+ diag_summary['Parallel Trends Assumption Status'] = "Passed (Placeholder)" if parallel_trends.get('valid', False) else "Failed/Unknown (Placeholder)"
317
+ if not parallel_trends.get('valid', False) and parallel_trends.get('details') != "Placeholder validation":
318
+ diag_summary['Parallel Trends Details'] = parallel_trends.get('details', 'N/A')
319
+
320
+ # --- Example: Add dataset description context ---
321
+ context_str = ""
322
+ if dataset_description:
323
+ # context_str += f"\nDataset Context: {dataset_description.get('summary', 'N/A')}" # Use string directly
324
+ context_str += f"\n\nDataset Context Provided:\n{dataset_description}"
325
+ # ------------------------------------------------
326
+
327
+ # --- Construct Prompt ---
328
+ prompt = f"""
329
+ You are assisting with interpreting Difference-in-Differences (DiD) results.
330
+ {context_str} # Add context here
331
+
332
+ Estimation Results Summary:
333
+ {results_summary}
334
+
335
+ Diagnostics Summary:
336
+ {diag_summary}
337
+
338
+ Explain these DiD results in 2-4 concise sentences. Focus on:
339
+ 1. The estimated average treatment effect on the treated (magnitude, direction, statistical significance based on p-value < 0.05).
340
+ 2. The status of the parallel trends assumption (mentioning it's a key assumption for DiD).
341
+ 3. Note that the estimation controlled for unit and time fixed effects, and potentially covariates {results_summary['Covariates Included']}
342
+
343
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
344
+ {{
345
+ "interpretation": "<your concise interpretation text>"
346
+ }}
347
+ """
348
+
349
+ # --- Call LLM ---
350
+ response = call_llm_with_json_output(llm, prompt)
351
+
352
+ # --- Process Response ---
353
+ if response and isinstance(response, dict) and \
354
+ "interpretation" in response and isinstance(response["interpretation"], str):
355
+ return response["interpretation"]
356
+ else:
357
+ logger.warning(f"Failed to get valid interpretation from LLM for DiD. Response: {response}")
358
+ return default_interpretation
359
+
360
+ except Exception as e:
361
+ logger.error(f"Error during LLM interpretation for DiD: {e}")
362
+ return f"Error generating interpretation: {e}"
auto_causal/methods/difference_in_differences/utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utility functions for Difference-in-Differences
2
+ import pandas as pd
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ def create_post_indicator(df: pd.DataFrame, time_var: str, treatment_period_start: any) -> pd.Series:
8
+ """Creates the post-treatment indicator variable.
9
+ Checks if time_var is already a 0/1 indicator; otherwise, compares to treatment_period_start.
10
+ """
11
+ try:
12
+ time_var_series = df[time_var]
13
+ # Ensure numeric for checks and direct comparison
14
+ if pd.api.types.is_bool_dtype(time_var_series):
15
+ time_var_series = time_var_series.astype(int)
16
+
17
+ # Check if it's already a binary 0/1 indicator
18
+ if pd.api.types.is_numeric_dtype(time_var_series):
19
+ unique_vals = set(time_var_series.dropna().unique())
20
+ if unique_vals == {0, 1}:
21
+ logger.info(f"Time variable '{time_var}' is already a binary 0/1 indicator. Using it directly as post indicator.")
22
+ return time_var_series.astype(int)
23
+ else:
24
+ # Numeric, but not 0/1, so compare with treatment_period_start
25
+ logger.info(f"Time variable '{time_var}' is numeric. Comparing with treatment_period_start: {treatment_period_start}")
26
+ return (time_var_series >= treatment_period_start).astype(int)
27
+ else:
28
+ # Non-numeric and not boolean, will likely fall into TypeError for datetime conversion
29
+ # This else block might not be strictly necessary if TypeError is caught below
30
+ # but added for logical completeness before attempting datetime conversion.
31
+ pass # Let it fall through to TypeError if not numeric here
32
+
33
+ # If we reached here, it means it wasn't numeric or bool, try direct comparison which will likely raise TypeError
34
+ # and be caught by the except block for datetime conversion if applicable.
35
+ # This line is kept to ensure non-numeric non-datetime-like strings also trigger the except.
36
+ return (df[time_var] >= treatment_period_start).astype(int)
37
+
38
+ except TypeError:
39
+ # If direct comparison fails (e.g., comparing datetime with int/str, or non-numeric string with number),
40
+ # attempt to convert both to datetime objects for comparison.
41
+ logger.info(f"Direct comparison/numeric check failed for time_var '{time_var}'. Attempting datetime conversion.")
42
+ try:
43
+ time_series_dt = pd.to_datetime(df[time_var], errors='coerce')
44
+ # Try to convert treatment_period_start to datetime if it's not already
45
+ # This handles cases where treatment_period_start might be a date string
46
+ try:
47
+ treatment_start_dt = pd.to_datetime(treatment_period_start)
48
+ except Exception as e_conv:
49
+ logger.error(f"Could not convert treatment_period_start '{treatment_period_start}' to datetime: {e_conv}")
50
+ raise TypeError(f"treatment_period_start '{treatment_period_start}' could not be converted to a comparable datetime format.")
51
+
52
+ if time_series_dt.isna().all(): # if all values are NaT after conversion
53
+ raise ValueError(f"Time variable '{time_var}' could not be converted to datetime (all values NaT).")
54
+ if pd.isna(treatment_start_dt):
55
+ raise ValueError(f"Treatment start period '{treatment_period_start}' converted to NaT.")
56
+
57
+ logger.info(f"Comparing time_var '{time_var}' (as datetime) with treatment_start_dt '{treatment_start_dt}' (as datetime).")
58
+ return (time_series_dt >= treatment_start_dt).astype(int)
59
+ except Exception as e:
60
+ logger.error(f"Failed to compare time variable '{time_var}' with treatment start '{treatment_period_start}' using datetime logic: {e}", exc_info=True)
61
+ raise TypeError(f"Could not compare time variable '{time_var}' with treatment start '{treatment_period_start}'. Ensure they are comparable or convertible to datetime. Error: {e}")
62
+ except Exception as ex:
63
+ # Catch any other unexpected errors during the initial numeric processing
64
+ logger.error(f"Unexpected error processing time_var '{time_var}' for post indicator: {ex}", exc_info=True)
65
+ raise TypeError(f"Unexpected error processing time_var '{time_var}': {ex}")
auto_causal/methods/generalized_propensity_score/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Generalized Propensity Score (GPS) method for continuous treatments.
3
+ """
auto_causal/methods/generalized_propensity_score/diagnostics.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diagnostic checks for the Generalized Propensity Score (GPS) method.
3
+ """
4
+ from typing import Dict, List, Any
5
+ import pandas as pd
6
+ import logging
7
+ import numpy as np
8
+ import statsmodels.api as sm
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def assess_gps_balance(
13
+ df_with_gps: pd.DataFrame,
14
+ treatment_var: str,
15
+ covariate_vars: List[str],
16
+ gps_col_name: str,
17
+ **kwargs: Any
18
+ ) -> Dict[str, Any]:
19
+ """
20
+ Assesses the balance of covariates conditional on the estimated GPS.
21
+
22
+ This function is typically called after GPS estimation to validate the
23
+ assumption that covariates are independent of treatment conditional on GPS.
24
+
25
+ Args:
26
+ df_with_gps: DataFrame containing the original data plus the estimated GPS column.
27
+ treatment_var: The name of the continuous treatment variable column.
28
+ covariate_vars: A list of covariate column names to check for balance.
29
+ gps_col_name: The name of the column containing the estimated GPS values.
30
+ **kwargs: Additional arguments (e.g., number of strata for checking balance).
31
+
32
+ Returns:
33
+ A dictionary containing balance statistics and summaries. For example:
34
+ {
35
+ "overall_balance_metric": 0.05,
36
+ "covariate_balance": {
37
+ "cov1": {"statistic": 0.03, "p_value": 0.5, "balanced": True},
38
+ "cov2": {"statistic": 0.12, "p_value": 0.02, "balanced": False}
39
+ },
40
+ "summary": "Balance assessment complete."
41
+ }
42
+ """
43
+ logger.info(f"Assessing GPS balance for covariates: {covariate_vars}")
44
+
45
+ # Default to 5 strata (quintiles) if not specified
46
+ num_strata = kwargs.get('num_strata', 5)
47
+ if not isinstance(num_strata, int) or num_strata <= 1:
48
+ logger.warning(f"Invalid num_strata ({num_strata}), defaulting to 5.")
49
+ num_strata = 5
50
+
51
+ balance_results = {}
52
+ overall_summary = {
53
+ "num_strata_used": num_strata,
54
+ "covariates_tested": len(covariate_vars),
55
+ "warnings": [],
56
+ "all_strata_coefficients": {cov: [] for cov in covariate_vars},
57
+ "all_strata_p_values": {cov: [] for cov in covariate_vars}
58
+ }
59
+
60
+ if df_with_gps[gps_col_name].isnull().all():
61
+ logger.error(f"All GPS scores in column '{gps_col_name}' are NaN. Cannot perform balance assessment.")
62
+ overall_summary["error"] = "All GPS scores are NaN."
63
+ return {
64
+ "error": "All GPS scores are NaN.",
65
+ "summary": "Balance assessment failed."
66
+ }
67
+
68
+ try:
69
+ # Create GPS strata (e.g., quintiles)
70
+ # Ensure unique bin edges for qcut, duplicates='drop' will handle cases with sparse GPS values
71
+ # but might result in fewer than num_strata if GPS distribution is highly skewed or has few unique values.
72
+ try:
73
+ df_with_gps['gps_stratum'] = pd.qcut(df_with_gps[gps_col_name], num_strata, labels=False, duplicates='drop')
74
+ actual_num_strata = df_with_gps['gps_stratum'].nunique()
75
+ if actual_num_strata < num_strata and actual_num_strata > 0:
76
+ logger.warning(f"Requested {num_strata} strata, but due to GPS distribution, only {actual_num_strata} could be formed.")
77
+ overall_summary["warnings"].append(f"Only {actual_num_strata} strata formed out of {num_strata} requested.")
78
+ overall_summary["actual_num_strata_formed"] = actual_num_strata
79
+ except ValueError as ve:
80
+ logger.error(f"Could not create strata using pd.qcut due to: {ve}. This might happen if GPS has too few unique values.")
81
+ logger.info("Attempting to use unique GPS values as strata if count is low.")
82
+ unique_gps_count = df_with_gps[gps_col_name].nunique()
83
+ if unique_gps_count <= num_strata * 2 and unique_gps_count > 1: # Arbitrary threshold to try unique values as strata
84
+ strata_map = {val: i for i, val in enumerate(df_with_gps[gps_col_name].unique())}
85
+ df_with_gps['gps_stratum'] = df_with_gps[gps_col_name].map(strata_map)
86
+ actual_num_strata = df_with_gps['gps_stratum'].nunique()
87
+ overall_summary["actual_num_strata_formed"] = actual_num_strata
88
+ overall_summary["warnings"].append(f"Used {actual_num_strata} unique GPS values as strata due to qcut error.")
89
+ else:
90
+ overall_summary["error"] = f"Failed to create GPS strata: {ve}. GPS may have too few unique values."
91
+ return {
92
+ "error": overall_summary["error"],
93
+ "summary": "Balance assessment failed due to strata creation issues."
94
+ }
95
+
96
+ if df_with_gps['gps_stratum'].isnull().all():
97
+ logger.error("Stratum assignment resulted in all NaNs.")
98
+ overall_summary["error"] = "Stratum assignment resulted in all NaNs."
99
+ return {"error": overall_summary["error"], "summary": "Balance assessment failed."}
100
+
101
+
102
+ for cov in covariate_vars:
103
+ balance_results[cov] = {
104
+ "strata_details": [],
105
+ "mean_abs_coefficient": None,
106
+ "num_significant_strata_p005": 0,
107
+ "balanced_heuristic": True # Assume balanced until proven otherwise
108
+ }
109
+ coeffs_for_cov = []
110
+ p_values_for_cov = []
111
+
112
+ for stratum_idx in sorted(df_with_gps['gps_stratum'].dropna().unique()):
113
+ stratum_data = df_with_gps[df_with_gps['gps_stratum'] == stratum_idx]
114
+ stratum_detail = {"stratum_index": int(stratum_idx), "n_obs": len(stratum_data)}
115
+
116
+ if len(stratum_data) < 10: # Need a minimum number of observations for stable regression
117
+ stratum_detail["status"] = "Skipped (too few observations)"
118
+ stratum_detail["coefficient_on_treatment"] = np.nan
119
+ stratum_detail["p_value_on_treatment"] = np.nan
120
+ balance_results[cov]["strata_details"].append(stratum_detail)
121
+ continue
122
+
123
+ # Ensure covariate and treatment have variance within the stratum
124
+ if stratum_data[cov].nunique() < 2 or stratum_data[treatment_var].nunique() < 2:
125
+ stratum_detail["status"] = "Skipped (no variance in cov or treatment)"
126
+ stratum_detail["coefficient_on_treatment"] = np.nan
127
+ stratum_detail["p_value_on_treatment"] = np.nan
128
+ balance_results[cov]["strata_details"].append(stratum_detail)
129
+ continue
130
+
131
+ try:
132
+ X_balance = sm.add_constant(stratum_data[[treatment_var]])
133
+ y_balance = stratum_data[cov]
134
+
135
+ # Drop NaNs for this specific regression within stratum
136
+ temp_df = pd.concat([y_balance, X_balance], axis=1).dropna()
137
+ if len(temp_df) < X_balance.shape[1] +1: # Check for enough data points after NaNs for regression
138
+ stratum_detail["status"] = "Skipped (too few non-NaN obs for regression)"
139
+ stratum_detail["coefficient_on_treatment"] = np.nan
140
+ stratum_detail["p_value_on_treatment"] = np.nan
141
+ balance_results[cov]["strata_details"].append(stratum_detail)
142
+ continue
143
+
144
+ y_balance_fit = temp_df[cov]
145
+ X_balance_fit = temp_df[[col for col in temp_df.columns if col != cov]]
146
+
147
+ balance_model = sm.OLS(y_balance_fit, X_balance_fit).fit()
148
+ coeff = balance_model.params.get(treatment_var, np.nan)
149
+ p_value = balance_model.pvalues.get(treatment_var, np.nan)
150
+
151
+ coeffs_for_cov.append(coeff)
152
+ p_values_for_cov.append(p_value)
153
+ overall_summary["all_strata_coefficients"][cov].append(coeff)
154
+ overall_summary["all_strata_p_values"][cov].append(p_value)
155
+
156
+ stratum_detail["status"] = "Analyzed"
157
+ stratum_detail["coefficient_on_treatment"] = coeff
158
+ stratum_detail["p_value_on_treatment"] = p_value
159
+ if not pd.isna(p_value) and p_value < 0.05:
160
+ balance_results[cov]["num_significant_strata_p005"] += 1
161
+ balance_results[cov]["balanced_heuristic"] = False # If any stratum is unbalanced
162
+
163
+ except Exception as e_bal:
164
+ logger.debug(f"Balance check regression failed for {cov} in stratum {stratum_idx}: {e_bal}")
165
+ stratum_detail["status"] = f"Error: {str(e_bal)}"
166
+ stratum_detail["coefficient_on_treatment"] = np.nan
167
+ stratum_detail["p_value_on_treatment"] = np.nan
168
+
169
+ balance_results[cov]["strata_details"].append(stratum_detail)
170
+
171
+ if coeffs_for_cov:
172
+ balance_results[cov]["mean_abs_coefficient"] = np.nanmean(np.abs(coeffs_for_cov))
173
+ else:
174
+ balance_results[cov]["mean_abs_coefficient"] = np.nan # No strata were analyzable
175
+
176
+ overall_summary["num_covariates_potentially_imbalanced_p005"] = sum(
177
+ 1 for cov_data in balance_results.values() if not cov_data["balanced_heuristic"]
178
+ )
179
+
180
+ except Exception as e:
181
+ logger.error(f"Error during GPS balance assessment: {e}", exc_info=True)
182
+ overall_summary["error"] = f"Overall assessment error: {str(e)}"
183
+ return {
184
+ "error": str(e),
185
+ "balance_results": balance_results,
186
+ "summary_stats": overall_summary,
187
+ "summary": "Balance assessment failed due to an unexpected error."
188
+ }
189
+
190
+ logger.info("GPS balance assessment complete.")
191
+
192
+ return {
193
+ "balance_results_per_covariate": balance_results,
194
+ "summary_stats": overall_summary,
195
+ "summary": "GPS balance assessment finished. Review strata details and mean absolute coefficients."
196
+ }
auto_causal/methods/generalized_propensity_score/estimator.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core estimation logic for the Generalized Propensity Score (GPS) method.
3
+ """
4
+ from typing import Dict, List, Any
5
+ import pandas as pd
6
+ import logging
7
+ import numpy as np
8
+ import statsmodels.api as sm
9
+
10
+ from .diagnostics import assess_gps_balance # Import for balance check
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def estimate_effect_gps(
15
+ df: pd.DataFrame,
16
+ treatment: str,
17
+ outcome: str,
18
+ covariates: List[str],
19
+ **kwargs: Any
20
+ ) -> Dict[str, Any]:
21
+ """
22
+ Estimates the causal effect using the Generalized Propensity Score method
23
+ for continuous treatments.
24
+
25
+ This function will be called by the method_executor_tool.
26
+
27
+ Args:
28
+ df: The input DataFrame.
29
+ treatment: The name of the continuous treatment variable column.
30
+ outcome: The name of the outcome variable column.
31
+ covariates: A list of covariate column names.
32
+ **kwargs: Additional arguments for controlling the estimation, including:
33
+ - gps_model_spec (dict): Specification for the GPS model (T ~ X).
34
+ - outcome_model_spec (dict): Specification for the outcome model (Y ~ T, GPS).
35
+ - t_values_range (list or dict): Specification for treatment levels for ADRF.
36
+ - n_bootstraps (int): Number of bootstrap replications for SEs.
37
+
38
+ Returns:
39
+ A dictionary containing the estimation results, including:
40
+ - "effect_estimate": Typically the ADRF or a specific contrast.
41
+ - "standard_error": Standard error for the primary effect estimate.
42
+ - "confidence_interval": Confidence interval for the primary estimate.
43
+ - "adrf_curve": Data representing the Average Dose-Response Function.
44
+ - "specific_contrasts": Any calculated specific contrasts.
45
+ - "diagnostics": Results from diagnostic checks (e.g., balance).
46
+ - "method_details": Description of the method and models used.
47
+ - "parameters_used": Dictionary of parameters used.
48
+ """
49
+ logger.info(f"Starting GPS estimation for treatment '{treatment}', outcome '{outcome}'.")
50
+
51
+ # --- Parameter Extraction and Defaults ---
52
+ gps_model_spec = kwargs.get('gps_model_spec', {"type": "linear"})
53
+ outcome_model_spec = kwargs.get('outcome_model_spec', {"type": "polynomial", "degree": 2, "interaction": True})
54
+
55
+ # Get t_values for ADRF from llm_assist or kwargs, default to 10 points over observed range
56
+ # For simplicity, we'll use a simple range here. In a full impl, this might call llm_assist.
57
+ t_values_for_adrf = kwargs.get('t_values_for_adrf')
58
+ if t_values_for_adrf is None:
59
+ min_t_obs = df[treatment].min()
60
+ max_t_obs = df[treatment].max()
61
+ if pd.isna(min_t_obs) or pd.isna(max_t_obs) or min_t_obs == max_t_obs:
62
+ logger.warning(f"Cannot determine a valid range for treatment '{treatment}' for ADRF. Using limited points.")
63
+ t_values_for_adrf = sorted(list(df[treatment].dropna().unique()))[:10] # Fallback
64
+ else:
65
+ t_values_for_adrf = np.linspace(min_t_obs, max_t_obs, 10).tolist()
66
+
67
+ n_bootstraps = kwargs.get('n_bootstraps', 0) # Default to 0, meaning no bootstrap for now
68
+
69
+ logger.info(f"Using GPS model spec: {gps_model_spec}")
70
+ logger.info(f"Using outcome model spec: {outcome_model_spec}")
71
+ logger.info(f"Evaluating ADRF at t-values: {t_values_for_adrf}")
72
+
73
+ try:
74
+ # 2. Estimate GPS Values
75
+ df_with_gps, gps_estimation_diagnostics = _estimate_gps_values(
76
+ df.copy(), treatment, covariates, gps_model_spec
77
+ )
78
+ if 'gps_score' not in df_with_gps.columns or df_with_gps['gps_score'].isnull().all():
79
+ logger.error("GPS estimation failed or resulted in all NaNs.")
80
+ return {
81
+ "error": "GPS estimation failed.",
82
+ "diagnostics": gps_estimation_diagnostics,
83
+ "method_details": "GPS (Failed)",
84
+ "parameters_used": kwargs
85
+ }
86
+
87
+ # Drop rows where GPS or outcome or necessary modeling variables are NaN before proceeding
88
+ modeling_cols = [outcome, treatment, 'gps_score'] + covariates
89
+ df_with_gps.dropna(subset=modeling_cols, inplace=True)
90
+ if df_with_gps.empty:
91
+ logger.error("DataFrame is empty after GPS estimation and NaN removal.")
92
+ return {"error": "No data available after GPS estimation and NaN removal.", "method_details": "GPS (Failed)", "parameters_used": kwargs}
93
+
94
+
95
+ # 3. Assess GPS Balance (call diagnostics.assess_gps_balance)
96
+ balance_diagnostics = assess_gps_balance(
97
+ df_with_gps, treatment, covariates, 'gps_score' # kwargs for assess_gps_balance can be passed if needed
98
+ )
99
+
100
+ # 4. Estimate Outcome Model
101
+ fitted_outcome_model = _estimate_outcome_model(
102
+ df_with_gps, outcome, treatment, 'gps_score', outcome_model_spec
103
+ )
104
+
105
+ # 5. Generate Dose-Response Function
106
+ adrf_results = _generate_dose_response_function(
107
+ df_with_gps, fitted_outcome_model, treatment, 'gps_score', outcome_model_spec, t_values_for_adrf
108
+ )
109
+ adrf_curve_data = {"t_levels": t_values_for_adrf, "expected_outcomes": adrf_results}
110
+
111
+ # 6. Calculate specific contrasts if requested (Placeholder)
112
+ specific_contrasts = {"info": "Specific contrasts not implemented in this version."}
113
+
114
+ # 7. Perform bootstrapping for SEs if requested (Placeholder for now)
115
+ standard_error_info = {"info": "Bootstrap SEs not implemented in this version."}
116
+ confidence_interval_info = {"info": "Bootstrap CIs not implemented in this version."}
117
+ if n_bootstraps > 0:
118
+ logger.info(f"Bootstrapping with {n_bootstraps} replications (placeholder).")
119
+ # Actual bootstrapping logic would go here.
120
+ # For now, we'll just note that it's not implemented.
121
+
122
+ logger.info("GPS estimation steps completed.")
123
+
124
+ # Consolidate diagnostics
125
+ all_diagnostics = {
126
+ "gps_estimation_diagnostics": gps_estimation_diagnostics,
127
+ "balance_check": balance_diagnostics, # Now using the actual balance check results
128
+ "outcome_model_summary": str(fitted_outcome_model.summary()) if fitted_outcome_model else "Outcome model not fitted.",
129
+ "warnings": [], # Populate with any warnings during the process
130
+ "summary": "GPS estimation complete."
131
+ }
132
+
133
+ return {
134
+ "effect_estimate": adrf_curve_data, # The ADRF is the primary "effect"
135
+ "standard_error_info": standard_error_info, # Placeholder
136
+ "confidence_interval_info": confidence_interval_info, # Placeholder
137
+ "adrf_curve": adrf_curve_data,
138
+ "specific_contrasts": specific_contrasts, # Placeholder
139
+ "diagnostics": all_diagnostics,
140
+ "method_details": f"Generalized Propensity Score (GPS) with {gps_model_spec.get('type', 'N/A')} GPS model and {outcome_model_spec.get('type', 'N/A')} outcome model.",
141
+ "parameters_used": {
142
+ "treatment_var": treatment,
143
+ "outcome_var": outcome,
144
+ "covariate_vars": covariates,
145
+ "gps_model_spec": gps_model_spec,
146
+ "outcome_model_spec": outcome_model_spec,
147
+ "t_values_for_adrf": t_values_for_adrf,
148
+ "n_bootstraps": n_bootstraps,
149
+ **kwargs
150
+ }
151
+ }
152
+ except Exception as e:
153
+ logger.error(f"Error during GPS estimation pipeline: {e}", exc_info=True)
154
+ return {
155
+ "error": f"Pipeline failed: {str(e)}",
156
+ "method_details": "GPS (Failed)",
157
+ "diagnostics": {"error": f"Pipeline failed during GPS estimation: {str(e)}"}, # Add diagnostics here too
158
+ "parameters_used": kwargs
159
+ }
160
+
161
+
162
+ # Placeholder for internal helper functions
163
+ def _estimate_gps_values(
164
+ df: pd.DataFrame,
165
+ treatment: str,
166
+ covariates: List[str],
167
+ gps_model_spec: Dict
168
+ ) -> tuple[pd.DataFrame, Dict]:
169
+ """
170
+ Estimates Generalized Propensity Scores.
171
+ Assumes T | X ~ N(X*beta, sigma^2), so GPS is the conditional density.
172
+ """
173
+ logger.info(f"Estimating GPS for treatment '{treatment}' using covariates: {covariates}")
174
+ diagnostics = {}
175
+
176
+ if not covariates:
177
+ logger.error("No covariates provided for GPS estimation.")
178
+ diagnostics["error"] = "No covariates provided."
179
+ df['gps_score'] = np.nan # Ensure gps_score column is added
180
+ return df, diagnostics
181
+
182
+ X_df = df[covariates]
183
+ T_series = df[treatment]
184
+
185
+ # Handle potential NaN values in covariates or treatment before modeling
186
+ valid_indices = X_df.dropna().index.intersection(T_series.dropna().index)
187
+ if len(valid_indices) < len(df):
188
+ logger.warning(f"Dropped {len(df) - len(valid_indices)} rows due to NaNs in treatment/covariates before GPS estimation.")
189
+ diagnostics["pre_estimation_nan_rows_dropped"] = len(df) - len(valid_indices)
190
+
191
+ X = X_df.loc[valid_indices]
192
+ T = T_series.loc[valid_indices]
193
+
194
+ if X.empty or T.empty:
195
+ logger.error("Covariate or treatment data is empty after NaN handling.")
196
+ diagnostics["error"] = "Covariate or treatment data is empty after NaN handling."
197
+ return df, diagnostics
198
+
199
+ X_sm = sm.add_constant(X, has_constant='add')
200
+
201
+ try:
202
+ if gps_model_spec.get("type") == 'linear':
203
+ model = sm.OLS(T, X_sm).fit()
204
+ t_hat = model.predict(X_sm)
205
+ residuals = T - t_hat
206
+ # MSE: sum of squared residuals / (n - k) where k is number of regressors (including const)
207
+ if len(T) <= X_sm.shape[1]:
208
+ logger.error("Not enough degrees of freedom to estimate sigma_sq_hat.")
209
+ diagnostics["error"] = "Not enough degrees of freedom for GPS variance."
210
+ df['gps_score'] = np.nan
211
+ return df, diagnostics
212
+
213
+ sigma_sq_hat = np.sum(residuals**2) / (len(T) - X_sm.shape[1])
214
+
215
+ if sigma_sq_hat <= 1e-9: # Check for effectively zero or very small variance
216
+ logger.warning(f"Estimated residual variance (sigma_sq_hat) is very close to zero ({sigma_sq_hat}). GPS will be set to NaN.")
217
+ diagnostics["warning_sigma_sq_hat_near_zero"] = sigma_sq_hat
218
+ df['gps_score'] = np.nan # Set GPS to NaN as density is ill-defined
219
+ if sigma_sq_hat == 0: # if it is exactly zero, add specific error
220
+ diagnostics["error_sigma_sq_hat_is_zero"] = "Residual variance is exactly zero."
221
+ return df, diagnostics
222
+
223
+
224
+ # Calculate GPS: (1 / sqrt(2*pi*sigma_hat^2)) * exp(-(T_i - T_hat_i)^2 / (2*sigma_hat^2))
225
+ # Ensure calculation is done on the original T values (T_series.loc[valid_indices])
226
+ # and corresponding t_hat for those valid_indices
227
+ gps_values_calculated = (1 / np.sqrt(2 * np.pi * sigma_sq_hat)) * np.exp(-((T - t_hat)**2) / (2 * sigma_sq_hat))
228
+
229
+ # Assign back to the original DataFrame using .loc to ensure alignment
230
+ df['gps_score'] = np.nan # Initialize column
231
+ df.loc[valid_indices, 'gps_score'] = gps_values_calculated
232
+
233
+ diagnostics["gps_model_type"] = "linear_ols"
234
+ diagnostics["gps_model_rsquared"] = model.rsquared
235
+ diagnostics["gps_residual_variance_mse"] = sigma_sq_hat
236
+ diagnostics["num_observations_for_gps_model"] = len(T)
237
+
238
+ else:
239
+ logger.error(f"GPS model type '{gps_model_spec.get('type')}' not implemented.")
240
+ diagnostics["error"] = f"GPS model type '{gps_model_spec.get('type')}' not implemented."
241
+ df['gps_score'] = np.nan
242
+
243
+ except Exception as e:
244
+ logger.error(f"Error during GPS model estimation: {e}", exc_info=True)
245
+ diagnostics["error"] = f"Exception during GPS estimation: {str(e)}"
246
+ df['gps_score'] = np.nan
247
+
248
+ # Ensure the original df is not modified if no valid indices for GPS estimation
249
+ if 'gps_score' not in df.columns:
250
+ df['gps_score'] = np.nan
251
+
252
+ return df, diagnostics
253
+
254
+ def _estimate_outcome_model(
255
+ df_with_gps: pd.DataFrame,
256
+ outcome: str,
257
+ treatment: str,
258
+ gps_col_name: str,
259
+ outcome_model_spec: Dict
260
+ ) -> Any: # Returns a fitted statsmodels model
261
+ """
262
+ Estimates the outcome model Y ~ f(T, GPS).
263
+ """
264
+ logger.info(f"Estimating outcome model for '{outcome}' using T='{treatment}', GPS='{gps_col_name}'")
265
+
266
+ Y = df_with_gps[outcome]
267
+ T_val = pd.Series(df_with_gps[treatment].values, index=df_with_gps.index)
268
+ GPS_val = pd.Series(df_with_gps[gps_col_name].values, index=df_with_gps.index)
269
+
270
+ X_outcome_dict = {'intercept': np.ones(len(df_with_gps))}
271
+
272
+ model_type = outcome_model_spec.get("type", "polynomial")
273
+ degree = outcome_model_spec.get("degree", 2)
274
+ interaction = outcome_model_spec.get("interaction", True)
275
+
276
+ if model_type == "polynomial":
277
+ X_outcome_dict['T'] = T_val
278
+ X_outcome_dict['GPS'] = GPS_val
279
+ if degree >= 2:
280
+ X_outcome_dict['T_sq'] = T_val**2
281
+ X_outcome_dict['GPS_sq'] = GPS_val**2
282
+ if degree >=3: # Example for higher order, can be made more general
283
+ X_outcome_dict['T_cub'] = T_val**3
284
+ X_outcome_dict['GPS_cub'] = GPS_val**3
285
+ if interaction:
286
+ X_outcome_dict['T_x_GPS'] = T_val * GPS_val
287
+ if degree >=2: # Interaction with squared terms if degree allows
288
+ X_outcome_dict['T_sq_x_GPS'] = (T_val**2) * GPS_val
289
+ X_outcome_dict['T_x_GPS_sq'] = T_val * (GPS_val**2)
290
+
291
+ # Add more model types as needed (e.g., splines)
292
+ else:
293
+ logger.warning(f"Outcome model type '{model_type}' not fully recognized. Defaulting to T + GPS.")
294
+ X_outcome_dict['T'] = T_val
295
+ X_outcome_dict['GPS'] = GPS_val
296
+ # Fallback to linear if spec is unknown or simple
297
+
298
+ X_outcome_df = pd.DataFrame(X_outcome_dict, index=df_with_gps.index)
299
+
300
+ # Drop rows with NaNs that might have been introduced by transformations if T or GPS were NaN
301
+ # (though earlier dropna should handle most of this for input T/GPS)
302
+ valid_outcome_model_indices = Y.dropna().index.intersection(X_outcome_df.dropna().index)
303
+ if len(valid_outcome_model_indices) < len(df_with_gps):
304
+ logger.warning(f"Dropped {len(df_with_gps) - len(valid_outcome_model_indices)} rows due to NaNs before outcome model fitting.")
305
+
306
+ Y_fit = Y.loc[valid_outcome_model_indices]
307
+ X_outcome_df_fit = X_outcome_df.loc[valid_outcome_model_indices]
308
+
309
+ if Y_fit.empty or X_outcome_df_fit.empty:
310
+ logger.error("Not enough data to fit outcome model after NaN handling.")
311
+ raise ValueError("Empty data for outcome model fitting.")
312
+
313
+ try:
314
+ model = sm.OLS(Y_fit, X_outcome_df_fit).fit()
315
+ logger.info("Outcome model estimated successfully.")
316
+ return model
317
+ except Exception as e:
318
+ logger.error(f"Error during outcome model estimation: {e}", exc_info=True)
319
+ raise # Re-raise the exception to be caught by the main try-except block
320
+
321
+ def _generate_dose_response_function(
322
+ df_with_gps: pd.DataFrame,
323
+ fitted_outcome_model: Any,
324
+ treatment: str,
325
+ gps_col_name: str,
326
+ outcome_model_spec: Dict, # To know how to construct X_pred features
327
+ t_values_to_evaluate: List[float]
328
+ ) -> List[float]:
329
+ """
330
+ Calculates the Average Dose-Response Function (ADRF).
331
+ E[Y(t)] = integral over E[Y | T=t, GPS=g] * f(g) dg
332
+ ~= (1/N) * sum_i E[Y | T=t, GPS=g_i] (using observed GPS values)
333
+ """
334
+ logger.info(f"Calculating ADRF for treatment levels: {t_values_to_evaluate}")
335
+ adrf_estimates = []
336
+
337
+ if not t_values_to_evaluate: # Handle empty list case
338
+ logger.warning("t_values_to_evaluate is empty. ADRF calculation will be skipped.")
339
+ return []
340
+
341
+ model_exog_names = fitted_outcome_model.model.exog_names
342
+
343
+ # Original GPS values from the dataframe
344
+ original_gps_values = pd.Series(df_with_gps[gps_col_name].values, index=df_with_gps.index)
345
+
346
+ for t_level in t_values_to_evaluate:
347
+ # Create a new DataFrame for prediction at this t_level
348
+ # Each row corresponds to an original observation's GPS, but with T set to t_level
349
+ X_pred_dict = {'intercept': np.ones(len(df_with_gps))}
350
+
351
+ # Reconstruct features based on outcome_model_spec and model_exog_names
352
+ # This mirrors the construction in _estimate_outcome_model
353
+ degree = outcome_model_spec.get("degree", 2)
354
+ interaction = outcome_model_spec.get("interaction", True)
355
+
356
+ if 'T' in model_exog_names: X_pred_dict['T'] = t_level
357
+ if 'GPS' in model_exog_names: X_pred_dict['GPS'] = original_gps_values
358
+
359
+ if 'T_sq' in model_exog_names: X_pred_dict['T_sq'] = t_level**2
360
+ if 'GPS_sq' in model_exog_names: X_pred_dict['GPS_sq'] = original_gps_values**2
361
+
362
+ if 'T_cub' in model_exog_names: X_pred_dict['T_cub'] = t_level**3 # Example
363
+ if 'GPS_cub' in model_exog_names: X_pred_dict['GPS_cub'] = original_gps_values**3 # Example
364
+
365
+ if 'T_x_GPS' in model_exog_names and interaction:
366
+ X_pred_dict['T_x_GPS'] = t_level * original_gps_values
367
+ if 'T_sq_x_GPS' in model_exog_names and interaction and degree >=2:
368
+ X_pred_dict['T_sq_x_GPS'] = (t_level**2) * original_gps_values
369
+ if 'T_x_GPS_sq' in model_exog_names and interaction and degree >=2:
370
+ X_pred_dict['T_x_GPS_sq'] = t_level * (original_gps_values**2)
371
+
372
+ X_pred_df = pd.DataFrame(X_pred_dict, index=df_with_gps.index)
373
+
374
+ # Ensure all required columns are present and in the correct order
375
+ # Drop any rows that might have NaNs if original_gps_values had NaNs (though they should be filtered before this)
376
+ X_pred_df_fit = X_pred_df[model_exog_names].dropna()
377
+
378
+ if X_pred_df_fit.empty:
379
+ logger.warning(f"Prediction data for t_level={t_level} is empty after NaN drop. Assigning NaN to ADRF point.")
380
+ adrf_estimates.append(np.nan)
381
+ continue
382
+
383
+ predicted_outcomes_at_t = fitted_outcome_model.predict(X_pred_df_fit)
384
+ adrf_estimates.append(np.mean(predicted_outcomes_at_t))
385
+
386
+ return adrf_estimates
auto_causal/methods/generalized_propensity_score/llm_assist.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM-assisted components for the Generalized Propensity Score (GPS) method.
3
+
4
+ These functions help in suggesting model specifications or parameters
5
+ by leveraging an LLM, providing intelligent defaults when not specified by the user.
6
+ """
7
+ from typing import Dict, List, Any, Optional
8
+ import pandas as pd
9
+ import logging
10
+ from auto_causal.utils.llm_helpers import call_llm_with_json_output # Hypothetical import
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def suggest_treatment_model_spec(
15
+ df: pd.DataFrame,
16
+ treatment_var: str,
17
+ covariate_vars: List[str],
18
+ query: Optional[str] = None,
19
+ llm_client: Optional[Any] = None
20
+ ) -> Dict[str, Any]:
21
+ """
22
+ Suggests a model specification for the treatment mechanism (T ~ X) in GPS.
23
+
24
+ Args:
25
+ df: The input DataFrame.
26
+ treatment_var: The name of the continuous treatment variable.
27
+ covariate_vars: A list of covariate names.
28
+ query: Optional user query for context.
29
+ llm_client: Optional LLM client for making a call.
30
+
31
+ Returns:
32
+ A dictionary representing the suggested model specification.
33
+ E.g., {"type": "linear", "formula": "T ~ X1 + X2"} or
34
+ {"type": "random_forest", "params": {...}}
35
+ """
36
+ logger.info(f"Suggesting treatment model spec for: {treatment_var}")
37
+
38
+ # Example of constructing a more detailed prompt for an LLM
39
+ prompt_parts = [
40
+ f"You are an expert econometrician. The user wants to estimate a Generalized Propensity Score (GPS) for a continuous treatment variable '{treatment_var}'.",
41
+ f"The available covariates are: {covariate_vars}.",
42
+ f"The user's research query is: '{query if query else 'Not specified'}'.",
43
+ "Based on this information and general best practices for GPS estimation:",
44
+ "1. Suggest a suitable model type for estimating the treatment (T) given covariates (X). Common choices include 'linear' (OLS), or flexible models like 'random_forest' or 'gradient_boosting' if non-linearities are suspected.",
45
+ "2. If suggesting a regression model like OLS, provide a Patsy-style formula string (e.g., 'treatment ~ cov1 + cov2 + cov1*cov2').",
46
+ "3. If suggesting a machine learning model, list key hyperparameters and reasonable starting values (e.g., n_estimators, max_depth).",
47
+ "Return your suggestion as a JSON object with the following structure:",
48
+ '''
49
+ {
50
+ "model_type": "<e.g., linear, random_forest>",
51
+ "formula": "<Patsy formula if model_type is linear/glm, else null>",
52
+ "parameters": { // if applicable for ML models
53
+ "<param1_name>": "<param1_value>",
54
+ "<param2_name>": "<param2_value>"
55
+ },
56
+ "reasoning": "<Brief justification for your suggestion>"
57
+ }
58
+ '''
59
+ ]
60
+ full_prompt = "\n".join(prompt_parts)
61
+
62
+ if llm_client:
63
+ logger.info("LLM client provided. Sending constructed prompt (actual call is hypothetical).")
64
+ logger.debug(f"LLM Prompt for treatment model spec:\n{full_prompt}")
65
+ # In a real implementation:
66
+ # response_json = call_llm_with_json_output(llm_client, full_prompt)
67
+ # if response_json and isinstance(response_json, dict):
68
+ # return response_json
69
+ # else:
70
+ # logger.warning("LLM did not return a valid JSON dict for treatment model spec.")
71
+ pass # Pass for now as it's a hypothetical call
72
+
73
+ # Default suggestion if no LLM or LLM fails
74
+ return {
75
+ "model_type": "linear",
76
+ "formula": f"{treatment_var} ~ {' + '.join(covariate_vars) if covariate_vars else '1'}",
77
+ "parameters": None,
78
+ "reasoning": "Defaulting to a linear model for T ~ X. Consider a more flexible model if non-linearities are expected.",
79
+ "comment": "This is a default suggestion."
80
+ }
81
+
82
+ def suggest_outcome_model_spec(
83
+ df: pd.DataFrame,
84
+ outcome_var: str,
85
+ treatment_var: str,
86
+ gps_col_name: str,
87
+ query: Optional[str] = None,
88
+ llm_client: Optional[Any] = None
89
+ ) -> Dict[str, Any]:
90
+ """
91
+ Suggests a model specification for the outcome mechanism (Y ~ T, GPS) in GPS.
92
+
93
+ Args:
94
+ df: The input DataFrame.
95
+ outcome_var: The name of the outcome variable.
96
+ treatment_var: The name of the continuous treatment variable.
97
+ gps_col_name: The name of the GPS column.
98
+ query: Optional user query for context.
99
+ llm_client: Optional LLM client for making a call.
100
+
101
+ Returns:
102
+ A dictionary representing the suggested model specification.
103
+ E.g., {"type": "polynomial", "degree": 2, "interaction": True,
104
+ "formula": "Y ~ T + T^2 + GPS + GPS^2 + T*GPS"}
105
+ """
106
+ logger.info(f"Suggesting outcome model spec for: {outcome_var}")
107
+
108
+ prompt_parts = [
109
+ f"You are an expert econometrician. For a Generalized Propensity Score (GPS) analysis, the user needs to model the outcome '{outcome_var}' conditional on the continuous treatment '{treatment_var}' and the estimated GPS (column name '{gps_col_name}').",
110
+ "The goal is to flexibly capture the relationship E[Y | T, GPS]. A common approach is to use a polynomial specification for T and GPS, including interaction terms.",
111
+ f"The user's research query is: '{query if query else 'Not specified'}'.",
112
+ "Suggest a specification for this outcome model. Consider:",
113
+ "1. The functional form for T (e.g., linear, quadratic, cubic).",
114
+ "2. The functional form for GPS (e.g., linear, quadratic, cubic).",
115
+ "3. Whether to include interaction terms between T and GPS (e.g., T*GPS, T^2*GPS, T*GPS^2).",
116
+ "Return your suggestion as a JSON object with the following structure:",
117
+ '''
118
+ {
119
+ "model_type": "polynomial", // Or other types like "splines"
120
+ "treatment_terms": ["T", "T_sq"], // e.g., ["T"] for linear, ["T", "T_sq"] for quadratic
121
+ "gps_terms": ["GPS", "GPS_sq"], // e.g., ["GPS"] for linear, ["GPS", "GPS_sq"] for quadratic
122
+ "interaction_terms": ["T_x_GPS", "T_sq_x_GPS", "T_x_GPS_sq"], // Interactions to include, or empty list
123
+ "reasoning": "<Brief justification for your suggestion>"
124
+ }
125
+ '''
126
+ ]
127
+ full_prompt = "\n".join(prompt_parts)
128
+
129
+ if llm_client:
130
+ logger.info("LLM client provided. Sending constructed prompt for outcome model (hypothetical call).")
131
+ logger.debug(f"LLM Prompt for outcome model spec:\n{full_prompt}")
132
+ # In a real implementation:
133
+ # response_json = call_llm_with_json_output(llm_client, full_prompt)
134
+ # if response_json and isinstance(response_json, dict):
135
+ # # Basic validation of expected keys for outcome model could go here
136
+ # return response_json
137
+ # else:
138
+ # logger.warning("LLM did not return a valid JSON dict for outcome model spec.")
139
+ pass # Pass for now
140
+
141
+ # Default suggestion
142
+ return {
143
+ "model_type": "polynomial",
144
+ "treatment_terms": ["T", "T_sq"],
145
+ "gps_terms": ["GPS", "GPS_sq"],
146
+ "interaction_terms": ["T_x_GPS"],
147
+ "reasoning": "Defaulting to a quadratic specification for T and GPS with a simple T*GPS interaction. This is a common starting point.",
148
+ "comment": "This is a default suggestion."
149
+ }
150
+
151
+ def suggest_dose_response_t_values(
152
+ df: pd.DataFrame,
153
+ treatment_var: str,
154
+ num_points: int = 20,
155
+ llm_client: Optional[Any] = None
156
+ ) -> List[float]:
157
+ """
158
+ Suggests a relevant range and number of points for estimating the ADRF.
159
+
160
+ Args:
161
+ df: The input DataFrame.
162
+ treatment_var: The name of the continuous treatment variable.
163
+ num_points: Desired number of points for the ADRF curve.
164
+ llm_client: Optional LLM client for making a call.
165
+
166
+ Returns:
167
+ A list of treatment values at which to evaluate the ADRF.
168
+ """
169
+ logger.info(f"Suggesting dose response t-values for: {treatment_var}")
170
+
171
+ prompt_parts = [
172
+ f"For a Generalized Propensity Score (GPS) analysis with continuous treatment '{treatment_var}', the user needs to estimate an Average Dose-Response Function (ADRF).",
173
+ f"The observed range of '{treatment_var}' is from {df[treatment_var].min():.2f} to {df[treatment_var].max():.2f}.",
174
+ f"The user desires approximately {num_points} points for the ADRF curve.",
175
+ f"The user's research query is: '{query if query else 'Not specified'}'.",
176
+ "Suggest a list of specific treatment values (t_values) at which to evaluate the ADRF. Consider:",
177
+ "1. Covering the observed range of the treatment.",
178
+ "2. Potentially including specific points of policy interest if deducible from the query (though this is advanced).",
179
+ "3. Ensuring a reasonable distribution of points (e.g., equally spaced, or based on quantiles).",
180
+ "Return your suggestion as a JSON object with a single key 't_values' holding a list of floats:",
181
+ '''
182
+ {
183
+ "t_values": [<float>, <float>, ..., <float>],
184
+ "reasoning": "<Brief justification for the choice/distribution of these t_values>"
185
+ }
186
+ '''
187
+ ]
188
+ full_prompt = "\n".join(prompt_parts)
189
+
190
+ if llm_client:
191
+ logger.info("LLM client provided. Sending prompt for t-values (hypothetical call).")
192
+ logger.debug(f"LLM Prompt for t-values:\n{full_prompt}")
193
+ # In a real implementation:
194
+ # response_json = call_llm_with_json_output(llm_client, full_prompt)
195
+ # if response_json and isinstance(response_json, dict) and 't_values' in response_json and isinstance(response_json['t_values'], list):
196
+ # return response_json['t_values'] # Assuming it returns the list directly based on current function signature
197
+ # else:
198
+ # logger.warning("LLM did not return a valid JSON with 't_values' list for ADRF points.")
199
+ pass # Pass for now
200
+
201
+ # Default: Linearly spaced points
202
+ min_t = df[treatment_var].min()
203
+ max_t = df[treatment_var].max()
204
+ if pd.isna(min_t) or pd.isna(max_t) or min_t == max_t:
205
+ logger.warning(f"Could not determine a valid range for treatment '{treatment_var}'. Returning empty list.")
206
+ return []
207
+
208
+ return list(pd.Series.linspace(min_t, max_t, num_points))
auto_causal/methods/instrumental_variable/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .estimator import estimate_effect
auto_causal/methods/instrumental_variable/diagnostics.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Placeholder for IV-specific diagnostic functions
2
+ import pandas as pd
3
+ import statsmodels.api as sm
4
+ from statsmodels.regression.linear_model import OLS
5
+ # from statsmodels.sandbox.regression.gmm import IV2SLSResults # Removed problematic import
6
+ from typing import Dict, Any, List, Tuple, Optional
7
+ import logging # Import logging
8
+ import numpy as np # Import numpy for np.zeros
9
+
10
+ # Configure logger
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def calculate_first_stage_f_statistic(df: pd.DataFrame, treatment: str, instruments: List[str], covariates: List[str]) -> Tuple[Optional[float], Optional[float]]:
14
+ """
15
+ Calculates the F-statistic for instrument relevance in the first stage regression.
16
+
17
+ Regresses treatment ~ instruments + covariates.
18
+ Tests the joint significance of the instrument coefficients.
19
+
20
+ Args:
21
+ df: Input DataFrame.
22
+ treatment: Name of the treatment variable.
23
+ instruments: List of instrument variable names.
24
+ covariates: List of covariate names.
25
+
26
+ Returns:
27
+ A tuple containing (F-statistic, p-value). Returns (None, None) on error.
28
+ """
29
+ logger.info("Diagnostics: Calculating First-Stage F-statistic...")
30
+ try:
31
+ df_copy = df.copy()
32
+ df_copy['intercept'] = 1
33
+ exog_vars = ['intercept'] + covariates
34
+ all_first_stage_exog = list(dict.fromkeys(exog_vars + instruments)) # Ensure unique columns
35
+
36
+ endog = df_copy[treatment]
37
+ exog = df_copy[all_first_stage_exog]
38
+
39
+ # Check for perfect multicollinearity before fitting
40
+ if exog.shape[1] > 1:
41
+ corr_matrix = exog.corr()
42
+ # Check if correlation matrix calculation failed (e.g., constant columns) or high correlation
43
+ if corr_matrix.isnull().values.any() or (corr_matrix.abs() > 0.9999).sum().sum() > exog.shape[1]: # Check off-diagonal elements
44
+ logger.warning("High multicollinearity or constant column detected in first stage exogenous variables.")
45
+ # Note: statsmodels OLS might handle perfect collinearity by dropping columns, but F-test might be unreliable.
46
+
47
+ first_stage_model = OLS(endog, exog).fit()
48
+
49
+ # Construct the restriction matrix (R) to test H0: instrument coeffs = 0
50
+ num_instruments = len(instruments)
51
+ if num_instruments == 0:
52
+ logger.warning("No instruments provided for F-statistic calculation.")
53
+ return None, None
54
+ num_exog_total = len(all_first_stage_exog)
55
+
56
+ # Ensure instruments are actually in the fitted model's exog names (in case statsmodels dropped some)
57
+ fitted_exog_names = first_stage_model.model.exog_names
58
+ valid_instruments = [inst for inst in instruments if inst in fitted_exog_names]
59
+ if not valid_instruments:
60
+ logger.error("None of the provided instruments were included in the first-stage regression model (possibly due to collinearity).")
61
+ return None, None
62
+ if len(valid_instruments) < len(instruments):
63
+ logger.warning(f"Instruments dropped by OLS: {set(instruments) - set(valid_instruments)}")
64
+
65
+ instrument_indices = [fitted_exog_names.index(inst) for inst in valid_instruments]
66
+
67
+ # Need to adjust R matrix size based on fitted model's exog
68
+ R = np.zeros((len(valid_instruments), len(fitted_exog_names)))
69
+ for i, idx in enumerate(instrument_indices):
70
+ R[i, idx] = 1
71
+
72
+ # Perform F-test
73
+ f_test_result = first_stage_model.f_test(R)
74
+
75
+ f_statistic = float(f_test_result.fvalue)
76
+ p_value = float(f_test_result.pvalue)
77
+
78
+ logger.info(f" F-statistic: {f_statistic:.4f}, p-value: {p_value:.4f}")
79
+ return f_statistic, p_value
80
+
81
+ except Exception as e:
82
+ logger.error(f"Error calculating first-stage F-statistic: {e}", exc_info=True)
83
+ return None, None
84
+
85
+ def run_overidentification_test(sm_results: Optional[Any], df: pd.DataFrame, treatment: str, outcome: str, instruments: List[str], covariates: List[str]) -> Tuple[Optional[float], Optional[float], Optional[str]]:
86
+ """
87
+ Runs an overidentification test (Sargan-Hansen) if applicable.
88
+
89
+ This test is only valid if the number of instruments exceeds the number
90
+ of endogenous regressors (typically 1, the treatment variable).
91
+
92
+ Requires results from a statsmodels IV estimation.
93
+
94
+ Args:
95
+ sm_results: The fitted results object from statsmodels IV2SLS.fit().
96
+ df: Input DataFrame.
97
+ treatment: Name of the treatment variable.
98
+ outcome: Name of the outcome variable.
99
+ instruments: List of instrument variable names.
100
+ covariates: List of covariate names.
101
+
102
+ Returns:
103
+ Tuple: (test_statistic, p_value, status_message) or (None, None, error_message)
104
+ """
105
+ logger.info("Diagnostics: Running Overidentification Test...")
106
+ num_instruments = len(instruments)
107
+ num_endog = 1 # Assuming only one treatment variable is endogenous
108
+
109
+ if num_instruments <= num_endog:
110
+ logger.info(" Over-ID test not applicable (model is exactly identified or underidentified).")
111
+ return None, None, "Test not applicable (Need more instruments than endogenous regressors)"
112
+
113
+ if sm_results is None or not hasattr(sm_results, 'resid'):
114
+ logger.warning(" Over-ID test requires valid statsmodels results object with residuals.")
115
+ return None, None, "Statsmodels results object not available or invalid for test."
116
+
117
+ try:
118
+ # Statsmodels IV2SLSResults does not seem to have a direct method for this test (as of common versions).
119
+ # We need to calculate it manually using residuals and instruments.
120
+ # Formula: N * R^2 from regressing residuals (u_hat) on all exogenous variables (instruments + covariates).
121
+ # Degrees of freedom = num_instruments - num_endogenous_vars
122
+
123
+ residuals = sm_results.resid
124
+ df_copy = df.copy()
125
+ df_copy['intercept'] = 1
126
+ exog_vars = ['intercept'] + covariates
127
+ all_exog_instruments = list(dict.fromkeys(exog_vars + instruments))
128
+
129
+ # Ensure columns exist in the dataframe before selecting
130
+ missing_cols = [col for col in all_exog_instruments if col not in df_copy.columns]
131
+ if missing_cols:
132
+ raise ValueError(f"Missing columns required for Over-ID test: {missing_cols}")
133
+
134
+ exog_for_test = df_copy[all_exog_instruments]
135
+
136
+ # Check shapes match after potential NA handling in main estimator
137
+ if len(residuals) != exog_for_test.shape[0]:
138
+ # Attempt to align based on index if lengths differ (might happen if NAs were dropped)
139
+ logger.warning(f"Residual length ({len(residuals)}) differs from exog_for_test rows ({exog_for_test.shape[0]}). Trying to align indices.")
140
+ common_index = residuals.index.intersection(exog_for_test.index)
141
+ if len(common_index) == 0:
142
+ raise ValueError("Cannot align residuals and exogenous variables for Over-ID test after NA handling.")
143
+ residuals = residuals.loc[common_index]
144
+ exog_for_test = exog_for_test.loc[common_index]
145
+ logger.warning(f"Aligned to {len(common_index)} common observations.")
146
+
147
+
148
+ # Regress residuals on all exogenous instruments
149
+ aux_model = OLS(residuals, exog_for_test).fit()
150
+ r_squared = aux_model.rsquared
151
+ n_obs = len(residuals) # Use length of residuals after potential alignment
152
+
153
+ test_statistic = n_obs * r_squared
154
+
155
+ # Calculate p-value from Chi-squared distribution
156
+ from scipy.stats import chi2
157
+ degrees_of_freedom = num_instruments - num_endog
158
+ if degrees_of_freedom < 0:
159
+ # This shouldn't happen if the initial check passed, but as a safeguard
160
+ raise ValueError("Degrees of freedom for Sargan test are negative.")
161
+ elif degrees_of_freedom == 0:
162
+ # R-squared should be 0 if exactly identified, but handle edge case
163
+ p_value = 1.0 if np.isclose(test_statistic, 0) else 0.0
164
+ else:
165
+ p_value = chi2.sf(test_statistic, degrees_of_freedom)
166
+
167
+ logger.info(f" Sargan Test Statistic: {test_statistic:.4f}, p-value: {p_value:.4f}, df: {degrees_of_freedom}")
168
+ return test_statistic, p_value, "Test successful"
169
+
170
+ except Exception as e:
171
+ logger.error(f"Error running overidentification test: {e}", exc_info=True)
172
+ return None, None, f"Error during test: {e}"
173
+
174
+ def run_iv_diagnostics(df: pd.DataFrame, treatment: str, outcome: str, instruments: List[str], covariates: List[str], sm_results: Optional[Any] = None, dw_results: Optional[Any] = None) -> Dict[str, Any]:
175
+ """
176
+ Runs standard IV diagnostic checks.
177
+
178
+ Args:
179
+ df: Input DataFrame.
180
+ treatment: Name of the treatment variable.
181
+ outcome: Name of the outcome variable.
182
+ instruments: List of instrument variable names.
183
+ covariates: List of covariate names.
184
+ sm_results: Optional fitted results object from statsmodels IV2SLS.fit().
185
+ dw_results: Optional results object from DoWhy (structure may vary).
186
+
187
+ Returns:
188
+ Dictionary containing diagnostic results.
189
+ """
190
+ diagnostics = {}
191
+
192
+ # 1. Instrument Relevance / Weak Instrument Test (First-Stage F-statistic)
193
+ f_stat, f_p_val = calculate_first_stage_f_statistic(df, treatment, instruments, covariates)
194
+ diagnostics['first_stage_f_statistic'] = f_stat
195
+ diagnostics['first_stage_p_value'] = f_p_val
196
+ diagnostics['is_instrument_weak'] = (f_stat < 10) if f_stat is not None else None # Common rule of thumb
197
+ if f_stat is None:
198
+ diagnostics['weak_instrument_test_status'] = "Error during calculation"
199
+ elif diagnostics['is_instrument_weak']:
200
+ diagnostics['weak_instrument_test_status'] = "Warning: Instrument(s) may be weak (F < 10)"
201
+ else:
202
+ diagnostics['weak_instrument_test_status'] = "Instrument(s) appear sufficiently strong (F >= 10)"
203
+
204
+
205
+ # 2. Overidentification Test (e.g., Sargan-Hansen)
206
+ overid_stat, overid_p_val, overid_status = run_overidentification_test(sm_results, df, treatment, outcome, instruments, covariates)
207
+ diagnostics['overid_test_statistic'] = overid_stat
208
+ diagnostics['overid_test_p_value'] = overid_p_val
209
+ diagnostics['overid_test_status'] = overid_status
210
+ diagnostics['overid_test_applicable'] = not ("not applicable" in overid_status.lower() if overid_status else True)
211
+
212
+ # 3. Exogeneity/Exclusion Restriction (Conceptual Check)
213
+ diagnostics['exclusion_restriction_assumption'] = "Assumed based on graph/input; cannot be statistically tested directly. Qualitative LLM check recommended."
214
+
215
+ # Potential future additions:
216
+ # - Endogeneity tests (e.g., Hausman test - requires comparing OLS and IV estimates)
217
+
218
+ return diagnostics
auto_causal/methods/instrumental_variable/estimator.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import statsmodels.api as sm
3
+ from statsmodels.sandbox.regression.gmm import IV2SLS
4
+ from dowhy import CausalModel # Primary path
5
+ from typing import Dict, Any, List, Union, Optional
6
+ import logging
7
+ from langchain.chat_models.base import BaseChatModel
8
+
9
+ from .diagnostics import run_iv_diagnostics
10
+ from .llm_assist import identify_instrument_variable, validate_instrument_assumptions_qualitative, interpret_iv_results
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def build_iv_graph_gml(treatment: str, outcome: str, instruments: List[str], covariates: List[str]) -> str:
15
+ """
16
+ Constructs a GML string representing the causal graph for IV.
17
+
18
+ Assumptions:
19
+ - Instruments cause Treatment
20
+ - Covariates cause Treatment and Outcome
21
+ - Treatment causes Outcome
22
+ - Instruments do NOT directly cause Outcome (Exclusion)
23
+ - Instruments are NOT caused by Covariates (can be relaxed if needed)
24
+ - Unobserved Confounder (U) affects Treatment and Outcome
25
+
26
+ Args:
27
+ treatment: Name of the treatment variable.
28
+ outcome: Name of the outcome variable.
29
+ instruments: List of instrument variable names.
30
+ covariates: List of covariate names.
31
+
32
+ Returns:
33
+ A GML graph string.
34
+ """
35
+ nodes = []
36
+ edges = []
37
+
38
+ # Define nodes - ensure no duplicates if a variable is both instrument and covariate (SHOULD NOT HAPPEN)
39
+ # Use a set to ensure unique variable names
40
+ all_vars_set = set([treatment, outcome] + instruments + covariates + ['U'])
41
+ all_vars = list(all_vars_set)
42
+
43
+ for var in all_vars:
44
+ nodes.append(f'node [ id "{var}" label "{var}" ]')
45
+
46
+ # Define edges
47
+ # Instruments -> Treatment
48
+ for inst in instruments:
49
+ edges.append(f'edge [ source "{inst}" target "{treatment}" ]')
50
+
51
+ # Covariates -> Treatment
52
+ for cov in covariates:
53
+ # Ensure we don't add self-loops or duplicate edges if cov == treatment (shouldn't happen)
54
+ if cov != treatment:
55
+ edges.append(f'edge [ source "{cov}" target "{treatment}" ]')
56
+
57
+ # Covariates -> Outcome
58
+ for cov in covariates:
59
+ if cov != outcome:
60
+ edges.append(f'edge [ source "{cov}" target "{outcome}" ]')
61
+
62
+ # Treatment -> Outcome
63
+ edges.append(f'edge [ source "{treatment}" target "{outcome}" ]')
64
+
65
+ # Unobserved Confounder -> Treatment and Outcome
66
+ edges.append(f'edge [ source "U" target "{treatment}" ]')
67
+ edges.append(f'edge [ source "U" target "{outcome}" ]')
68
+
69
+ # Core IV Assumption: Instruments are NOT caused by U (implicitly handled by not adding edge)
70
+ # Core IV Assumption: Instruments do NOT directly cause Outcome (handled by not adding edge)
71
+
72
+ # Format nodes and edges with indentation before inserting into f-string
73
+ formatted_nodes = '\n '.join(nodes)
74
+ formatted_edges = '\n '.join(edges)
75
+
76
+ gml_string = f"""
77
+ graph [
78
+ directed 1
79
+ {formatted_nodes}
80
+ {formatted_edges}
81
+ ]
82
+ """
83
+ # Convert print to logger
84
+ logger.debug("\n--- Generated GML Graph ---")
85
+ logger.debug(gml_string)
86
+ logger.debug("-------------------------\n")
87
+ return gml_string
88
+
89
+ def format_iv_results(estimate: Optional[float], raw_results: Dict, diagnostics: Dict, treatment: str, outcome: str, instrument: List[str], method_used: str, llm: Optional[BaseChatModel] = None) -> Dict[str, Any]:
90
+ """
91
+ Formats the results from IV estimation into a standardized dictionary.
92
+
93
+ Args:
94
+ estimate: The point estimate of the causal effect.
95
+ raw_results: Dictionary containing raw outputs from DoWhy/statsmodels.
96
+ diagnostics: Dictionary containing diagnostic results.
97
+ treatment: Name of the treatment variable.
98
+ outcome: Name of the outcome variable.
99
+ instrument: List of instrument variable names.
100
+ method_used: 'dowhy' or 'statsmodels'.
101
+ llm: Optional LLM instance for interpretation.
102
+
103
+ Returns:
104
+ Standardized results dictionary.
105
+ """
106
+ formatted = {
107
+ "effect_estimate": estimate,
108
+ "treatment_variable": treatment,
109
+ "outcome_variable": outcome,
110
+ "instrument_variables": instrument,
111
+ "method_used": method_used,
112
+ "diagnostics": diagnostics,
113
+ "raw_results": {k: str(v) for k, v in raw_results.items() if "object" not in k}, # Avoid serializing large objects
114
+ "confidence_interval": None,
115
+ "standard_error": None,
116
+ "p_value": None,
117
+ "interpretation": "Placeholder"
118
+ }
119
+
120
+ # Extract details from statsmodels results if available
121
+ sm_results = raw_results.get('statsmodels_results_object')
122
+ if method_used == 'statsmodels' and sm_results:
123
+ try:
124
+ # Use .bse for standard error in statsmodels results
125
+ formatted["standard_error"] = float(sm_results.bse[treatment])
126
+ formatted["p_value"] = float(sm_results.pvalues[treatment])
127
+ conf_int = sm_results.conf_int().loc[treatment].tolist()
128
+ formatted["confidence_interval"] = [float(ci) for ci in conf_int]
129
+ except AttributeError as e:
130
+ logger.warning(f"Could not extract all details from statsmodels results object (likely missing attribute): {e}")
131
+ except Exception as e:
132
+ logger.warning(f"Error extracting details from statsmodels results: {e}")
133
+
134
+ # Extract details from DoWhy results if available
135
+ # Note: DoWhy's CausalEstimate object structure needs inspection
136
+ dw_results = raw_results.get('dowhy_results_object')
137
+ if method_used == 'dowhy' and dw_results:
138
+ try:
139
+ # Attempt common attributes, may need adjustment based on DoWhy version/output
140
+ if hasattr(dw_results, 'stderr'):
141
+ formatted["standard_error"] = float(dw_results.stderr)
142
+ if hasattr(dw_results, 'p_value'):
143
+ formatted["p_value"] = float(dw_results.p_value)
144
+ if hasattr(dw_results, 'conf_intervals'):
145
+ # Assuming it's stored similarly to statsmodels, might need adjustment
146
+ ci = dw_results.conf_intervals().loc[treatment].tolist() # Fictional attribute/method - check DoWhy docs!
147
+ formatted["confidence_interval"] = [float(c) for c in ci]
148
+ elif hasattr(dw_results, 'get_confidence_intervals'):
149
+ ci = dw_results.get_confidence_intervals() # Check DoWhy docs for format
150
+ # Check format of ci before converting
151
+ if isinstance(ci, (list, tuple)) and len(ci) == 2:
152
+ formatted["confidence_interval"] = [float(c) for c in ci] # Adapt parsing
153
+ else:
154
+ logger.warning(f"Could not parse confidence intervals from DoWhy object: {ci}")
155
+
156
+ except Exception as e:
157
+ logger.warning(f"Could not extract all details from DoWhy results: {e}. Structure might be different.", exc_info=True)
158
+ # Avoid printing dir in production code, use logger.debug if needed for dev
159
+ # logger.debug(f"DoWhy result object dir(): {dir(dw_results)}")
160
+
161
+ # Generate LLM interpretation - pass llm object
162
+ if estimate is not None:
163
+ formatted["interpretation"] = interpret_iv_results(formatted, diagnostics, llm=llm)
164
+ else:
165
+ formatted["interpretation"] = "Estimation failed, cannot interpret results."
166
+
167
+
168
+ return formatted
169
+
170
+ def estimate_effect(
171
+ df: pd.DataFrame,
172
+ treatment: str,
173
+ outcome: str,
174
+ covariates: List[str],
175
+ query: Optional[str] = None,
176
+ dataset_description: Optional[str] = None,
177
+ llm: Optional[BaseChatModel] = None,
178
+ **kwargs
179
+ ) -> Dict[str, Any]:
180
+
181
+ instrument = kwargs.get('instrument_variable')
182
+ if not instrument:
183
+ return {"error": "Instrument variable ('instrument_variable') not found in kwargs.", "method_used": "none", "diagnostics": {}}
184
+
185
+ instrument_list = [instrument] if isinstance(instrument, str) else instrument
186
+ valid_instruments = [inst for inst in instrument_list if isinstance(inst, str)]
187
+ clean_covariates = [cov for cov in covariates if cov not in valid_instruments]
188
+
189
+ logger.info(f"\n--- Starting Instrumental Variable Estimation ---")
190
+ logger.info(f"Treatment: {treatment}, Outcome: {outcome}, Instrument(s): {valid_instruments}, Original Covariates: {covariates}, Cleaned Covariates: {clean_covariates}")
191
+ results = {}
192
+ method_used = "none"
193
+ sm_results_obj = None
194
+ dw_results_obj = None
195
+ identified_estimand = None # Initialize
196
+ model = None # Initialize
197
+ refutation_results = {} # Initialize
198
+
199
+ # --- Input Validation ---
200
+ required_cols = [treatment, outcome] + valid_instruments + clean_covariates
201
+ missing_cols = [col for col in required_cols if col not in df.columns]
202
+ if missing_cols:
203
+ return {"error": f"Missing required columns in DataFrame: {missing_cols}", "method_used": method_used, "diagnostics": {}}
204
+ if not valid_instruments:
205
+ return {"error": "Instrument variable(s) must be provided and valid.", "method_used": method_used, "diagnostics": {}}
206
+
207
+ # --- LLM Pre-Checks ---
208
+ if query and llm:
209
+ qualitative_check = validate_instrument_assumptions_qualitative(treatment, outcome, valid_instruments, clean_covariates, query, llm=llm)
210
+ results['llm_assumption_check'] = qualitative_check
211
+ logger.info(f"LLM Qualitative Assumption Check: {qualitative_check}")
212
+
213
+ # --- Build Graph and Instantiate CausalModel (Do this before estimation attempts) ---
214
+ # This allows using identify_effect and refute_estimate even if DoWhy estimation fails
215
+ try:
216
+ graph = build_iv_graph_gml(treatment, outcome, valid_instruments, clean_covariates)
217
+ if not graph:
218
+ raise ValueError("Failed to build GML graph for DoWhy.")
219
+
220
+ model = CausalModel(data=df, treatment=treatment, outcome=outcome, graph=graph)
221
+
222
+ # Identify Effect (essential for refutation later)
223
+ identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
224
+ logger.debug("\nDoWhy Identified Estimand:")
225
+ logger.debug(identified_estimand)
226
+ if not identified_estimand:
227
+ raise ValueError("DoWhy could not identify a valid estimand.")
228
+
229
+ except Exception as model_init_e:
230
+ logger.error(f"Failed to initialize CausalModel or identify effect: {model_init_e}", exc_info=True)
231
+ # Cannot proceed without model/estimand for DoWhy or refutation
232
+ results['error'] = f"Failed to initialize CausalModel: {model_init_e}"
233
+ # Attempt statsmodels anyway? Or return error? Let's try statsmodels.
234
+ pass # Allow falling through to statsmodels if desired
235
+
236
+ # --- Primary Path: DoWhy Estimation ---
237
+ if model and identified_estimand and not kwargs.get('force_statsmodels', False):
238
+ logger.info("\nAttempting estimation with DoWhy...")
239
+ try:
240
+ dw_results_obj = model.estimate_effect(
241
+ identified_estimand,
242
+ method_name="iv.instrumental_variable",
243
+ method_params={'iv_instrument_name': valid_instruments}
244
+ )
245
+ logger.debug("\nDoWhy Estimation Result:")
246
+ logger.debug(dw_results_obj)
247
+ results['dowhy_estimate'] = dw_results_obj.value
248
+ results['dowhy_results_object'] = dw_results_obj
249
+ method_used = 'dowhy'
250
+ logger.info("DoWhy estimation successful.")
251
+ except Exception as e:
252
+ logger.error(f"DoWhy IV estimation failed: {e}", exc_info=True)
253
+ results['dowhy_error'] = str(e)
254
+ if not kwargs.get('allow_fallback', True):
255
+ logger.warning("Fallback to statsmodels disabled. Estimation failed.")
256
+ method_used = "dowhy_failed"
257
+ # Still run diagnostics and format output
258
+ else:
259
+ logger.info("Proceeding to statsmodels fallback.")
260
+ elif not model or not identified_estimand:
261
+ logger.warning("Skipping DoWhy estimation due to CausalModel initialization/identification failure.")
262
+ # Ensure we proceed to statsmodels if fallback is allowed
263
+ if not kwargs.get('allow_fallback', True):
264
+ logger.error("Cannot estimate effect: CausalModel failed and fallback disabled.")
265
+ method_used = "dowhy_failed"
266
+ else:
267
+ logger.info("Proceeding to statsmodels fallback.")
268
+
269
+ # --- Fallback Path: statsmodels IV2SLS ---
270
+ if method_used not in ['dowhy', 'dowhy_failed']:
271
+ logger.info("\nAttempting estimation with statsmodels IV2SLS...")
272
+ try:
273
+ df_copy = df.copy().dropna(subset=required_cols)
274
+ if df_copy.empty:
275
+ raise ValueError("DataFrame becomes empty after dropping NAs in required columns.")
276
+ df_copy['intercept'] = 1
277
+ exog_regressors = ['intercept'] + clean_covariates
278
+ endog_var = treatment
279
+ all_instruments_for_sm = list(dict.fromkeys(exog_regressors + valid_instruments))
280
+ endog_data = df_copy[outcome]
281
+ exog_data_sm_cols = list(dict.fromkeys(exog_regressors + [endog_var]))
282
+ exog_data_sm = df_copy[exog_data_sm_cols]
283
+ instrument_data_sm = df_copy[all_instruments_for_sm]
284
+ num_endog = 1
285
+ num_external_iv = len(valid_instruments)
286
+ if num_endog > num_external_iv:
287
+ raise ValueError(f"Model underidentified: More endogenous regressors ({num_endog}) than unique external instruments ({num_external_iv}).")
288
+ iv_model = IV2SLS(endog=endog_data, exog=exog_data_sm, instrument=instrument_data_sm)
289
+ sm_results_obj = iv_model.fit()
290
+ logger.info("\nStatsmodels Estimation Summary:")
291
+ logger.info(f" Estimate for {treatment}: {sm_results_obj.params[treatment]}")
292
+ logger.info(f" Std Error: {sm_results_obj.bse[treatment]}")
293
+ logger.info(f" P-value: {sm_results_obj.pvalues[treatment]}")
294
+ results['statsmodels_estimate'] = sm_results_obj.params[treatment]
295
+ results['statsmodels_results_object'] = sm_results_obj
296
+ method_used = 'statsmodels'
297
+ logger.info("Statsmodels estimation successful.")
298
+ except Exception as sm_e:
299
+ logger.error(f"Statsmodels IV estimation also failed: {sm_e}", exc_info=True)
300
+ results['statsmodels_error'] = str(sm_e)
301
+ method_used = 'statsmodels_failed' if method_used == "none" else "dowhy_failed_sm_failed"
302
+
303
+ # --- Diagnostics ---
304
+ logger.info("\nRunning diagnostics...")
305
+ diagnostics = run_iv_diagnostics(df, treatment, outcome, valid_instruments, clean_covariates, sm_results_obj, dw_results_obj)
306
+ results['diagnostics'] = diagnostics
307
+
308
+ # --- Refutation Step ---
309
+ final_estimate_value = results.get('dowhy_estimate') if method_used == 'dowhy' else results.get('statsmodels_estimate')
310
+
311
+ # Only run permute refuter if estimate is valid AND came from DoWhy
312
+ if method_used == 'dowhy' and dw_results_obj and final_estimate_value is not None:
313
+ logger.info("\nRunning refutation test (Placebo Treatment - Permute - requires DoWhy estimate object)...")
314
+ try:
315
+ # Pass the actual DoWhy estimate object
316
+ refuter_result = model.refute_estimate(
317
+ identified_estimand,
318
+ dw_results_obj, # Pass the original DoWhy result object
319
+ method_name="placebo_treatment_refuter",
320
+ placebo_type="permute" # Necessary for IV according to docs/examples
321
+ )
322
+ logger.info("Refutation test completed.")
323
+ logger.debug(f"Refuter Result:\n{refuter_result}")
324
+ # Store relevant info from refuter_result (check its structure)
325
+ refutation_results = {
326
+ "refuter": "placebo_treatment_refuter",
327
+ "new_effect": getattr(refuter_result, 'new_effect', 'N/A'),
328
+ "p_value": getattr(refuter_result, 'refutation_result', {}).get('p_value', 'N/A') if hasattr(refuter_result, 'refutation_result') else 'N/A',
329
+ # Passed if p-value > 0.05 (or not statistically significant)
330
+ "passed": getattr(refuter_result, 'refutation_result', {}).get('is_statistically_significant', None) == False if hasattr(refuter_result, 'refutation_result') else None
331
+ }
332
+ except Exception as refute_e:
333
+ logger.error(f"Refutation test failed: {refute_e}", exc_info=True)
334
+ refutation_results = {"error": f"Refutation failed: {refute_e}"}
335
+
336
+ elif final_estimate_value is not None and method_used == 'statsmodels':
337
+ logger.warning("Skipping placebo permutation refuter: Estimate was generated by statsmodels, not DoWhy's IV estimator.")
338
+ refutation_results = {"status": "skipped_wrong_estimator_for_permute"}
339
+
340
+ elif final_estimate_value is None:
341
+ logger.warning("Skipping refutation test because estimation failed.")
342
+ refutation_results = {"status": "skipped_due_to_failed_estimation"}
343
+
344
+ else: # Model or estimand failed earlier, or unknown method_used
345
+ logger.warning(f"Skipping refutation test due to earlier failure (method_used: {method_used}).")
346
+ refutation_results = {"status": "skipped_due_to_model_failure_or_unknown"}
347
+
348
+ results['refutation_results'] = refutation_results # Add to main results
349
+
350
+ # --- Formatting Results ---
351
+ if final_estimate_value is None and method_used not in ['dowhy', 'statsmodels']:
352
+ logger.error("ERROR: Both estimation methods failed.")
353
+ # Ensure error key exists if not set earlier
354
+ if 'error' not in results:
355
+ results['error'] = "Both DoWhy and statsmodels IV estimation failed."
356
+
357
+ logger.info("\n--- Formatting Final Results ---")
358
+ formatted_results = format_iv_results(
359
+ final_estimate_value, # Pass the numeric value
360
+ results, # Pass the dict containing estimate objects and refutation results
361
+ diagnostics,
362
+ treatment,
363
+ outcome,
364
+ valid_instruments,
365
+ method_used,
366
+ llm=llm
367
+ )
368
+
369
+ logger.info("--- Instrumental Variable Estimation Complete ---\n")
370
+ return formatted_results
auto_causal/methods/instrumental_variable/llm_assist.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM assistance functions for Instrumental Variable (IV) analysis.
3
+
4
+ This module provides functions for LLM-based assistance in instrumental variable analysis,
5
+ including identifying potential instruments, validating IV assumptions, and interpreting results.
6
+ """
7
+
8
+ from typing import List, Dict, Any, Optional
9
+ import logging
10
+
11
+ # Imported for type hinting
12
+ from langchain.chat_models.base import BaseChatModel
13
+
14
+ # Import shared LLM helpers
15
+ from auto_causal.utils.llm_helpers import call_llm_with_json_output
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ def identify_instrument_variable(
20
+ df_cols: List[str],
21
+ query: str,
22
+ llm: Optional[BaseChatModel] = None
23
+ ) -> List[str]:
24
+ """
25
+ Use LLM to identify potential instrumental variables from available columns.
26
+
27
+ Args:
28
+ df_cols: List of column names from the dataset
29
+ query: User's causal query text
30
+ llm: Optional LLM model instance
31
+
32
+ Returns:
33
+ List of column names identified as potential instruments
34
+ """
35
+ if llm is None:
36
+ logger.warning("No LLM provided for instrument identification")
37
+ return []
38
+
39
+ prompt = f"""
40
+ You are assisting with an instrumental variable analysis.
41
+
42
+ Available columns in the dataset: {df_cols}
43
+ User query: {query}
44
+
45
+ Identify potential instrumental variable(s) from the available columns based on the query.
46
+ The treatment and outcome should NOT be included as instruments.
47
+
48
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
49
+ {{
50
+ "potential_instruments": ["column_name1", "column_name2", ...]
51
+ }}
52
+ """
53
+
54
+ response = call_llm_with_json_output(llm, prompt)
55
+
56
+ if response and "potential_instruments" in response and isinstance(response["potential_instruments"], list):
57
+ # Basic validation: ensure items are strings (column names)
58
+ valid_instruments = [item for item in response["potential_instruments"] if isinstance(item, str)]
59
+ if len(valid_instruments) != len(response["potential_instruments"]):
60
+ logger.warning("LLM returned non-string items in potential_instruments list.")
61
+ return valid_instruments
62
+
63
+ logger.warning(f"Failed to get valid instrument recommendations from LLM. Response: {response}")
64
+ return []
65
+
66
+ def validate_instrument_assumptions_qualitative(
67
+ treatment: str,
68
+ outcome: str,
69
+ instrument: List[str],
70
+ covariates: List[str],
71
+ query: str,
72
+ llm: Optional[BaseChatModel] = None
73
+ ) -> Dict[str, str]:
74
+ """
75
+ Use LLM to provide qualitative assessment of IV assumptions.
76
+
77
+ Args:
78
+ treatment: Treatment variable name
79
+ outcome: Outcome variable name
80
+ instrument: List of instrumental variable names
81
+ covariates: List of covariate variable names
82
+ query: User's causal query text
83
+ llm: Optional LLM model instance
84
+
85
+ Returns:
86
+ Dictionary with qualitative assessments of exclusion and exogeneity assumptions
87
+ """
88
+ default_fail = {
89
+ "exclusion_assessment": "LLM Check Failed",
90
+ "exogeneity_assessment": "LLM Check Failed"
91
+ }
92
+
93
+ if llm is None:
94
+ return {
95
+ "exclusion_assessment": "LLM Not Provided",
96
+ "exogeneity_assessment": "LLM Not Provided"
97
+ }
98
+
99
+ prompt = f"""
100
+ You are assisting with assessing the validity of instrumental variable assumptions.
101
+
102
+ Treatment variable: {treatment}
103
+ Outcome variable: {outcome}
104
+ Instrumental variable(s): {instrument}
105
+ Covariates: {covariates}
106
+ User query: {query}
107
+
108
+ Assess the core Instrumental Variable (IV) assumptions based *only* on the provided variable names and query context:
109
+ 1. Exclusion restriction: Plausibility that the instrument(s) affect the outcome ONLY through the treatment.
110
+ 2. Exogeneity (also called Independence): Plausibility that the instrument(s) are not correlated with unobserved confounders that also affect the outcome.
111
+
112
+ Provide a brief, qualitative assessment (e.g., 'Plausible', 'Unlikely', 'Requires Domain Knowledge', 'Potentially Violated').
113
+
114
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
115
+ {{
116
+ "exclusion_assessment": "<brief assessment of exclusion restriction>",
117
+ "exogeneity_assessment": "<brief assessment of exogeneity assumption>"
118
+ }}
119
+ """
120
+
121
+ response = call_llm_with_json_output(llm, prompt)
122
+
123
+ if response and isinstance(response, dict) and \
124
+ "exclusion_assessment" in response and isinstance(response["exclusion_assessment"], str) and \
125
+ "exogeneity_assessment" in response and isinstance(response["exogeneity_assessment"], str):
126
+ return response
127
+
128
+ logger.warning(f"Failed to get valid assumption assessment from LLM. Response: {response}")
129
+ return default_fail
130
+
131
+ def interpret_iv_results(
132
+ results: Dict[str, Any],
133
+ diagnostics: Dict[str, Any],
134
+ llm: Optional[BaseChatModel] = None
135
+ ) -> str:
136
+ """
137
+ Use LLM to interpret IV results in natural language.
138
+
139
+ Args:
140
+ results: Dictionary of estimation results (e.g., effect_estimate, p_value, confidence_interval)
141
+ diagnostics: Dictionary of diagnostic test results (e.g., first_stage_f_statistic, overid_test)
142
+ llm: Optional LLM model instance
143
+
144
+ Returns:
145
+ String containing natural language interpretation of results
146
+ """
147
+ if llm is None:
148
+ return "LLM was not available to provide interpretation. Please review the numeric results manually."
149
+
150
+ # Construct a concise summary of inputs for the prompt
151
+ results_summary = {}
152
+
153
+ effect = results.get('effect_estimate')
154
+ if effect is not None:
155
+ try:
156
+ results_summary['Effect Estimate'] = f"{float(effect):.3f}"
157
+ except (ValueError, TypeError):
158
+ results_summary['Effect Estimate'] = 'N/A (Invalid Format)'
159
+ else:
160
+ results_summary['Effect Estimate'] = 'N/A'
161
+
162
+ p_value = results.get('p_value')
163
+ if p_value is not None:
164
+ try:
165
+ results_summary['P-value'] = f"{float(p_value):.3f}"
166
+ except (ValueError, TypeError):
167
+ results_summary['P-value'] = 'N/A (Invalid Format)'
168
+ else:
169
+ results_summary['P-value'] = 'N/A'
170
+
171
+ ci = results.get('confidence_interval')
172
+ if ci is not None and isinstance(ci, (list, tuple)) and len(ci) == 2:
173
+ try:
174
+ results_summary['Confidence Interval'] = f"[{float(ci[0]):.3f}, {float(ci[1]):.3f}]"
175
+ except (ValueError, TypeError):
176
+ results_summary['Confidence Interval'] = 'N/A (Invalid Format)'
177
+ else:
178
+ # Handle cases where CI is None or not a 2-element list/tuple
179
+ results_summary['Confidence Interval'] = str(ci) if ci is not None else 'N/A'
180
+
181
+ if 'treatment_variable' in results:
182
+ results_summary['Treatment'] = results['treatment_variable']
183
+ if 'outcome_variable' in results:
184
+ results_summary['Outcome'] = results['outcome_variable']
185
+
186
+ diagnostics_summary = {}
187
+ f_stat = diagnostics.get('first_stage_f_statistic')
188
+ if f_stat is not None:
189
+ try:
190
+ diagnostics_summary['First-Stage F-statistic'] = f"{float(f_stat):.2f}"
191
+ except (ValueError, TypeError):
192
+ diagnostics_summary['First-Stage F-statistic'] = 'N/A (Invalid Format)'
193
+ else:
194
+ diagnostics_summary['First-Stage F-statistic'] = 'N/A'
195
+
196
+ if 'weak_instrument_test_status' in diagnostics:
197
+ diagnostics_summary['Weak Instrument Test'] = diagnostics['weak_instrument_test_status']
198
+
199
+ overid_p = diagnostics.get('overid_test_p_value')
200
+ if overid_p is not None:
201
+ try:
202
+ diagnostics_summary['Overidentification Test P-value'] = f"{float(overid_p):.3f}"
203
+ diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A')
204
+ except (ValueError, TypeError):
205
+ diagnostics_summary['Overidentification Test P-value'] = 'N/A (Invalid Format)'
206
+ diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A')
207
+ else:
208
+ # Explicitly state if not applicable or not available
209
+ if diagnostics.get('overid_test_applicable') == False:
210
+ diagnostics_summary['Overidentification Test'] = 'Not Applicable'
211
+ else:
212
+ diagnostics_summary['Overidentification Test P-value'] = 'N/A'
213
+ diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A')
214
+
215
+ prompt = f"""
216
+ You are assisting with interpreting instrumental variable (IV) analysis results.
217
+
218
+ Estimation results summary: {results_summary}
219
+ Diagnostic test results summary: {diagnostics_summary}
220
+
221
+ Explain these Instrumental Variable (IV) results in clear, concise language (2-4 sentences).
222
+ Focus on:
223
+ 1. The estimated causal effect (magnitude, direction, statistical significance based on p-value < 0.05).
224
+ 2. The strength of the instrument(s) (based on F-statistic, typically > 10 indicates strength).
225
+ 3. Any implications from other diagnostic tests (e.g., overidentification test suggesting instrument validity issues if p < 0.05).
226
+
227
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
228
+ {{
229
+ "interpretation": "<your concise interpretation text>"
230
+ }}
231
+ """
232
+
233
+ response = call_llm_with_json_output(llm, prompt)
234
+
235
+ if response and isinstance(response, dict) and \
236
+ "interpretation" in response and isinstance(response["interpretation"], str):
237
+ return response["interpretation"]
238
+
239
+ logger.warning(f"Failed to get valid interpretation from LLM. Response: {response}")
240
+ return "LLM interpretation could not be generated. Please review the numeric results manually."
auto_causal/methods/linear_regression/__init__.py ADDED
File without changes
auto_causal/methods/linear_regression/diagnostics.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diagnostic checks for Linear Regression models.
3
+ """
4
+
5
+ from typing import Dict, Any
6
+ import statsmodels.api as sm
7
+ from statsmodels.stats.diagnostic import het_breuschpagan, normal_ad
8
+ from statsmodels.stats.stattools import jarque_bera
9
+ from statsmodels.regression.linear_model import RegressionResultsWrapper
10
+ import pandas as pd
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def run_lr_diagnostics(results: RegressionResultsWrapper, X: pd.DataFrame) -> Dict[str, Any]:
16
+ """
17
+ Runs diagnostic checks on a fitted OLS model.
18
+
19
+ Args:
20
+ results: A fitted statsmodels OLS results object.
21
+ X: The design matrix (including constant) used for the regression.
22
+ Needed for heteroskedasticity tests.
23
+
24
+ Returns:
25
+ Dictionary containing diagnostic metrics.
26
+ """
27
+
28
+ diagnostics = {}
29
+
30
+ try:
31
+ diagnostics['r_squared'] = results.rsquared
32
+ diagnostics['adj_r_squared'] = results.rsquared_adj
33
+ diagnostics['f_statistic'] = results.fvalue
34
+ diagnostics['f_p_value'] = results.f_pvalue
35
+ diagnostics['n_observations'] = int(results.nobs)
36
+ diagnostics['degrees_of_freedom_resid'] = int(results.df_resid)
37
+
38
+ # --- Normality of Residuals (Jarque-Bera) ---
39
+ try:
40
+ jb_value, jb_p_value, skew, kurtosis = jarque_bera(results.resid)
41
+ diagnostics['residuals_normality_jb_stat'] = jb_value
42
+ diagnostics['residuals_normality_jb_p_value'] = jb_p_value
43
+ diagnostics['residuals_skewness'] = skew
44
+ diagnostics['residuals_kurtosis'] = kurtosis
45
+ diagnostics['residuals_normality_status'] = "Normal" if jb_p_value > 0.05 else "Non-Normal"
46
+ except Exception as e:
47
+ logger.warning(f"Could not run Jarque-Bera test: {e}")
48
+ diagnostics['residuals_normality_status'] = "Test Failed"
49
+
50
+ # --- Homoscedasticity (Breusch-Pagan) ---
51
+ # Requires the design matrix X used in the model fitting
52
+ try:
53
+ lm_stat, lm_p_value, f_stat, f_p_value = het_breuschpagan(results.resid, X)
54
+ diagnostics['homoscedasticity_bp_lm_stat'] = lm_stat
55
+ diagnostics['homoscedasticity_bp_lm_p_value'] = lm_p_value
56
+ diagnostics['homoscedasticity_bp_f_stat'] = f_stat
57
+ diagnostics['homoscedasticity_bp_f_p_value'] = f_p_value
58
+ diagnostics['homoscedasticity_status'] = "Homoscedastic" if lm_p_value > 0.05 else "Heteroscedastic"
59
+ except Exception as e:
60
+ logger.warning(f"Could not run Breusch-Pagan test: {e}")
61
+ diagnostics['homoscedasticity_status'] = "Test Failed"
62
+
63
+ # --- Linearity (Basic check - often requires visual inspection) ---
64
+ # No standard quantitative test implemented here. Usually assessed via residual plots.
65
+ diagnostics['linearity_check'] = "Requires visual inspection (e.g., residual vs fitted plot)"
66
+
67
+ # --- Multicollinearity (Placeholder - requires VIF calculation) ---
68
+ # VIF requires iterating through predictors, more involved
69
+ diagnostics['multicollinearity_check'] = "Not Implemented (Requires VIF)"
70
+
71
+ return {"status": "Success", "details": diagnostics}
72
+
73
+ except Exception as e:
74
+ logger.error(f"Error running LR diagnostics: {e}")
75
+ return {"status": "Failed", "error": str(e), "details": diagnostics} # Return partial results if possible
76
+
auto_causal/methods/linear_regression/estimator.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Linear Regression Estimator for Causal Inference.
3
+
4
+ Uses Ordinary Least Squares (OLS) to estimate the treatment effect, potentially
5
+ adjusting for covariates.
6
+ """
7
+ import pandas as pd
8
+ import statsmodels.api as sm
9
+ import statsmodels.formula.api as smf
10
+ from typing import Dict, Any, List, Optional, Union
11
+ import logging
12
+ from langchain.chat_models.base import BaseChatModel
13
+ import re
14
+ import json
15
+ from pydantic import BaseModel, ValidationError
16
+ from langchain_core.messages import HumanMessage
17
+ from langchain_core.exceptions import OutputParserException
18
+
19
+
20
+ from auto_causal.models import LLMIdentifiedRelevantParams
21
+ from auto_causal.prompts.regression_prompts import STATSMODELS_PARAMS_IDENTIFICATION_PROMPT_TEMPLATE
22
+ from auto_causal.config import get_llm_client
23
+
24
+ # Placeholder for potential future LLM assistance integration
25
+ # from .llm_assist import interpret_lr_results, suggest_lr_covariates
26
+ # Placeholder for potential future diagnostics integration
27
+ # from .diagnostics import run_lr_diagnostics
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel) -> Optional[BaseModel]:
32
+ """Helper to call LLM with structured output and handle errors."""
33
+ try:
34
+ messages = [HumanMessage(content=prompt)]
35
+ structured_llm = llm.with_structured_output(pydantic_model)
36
+ parsed_result = structured_llm.invoke(messages)
37
+ return parsed_result
38
+ except (OutputParserException, ValidationError) as e:
39
+ logger.error(f"LLM call failed parsing/validation for {pydantic_model.__name__}: {e}")
40
+ except Exception as e:
41
+ logger.error(f"LLM call failed unexpectedly for {pydantic_model.__name__}: {e}", exc_info=True)
42
+ return None
43
+
44
+ # Define module-level helper function
45
+ def _clean_variable_name_for_patsy_local(name: str) -> str:
46
+ if not isinstance(name, str):
47
+ name = str(name)
48
+ name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
49
+ if not re.match(r'^[a-zA-Z_]', name):
50
+ name = 'var_' + name
51
+ return name
52
+
53
+
54
+ def estimate_effect(
55
+ df: pd.DataFrame,
56
+ treatment: str,
57
+ outcome: str,
58
+ covariates: Optional[List[str]] = None,
59
+ query_str: Optional[str] = None, # For potential LLM use
60
+ llm: Optional[BaseChatModel] = None, # For potential LLM use
61
+ **kwargs # To capture any other potential arguments
62
+ ) -> Dict[str, Any]:
63
+ """
64
+ Estimates the causal effect using Linear Regression (OLS).
65
+
66
+ Args:
67
+ df: Input DataFrame.
68
+ treatment: Name of the treatment variable column.
69
+ outcome: Name of the outcome variable column.
70
+ covariates: Optional list of covariate names.
71
+ query_str: Optional user query for context (e.g., for LLM).
72
+ llm: Optional Language Model instance.
73
+ **kwargs: Additional keyword arguments.
74
+
75
+ Returns:
76
+ Dictionary containing estimation results:
77
+ - 'effect_estimate': The estimated coefficient for the treatment variable.
78
+ - 'p_value': The p-value associated with the treatment coefficient.
79
+ - 'confidence_interval': The 95% confidence interval for the effect.
80
+ - 'standard_error': The standard error of the treatment coefficient.
81
+ - 'formula': The regression formula used.
82
+ - 'model_summary': Summary object from statsmodels.
83
+ - 'diagnostics': Placeholder for diagnostic results.
84
+ - 'interpretation': Placeholder for LLM interpretation.
85
+ """
86
+ if covariates is None:
87
+ covariates = []
88
+
89
+ # Retrieve additional args from kwargs
90
+ interaction_term_suggested = kwargs.get('interaction_term_suggested', False)
91
+ # interaction_variable_candidate is the *original* name from query_interpreter
92
+ interaction_variable_candidate_orig_name = kwargs.get('interaction_variable_candidate')
93
+ treatment_reference_level = kwargs.get('treatment_reference_level')
94
+ column_mappings = kwargs.get('column_mappings', {})
95
+
96
+ required_cols = [treatment, outcome] + covariates
97
+ # If interaction variable is suggested, ensure it (or its processed form) is in df for analysis
98
+ # This check is complex here as interaction_variable_candidate_orig_name needs mapping to processed column(s)
99
+ # We'll rely on df_analysis.dropna() and formula construction to handle missing interaction var columns later
100
+
101
+ missing_cols = [col for col in required_cols if col not in df.columns]
102
+ if missing_cols:
103
+ raise ValueError(f"Missing required columns: {missing_cols}")
104
+
105
+ # Prepare data for statsmodels (add constant, handle potential NaNs)
106
+ df_analysis = df[required_cols].dropna()
107
+ if df_analysis.empty:
108
+ raise ValueError("No data remaining after dropping NaNs for required columns.")
109
+
110
+ X = df_analysis[[treatment] + covariates]
111
+ X = sm.add_constant(X) # Add intercept
112
+ y = df_analysis[outcome]
113
+
114
+ # --- Formula Construction ---
115
+ outcome_col_name = outcome # Name in processed df
116
+ treatment_col_name = treatment # Name in processed df
117
+ processed_covariate_col_names = covariates # List of names in processed df
118
+
119
+ rhs_terms = []
120
+
121
+ # 1. Treatment Term
122
+ treatment_patsy_term = treatment_col_name # Default
123
+ original_treatment_info = column_mappings.get(treatment_col_name, {}) # Info from preprocess_data
124
+
125
+ is_binary_encoded = original_treatment_info.get('transformed_as') == 'label_encoded_binary'
126
+ is_still_categorical_in_df = df_analysis[treatment_col_name].dtype.name in ['object', 'category']
127
+
128
+ if is_still_categorical_in_df and not is_binary_encoded: # Covers multi-level and binary categoricals not yet numeric
129
+ if treatment_reference_level:
130
+ treatment_patsy_term = f"C({treatment_col_name}, Treatment(reference='{treatment_reference_level}'))"
131
+ logger.info(f"Treating '{treatment_col_name}' as multi-level categorical with reference '{treatment_reference_level}'.")
132
+ else:
133
+ # Default C() wrapping for categoricals if no specific reference is given.
134
+ # This applies to multi-level or binary categoricals that were not label_encoded to 0/1 by preprocess_data.
135
+ treatment_patsy_term = f"C({treatment_col_name})"
136
+ logger.info(f"Treating '{treatment_col_name}' as categorical (Patsy will pick reference).")
137
+ elif is_binary_encoded: # Was binary and explicitly label encoded to 0/1 by preprocess_data
138
+ # Even if it's now numeric 0/1, C() ensures Patsy treats it categorically for parameter naming consistency.
139
+ treatment_patsy_term = f"C({treatment_col_name})"
140
+ logger.info(f"Treating label-encoded binary '{treatment_col_name}' as categorical for Patsy.")
141
+ else: # Assumed to be already numeric (continuous or discrete numeric not needing C() for main effect)
142
+ # treatment_patsy_term remains treatment_col_name (default)
143
+ logger.info(f"Treating '{treatment_col_name}' as numeric for Patsy formula.")
144
+
145
+ rhs_terms.append(treatment_patsy_term)
146
+
147
+ # 2. Covariate Terms
148
+ for cov_col_name in processed_covariate_col_names:
149
+ if cov_col_name == treatment_col_name: # Should not happen if covariates list is clean
150
+ continue
151
+ # Assume covariates are already numeric/dummy. If one was object/category in df_analysis (unlikely), C() it.
152
+ if df_analysis[cov_col_name].dtype.name in ['object', 'category']:
153
+ rhs_terms.append(f"C({cov_col_name})")
154
+ else:
155
+ rhs_terms.append(cov_col_name)
156
+
157
+ # 3. Interaction Term (Simplified: interaction_variable_candidate_orig_name must map to a single column in df_analysis)
158
+ actual_interaction_term_added_to_formula = None
159
+ if interaction_term_suggested and interaction_variable_candidate_orig_name:
160
+ processed_interaction_col_name = None
161
+ interaction_var_info = column_mappings.get(interaction_variable_candidate_orig_name, {})
162
+
163
+ if interaction_var_info.get('transformed_as') == 'one_hot_encoded':
164
+ logger.warning(f"Interaction with one-hot encoded variable '{interaction_variable_candidate_orig_name}' is complex. Currently skipping this interaction for Linear Regression.")
165
+ elif interaction_var_info.get('new_column_name') and interaction_var_info['new_column_name'] in df_analysis.columns:
166
+ processed_interaction_col_name = interaction_var_info['new_column_name']
167
+ elif interaction_variable_candidate_orig_name in df_analysis.columns: # Was not in mappings, or mapping didn't change name (e.g. numeric)
168
+ processed_interaction_col_name = interaction_variable_candidate_orig_name
169
+
170
+ if processed_interaction_col_name:
171
+ interaction_var_patsy_term = processed_interaction_col_name
172
+ # If the processed interaction column itself is categorical (e.g. label encoded binary)
173
+ if df_analysis[processed_interaction_col_name].dtype.name in ['object', 'category', 'bool'] or \
174
+ interaction_var_info.get('original_dtype') in ['bool', 'category']:
175
+ interaction_var_patsy_term = f"C({processed_interaction_col_name})"
176
+
177
+ actual_interaction_term_added_to_formula = f"{treatment_patsy_term}:{interaction_var_patsy_term}"
178
+ rhs_terms.append(actual_interaction_term_added_to_formula)
179
+ logger.info(f"Adding interaction term to formula: {actual_interaction_term_added_to_formula}")
180
+ elif interaction_variable_candidate_orig_name: # Log if it was suggested but couldn't be mapped/found
181
+ logger.warning(f"Could not resolve interaction variable candidate '{interaction_variable_candidate_orig_name}' to a single usable column in processed data. Skipping interaction term.")
182
+
183
+ # Build the formula string for reporting and fitting
184
+ if not rhs_terms: # Should always have at least treatment
185
+ formula = f"{outcome_col_name} ~ 1"
186
+ else:
187
+ formula = f"{outcome_col_name} ~ {' + '.join(rhs_terms)}"
188
+ logger.info(f"Using formula for Linear Regression: {formula}")
189
+
190
+ try:
191
+ model = smf.ols(formula=formula, data=df_analysis)
192
+ results = model.fit()
193
+ logger.info("OLS model fitted successfully.")
194
+ logger.info(results.summary()) # Changed to debug level for less verbose default logging
195
+
196
+ # --- Result Extraction: LLM attempt first, then Regex fallback ---
197
+ effect_estimates_by_level = {}
198
+ all_params_extracted = False # Default to False
199
+ llm_extraction_successful = False
200
+
201
+ # Attempt LLM-based extraction if llm client and query are available
202
+ llm = get_llm_client()
203
+ if llm and query_str:
204
+ logger.info(f"Attempting LLM-based result extraction (informed by query: '{query_str[:50]}...').")
205
+ try:
206
+ param_names_list = results.params.index.tolist()
207
+ param_estimates_list = results.params.tolist()
208
+ param_p_values_list = results.pvalues.tolist()
209
+ param_std_errs_list = results.bse.tolist()
210
+
211
+ conf_int_df = results.conf_int(alpha=0.05)
212
+ param_conf_ints_low_list = []
213
+ param_conf_ints_high_list = []
214
+
215
+ if not conf_int_df.empty and len(conf_int_df.columns) == 2:
216
+ aligned_conf_int_df = conf_int_df.reindex(results.params.index)
217
+ param_conf_ints_low_list = aligned_conf_int_df.iloc[:, 0].fillna(float('nan')).tolist()
218
+ param_conf_ints_high_list = aligned_conf_int_df.iloc[:, 1].fillna(float('nan')).tolist()
219
+ else:
220
+ nan_list_ci = [float('nan')] * len(param_names_list)
221
+ param_conf_ints_low_list = nan_list_ci
222
+ param_conf_ints_high_list = nan_list_ci
223
+
224
+ # Placeholder for the new prompt template tailored for this extraction task
225
+ # MOVED TO causalscientist/auto_causal/prompts/regression_prompts.py
226
+
227
+ is_multilevel_case_for_prompt = bool(treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded)
228
+ reference_level_for_prompt_str = str(treatment_reference_level) if is_multilevel_case_for_prompt else "N/A"
229
+
230
+ indexed_param_names_for_prompt = [f"{idx}: '{name}'" for idx, name in enumerate(param_names_list)]
231
+ indexed_param_names_str_for_prompt = "\n".join(indexed_param_names_for_prompt)
232
+
233
+ prompt_text_for_identification = STATSMODELS_PARAMS_IDENTIFICATION_PROMPT_TEMPLATE.format(
234
+ user_query=query_str,
235
+ treatment_patsy_term=treatment_patsy_term,
236
+ treatment_col_name=treatment_col_name,
237
+ is_multilevel_case=is_multilevel_case_for_prompt,
238
+ reference_level_for_prompt=reference_level_for_prompt_str,
239
+ indexed_param_names_str=indexed_param_names_str_for_prompt, # Pass the indexed list as a string
240
+ llm_response_schema_json=json.dumps(LLMIdentifiedRelevantParams.model_json_schema(), indent=2)
241
+ )
242
+
243
+ llm_identification_response = _call_llm_for_var(llm, prompt_text_for_identification, LLMIdentifiedRelevantParams)
244
+
245
+ if llm_identification_response and llm_identification_response.identified_params:
246
+ logger.info("LLM identified relevant parameters. Proceeding with programmatic extraction.")
247
+ for item in llm_identification_response.identified_params:
248
+ param_idx = item.param_index
249
+ # Validate index against actual list length
250
+ if 0 <= param_idx < len(results.params.index):
251
+ actual_param_name = results.params.index[param_idx]
252
+ # Sanity check if LLM returned name matches actual name at index
253
+ if item.param_name != actual_param_name:
254
+ logger.warning(f"LLM returned param_name '{item.param_name}' but name at index {param_idx} is '{actual_param_name}'. Using actual name from results.")
255
+
256
+ current_effect_stats = {
257
+ 'estimate': results.params.iloc[param_idx],
258
+ 'p_value': results.pvalues.iloc[param_idx],
259
+ 'conf_int': results.conf_int(alpha=0.05).iloc[param_idx].tolist(),
260
+ 'std_err': results.bse.iloc[param_idx]
261
+ }
262
+
263
+ key_for_effect_dict = 'treatment_effect' # Default for single/binary
264
+ if is_multilevel_case_for_prompt: # If it was a multi-level case
265
+ match = re.search(r'\[T\.([^]]+)]', actual_param_name) # Use actual_param_name
266
+ if match:
267
+ level = match.group(1)
268
+ if level != reference_level_for_prompt_str: # Ensure it's not the ref level itself
269
+ key_for_effect_dict = level
270
+ else:
271
+ logger.warning(f"Could not parse level from LLM-identified param: {actual_param_name}. Storing with raw name.")
272
+ key_for_effect_dict = actual_param_name # Fallback key
273
+
274
+ effect_estimates_by_level[key_for_effect_dict] = current_effect_stats
275
+ else:
276
+ logger.warning(f"LLM returned an invalid parameter index: {param_idx}. Skipping.")
277
+
278
+ if effect_estimates_by_level: # If any effects were successfully processed
279
+ all_params_extracted = llm_identification_response.all_parameters_successfully_identified
280
+ llm_extraction_successful = True
281
+ logger.info(f"Successfully processed LLM-identified parameters. all_parameters_successfully_identified={all_params_extracted}")
282
+ print(f"effect_estimates_by_level: {effect_estimates_by_level}")
283
+ else:
284
+ logger.warning("LLM identified parameters, but none could be processed into effects_estimates_by_level. Falling back to regex.")
285
+ else:
286
+ logger.warning("LLM parameter identification did not yield usable parameters. Falling back to regex.")
287
+
288
+ except Exception as e_llm:
289
+ logger.warning(f"LLM-based result extraction failed: {e_llm}. Falling back to regex.", exc_info=True)
290
+
291
+
292
+ # --- End of Existing Regex Logic Block ---
293
+
294
+ # Primary effect_estimate for simple reporting (e.g. first level or the only one)
295
+ # For multi-level, this is ambiguous. For now, let's report None or the first one.
296
+ # The full details are in effect_estimates_by_level.
297
+ main_effect_estimate = None
298
+ main_p_value = None
299
+ main_conf_int = [None, None] # Default for single or if no effects
300
+ main_std_err = None
301
+
302
+ if effect_estimates_by_level:
303
+ if 'treatment_effect' in effect_estimates_by_level: # Single effect case
304
+ single_effect_data = effect_estimates_by_level['treatment_effect']
305
+ main_effect_estimate = single_effect_data['estimate']
306
+ main_p_value = single_effect_data['p_value']
307
+ main_conf_int = single_effect_data['conf_int']
308
+ main_std_err = single_effect_data['std_err']
309
+ else: # Multi-level case
310
+ logger.info("Multi-level treatment effects extracted. Populating dicts for main estimate fields.")
311
+ effect_estimate_dict = {}
312
+ p_value_dict = {}
313
+ conf_int_dict = {}
314
+ std_err_dict = {}
315
+ for level, stats in effect_estimates_by_level.items():
316
+ effect_estimate_dict[level] = stats.get('estimate')
317
+ p_value_dict[level] = stats.get('p_value')
318
+ conf_int_dict[level] = stats.get('conf_int') # This is already a list [low, high]
319
+ std_err_dict[level] = stats.get('std_err')
320
+
321
+ main_effect_estimate = effect_estimate_dict
322
+ main_p_value = p_value_dict
323
+ main_conf_int = conf_int_dict
324
+ main_std_err = std_err_dict
325
+
326
+ interpretation_details = {}
327
+ if actual_interaction_term_added_to_formula and actual_interaction_term_added_to_formula in results.params.index:
328
+ interpretation_details['interaction_term_coefficient'] = results.params[actual_interaction_term_added_to_formula]
329
+ interpretation_details['interaction_term_p_value'] = results.pvalues[actual_interaction_term_added_to_formula]
330
+ logger.info(f"Interaction term '{actual_interaction_term_added_to_formula}' coeff: {interpretation_details['interaction_term_coefficient']}")
331
+
332
+ diag_results = {}
333
+ interpretation = "Interpretation not available."
334
+
335
+ output_dict = {
336
+ 'effect_estimate': main_effect_estimate,
337
+ 'p_value': main_p_value,
338
+ 'confidence_interval': main_conf_int,
339
+ 'standard_error': main_std_err,
340
+ 'estimated_effects_by_level': effect_estimates_by_level if (treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded and effect_estimates_by_level) else None,
341
+ 'reference_level_used': treatment_reference_level if (treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded) else None,
342
+ 'formula': formula,
343
+ 'model_summary_text': results.summary().as_text(), # Store as text for easier serialization
344
+ 'diagnostics': diag_results,
345
+ 'interpretation_details': interpretation_details, # Added interaction details
346
+ 'interpretation': interpretation,
347
+ 'method_used': 'Linear Regression (OLS)'
348
+ }
349
+ if not all_params_extracted:
350
+ output_dict['warnings'] = ["Could not reliably extract all requested parameters from model results. Please check model_summary_text."]
351
+ return output_dict
352
+
353
+ except Exception as e:
354
+ logger.error(f"Linear Regression failed: {e}")
355
+ raise # Re-raise the exception after logging
auto_causal/methods/linear_regression/llm_assist.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM assistance functions for Linear Regression analysis.
3
+ """
4
+
5
+ from typing import List, Dict, Any, Optional
6
+ import logging
7
+
8
+ # Imported for type hinting
9
+ from langchain.chat_models.base import BaseChatModel
10
+ from statsmodels.regression.linear_model import RegressionResultsWrapper
11
+
12
+ # Import shared LLM helpers
13
+ from auto_causal.utils.llm_helpers import call_llm_with_json_output
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def suggest_lr_covariates(
18
+ df_cols: List[str],
19
+ treatment: str,
20
+ outcome: str,
21
+ query: str,
22
+ llm: Optional[BaseChatModel] = None
23
+ ) -> List[str]:
24
+ """
25
+ (Placeholder) Use LLM to suggest relevant covariates for linear regression.
26
+
27
+ Args:
28
+ df_cols: List of available column names.
29
+ treatment: Treatment variable name.
30
+ outcome: Outcome variable name.
31
+ query: User's causal query text.
32
+ llm: Optional LLM model instance.
33
+
34
+ Returns:
35
+ List of suggested covariate names.
36
+ """
37
+ logger.info("LLM covariate suggestion for LR is not implemented yet.")
38
+ if llm:
39
+ # Placeholder: Call LLM here in future
40
+ pass
41
+ return []
42
+
43
+ def interpret_lr_results(
44
+ results: RegressionResultsWrapper,
45
+ diagnostics: Dict[str, Any],
46
+ treatment_var: str, # Need treatment variable name to extract coefficient
47
+ llm: Optional[BaseChatModel] = None
48
+ ) -> str:
49
+ """
50
+ Use LLM to interpret Linear Regression results.
51
+
52
+ Args:
53
+ results: Fitted statsmodels OLS results object.
54
+ diagnostics: Dictionary of diagnostic test results.
55
+ treatment_var: Name of the treatment variable.
56
+ llm: Optional LLM model instance.
57
+
58
+ Returns:
59
+ String containing natural language interpretation.
60
+ """
61
+ default_interpretation = "LLM interpretation not available for Linear Regression."
62
+ if llm is None:
63
+ logger.info("LLM not provided for LR interpretation.")
64
+ return default_interpretation
65
+
66
+ try:
67
+ # --- Prepare summary for LLM ---
68
+ results_summary = {}
69
+ treatment_val = results.params.get(treatment_var)
70
+ pval_val = results.pvalues.get(treatment_var)
71
+
72
+ if treatment_val is not None:
73
+ results_summary['Treatment Effect Estimate'] = f"{treatment_val:.3f}"
74
+ else:
75
+ logger.warning(f"Treatment variable '{treatment_var}' not found in regression parameters.")
76
+ results_summary['Treatment Effect Estimate'] = "Not Found"
77
+
78
+ if pval_val is not None:
79
+ results_summary['Treatment P-value'] = f"{pval_val:.3f}"
80
+ else:
81
+ logger.warning(f"P-value for treatment variable '{treatment_var}' not found in regression results.")
82
+ results_summary['Treatment P-value'] = "Not Found"
83
+
84
+ try:
85
+ conf_int = results.conf_int().loc[treatment_var]
86
+ results_summary['Treatment 95% CI'] = f"[{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
87
+ except KeyError:
88
+ logger.warning(f"Confidence interval for treatment variable '{treatment_var}' not found.")
89
+ results_summary['Treatment 95% CI'] = "Not Found"
90
+ except Exception as ci_e:
91
+ logger.warning(f"Could not extract confidence interval for '{treatment_var}': {ci_e}")
92
+ results_summary['Treatment 95% CI'] = "Error"
93
+
94
+ results_summary['R-squared'] = f"{results.rsquared:.3f}"
95
+ results_summary['Adj. R-squared'] = f"{results.rsquared_adj:.3f}"
96
+
97
+ diag_summary = {}
98
+ if diagnostics.get("status") == "Success":
99
+ diag_details = diagnostics.get("details", {})
100
+ # Format p-values only if they are numbers
101
+ jb_p = diag_details.get('residuals_normality_jb_p_value')
102
+ bp_p = diag_details.get('homoscedasticity_bp_lm_p_value')
103
+ diag_summary['Residuals Normality (Jarque-Bera P-value)'] = f"{jb_p:.3f}" if isinstance(jb_p, (int, float)) else str(jb_p)
104
+ diag_summary['Homoscedasticity (Breusch-Pagan P-value)'] = f"{bp_p:.3f}" if isinstance(bp_p, (int, float)) else str(bp_p)
105
+ diag_summary['Homoscedasticity Status'] = diag_details.get('homoscedasticity_status', 'N/A')
106
+ diag_summary['Residuals Normality Status'] = diag_details.get('residuals_normality_status', 'N/A')
107
+ else:
108
+ diag_summary['Status'] = diagnostics.get("status", "Unknown")
109
+ if "error" in diagnostics:
110
+ diag_summary['Error'] = diagnostics["error"]
111
+
112
+ # --- Construct Prompt ---
113
+ prompt = f"""
114
+ You are assisting with interpreting Linear Regression (OLS) results for causal inference.
115
+
116
+ Model Results Summary:
117
+ {results_summary}
118
+
119
+ Model Diagnostics Summary:
120
+ {diag_summary}
121
+
122
+ Explain these results in 2-4 concise sentences. Focus on:
123
+ 1. The estimated causal effect of the treatment variable '{treatment_var}' (magnitude, direction, statistical significance based on p-value < 0.05).
124
+ 2. Overall model fit (using R-squared as a rough guide).
125
+ 3. Key diagnostic findings (specifically, mention if residuals are non-normal or if heteroscedasticity is detected, as these violate OLS assumptions and can affect inference).
126
+
127
+ Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
128
+ {{
129
+ "interpretation": "<your concise interpretation text>"
130
+ }}
131
+ """
132
+
133
+ # --- Call LLM ---
134
+ response = call_llm_with_json_output(llm, prompt)
135
+
136
+ # --- Process Response ---
137
+ if response and isinstance(response, dict) and \
138
+ "interpretation" in response and isinstance(response["interpretation"], str):
139
+ return response["interpretation"]
140
+ else:
141
+ logger.warning(f"Failed to get valid interpretation from LLM. Response: {response}")
142
+ return default_interpretation
143
+
144
+ except Exception as e:
145
+ logger.error(f"Error during LLM interpretation for LR: {e}")
146
+ return f"Error generating interpretation: {e}"
auto_causal/methods/propensity_score/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import estimate_propensity_scores
2
+ from .matching import estimate_effect as estimate_matching_effect
3
+ from .weighting import estimate_effect as estimate_weighting_effect
4
+ from .diagnostics import assess_balance, plot_overlap, plot_balance
5
+
6
+ __all__ = [
7
+ "estimate_propensity_scores",
8
+ "estimate_matching_effect",
9
+ "estimate_weighting_effect",
10
+ "assess_balance",
11
+ "plot_overlap",
12
+ "plot_balance"
13
+ ]
auto_causal/methods/propensity_score/base.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base functionality for Propensity Score methods
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.linear_model import LogisticRegression
5
+ from sklearn.preprocessing import StandardScaler
6
+ from typing import List, Optional, Dict, Any
7
+
8
+ # Placeholder for LLM interaction to select model type
9
+ def select_propensity_model(df: pd.DataFrame, treatment: str, covariates: List[str],
10
+ query: Optional[str] = None) -> str:
11
+ '''Selects the appropriate propensity score model type (e.g., logistic, GBM).
12
+
13
+ Placeholder: Currently defaults to Logistic Regression.
14
+ '''
15
+ # TODO: Implement LLM call or heuristic to select model based on data characteristics
16
+ return "logistic"
17
+
18
+ def estimate_propensity_scores(df: pd.DataFrame, treatment: str,
19
+ covariates: List[str], model_type: str = 'logistic',
20
+ **kwargs) -> np.ndarray:
21
+ '''Estimate propensity scores using a specified model.
22
+
23
+ Args:
24
+ df: DataFrame containing the data
25
+ treatment: Name of the treatment variable
26
+ covariates: List of covariate variable names
27
+ model_type: Type of model to use ('logistic' supported for now)
28
+ **kwargs: Additional arguments for the model
29
+
30
+ Returns:
31
+ Array of propensity scores
32
+ '''
33
+
34
+ X = df[covariates]
35
+ y = df[treatment]
36
+
37
+ # Standardize covariates for logistic regression
38
+ scaler = StandardScaler()
39
+ X_scaled = scaler.fit_transform(X)
40
+
41
+ if model_type.lower() == 'logistic':
42
+ # Fit logistic regression
43
+ model = LogisticRegression(max_iter=kwargs.get('max_iter', 1000),
44
+ solver=kwargs.get('solver', 'liblinear'), # Use liblinear for L1/L2
45
+ C=kwargs.get('C', 1.0),
46
+ penalty=kwargs.get('penalty', 'l2'))
47
+ model.fit(X_scaled, y)
48
+
49
+ # Predict probabilities
50
+ propensity_scores = model.predict_proba(X_scaled)[:, 1]
51
+ # TODO: Add other model types like Gradient Boosting, etc.
52
+ # elif model_type.lower() == 'gbm':
53
+ # from sklearn.ensemble import GradientBoostingClassifier
54
+ # model = GradientBoostingClassifier(...)
55
+ # model.fit(X, y)
56
+ # propensity_scores = model.predict_proba(X)[:, 1]
57
+ else:
58
+ raise ValueError(f"Unsupported propensity score model type: {model_type}")
59
+
60
+ # Clip scores to avoid extremes which can cause issues in weighting/matching
61
+ propensity_scores = np.clip(propensity_scores, 0.01, 0.99)
62
+
63
+ return propensity_scores
64
+
65
+ # Common formatting function (can be expanded)
66
+ def format_ps_results(effect_estimate: float, effect_se: float,
67
+ diagnostics: Dict[str, Any], method_details: str,
68
+ parameters: Dict[str, Any]) -> Dict[str, Any]:
69
+ '''Standard formatter for PS method results.'''
70
+ ci_lower = effect_estimate - 1.96 * effect_se
71
+ ci_upper = effect_estimate + 1.96 * effect_se
72
+ return {
73
+ "effect_estimate": float(effect_estimate),
74
+ "effect_se": float(effect_se),
75
+ "confidence_interval": [float(ci_lower), float(ci_upper)],
76
+ "diagnostics": diagnostics,
77
+ "method_details": method_details,
78
+ "parameters": parameters
79
+ # Add p-value if needed (can be calculated from estimate and SE)
80
+ }
auto_causal/methods/propensity_score/diagnostics.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Balance and sensitivity analysis diagnostics for Propensity Score methods
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ from typing import Dict, List, Optional, Any
6
+
7
+ # Import necessary plotting libraries if visualizations are needed
8
+ # import matplotlib.pyplot as plt
9
+ # import seaborn as sns
10
+
11
+ # Import utility for standardized differences if needed
12
+ from auto_causal.methods.utils import calculate_standardized_differences
13
+
14
+ def assess_balance(df_original: pd.DataFrame, df_matched_or_weighted: pd.DataFrame,
15
+ treatment: str, covariates: List[str],
16
+ method: str,
17
+ propensity_scores_original: Optional[np.ndarray] = None,
18
+ propensity_scores_matched: Optional[np.ndarray] = None,
19
+ weights: Optional[np.ndarray] = None) -> Dict[str, Any]:
20
+ '''Assesses covariate balance before and after matching/weighting.
21
+
22
+ Placeholder: Returns dummy diagnostic data.
23
+ '''
24
+ print(f"Assessing balance for {method}...")
25
+ # TODO: Implement actual balance checking using standardized differences,
26
+ # variance ratios, KS tests, etc.
27
+ # Example using standardized differences (needs calculate_standardized_differences):
28
+ # std_diff_before = calculate_standardized_differences(df_original, treatment, covariates)
29
+ # std_diff_after = calculate_standardized_differences(df_matched_or_weighted, treatment, covariates, weights=weights)
30
+
31
+ dummy_balance_metric = {cov: np.random.rand() * 0.1 for cov in covariates} # Simulate good balance
32
+
33
+ return {
34
+ "balance_metrics": dummy_balance_metric,
35
+ "balance_achieved": True, # Placeholder
36
+ "problematic_covariates": [], # Placeholder
37
+ # Add plots or paths to plots if generated
38
+ "plots": {
39
+ "balance_plot": "balance_plot.png",
40
+ "overlap_plot": "overlap_plot.png"
41
+ }
42
+ }
43
+
44
+ def assess_weight_distribution(weights: np.ndarray, treatment_indicator: pd.Series) -> Dict[str, Any]:
45
+ '''Assesses the distribution of IPW weights.
46
+
47
+ Placeholder: Returns dummy diagnostic data.
48
+ '''
49
+ print("Assessing weight distribution...")
50
+ # TODO: Implement checks for extreme weights, effective sample size, etc.
51
+ return {
52
+ "min_weight": float(np.min(weights)),
53
+ "max_weight": float(np.max(weights)),
54
+ "mean_weight": float(np.mean(weights)),
55
+ "std_dev_weight": float(np.std(weights)),
56
+ "effective_sample_size": len(weights) / (1 + np.std(weights)**2 / np.mean(weights)**2), # Kish's ESS approx
57
+ "potential_issues": np.max(weights) > 20 # Example check
58
+ }
59
+
60
+ def plot_overlap(df: pd.DataFrame, treatment: str, propensity_scores: np.ndarray, save_path: str = 'overlap_plot.png'):
61
+ '''Generates plot showing propensity score overlap.
62
+ Placeholder: Does nothing.
63
+ '''
64
+ print(f"Generating overlap plot (placeholder) -> {save_path}")
65
+ # TODO: Implement actual plotting (e.g., using seaborn histplot or kdeplot)
66
+ pass
67
+
68
+ def plot_balance(balance_metrics_before: Dict[str, float], balance_metrics_after: Dict[str, float], save_path: str = 'balance_plot.png'):
69
+ '''Generates plot showing covariate balance before/after.
70
+ Placeholder: Does nothing.
71
+ '''
72
+ print(f"Generating balance plot (placeholder) -> {save_path}")
73
+ # TODO: Implement actual plotting (e.g., Love plot)
74
+ pass
auto_causal/methods/propensity_score/llm_assist.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLM Integration points for Propensity Score methods
2
+ import pandas as pd
3
+ from typing import List, Optional, Dict, Any
4
+
5
+ def determine_optimal_caliper(df: pd.DataFrame, treatment: str,
6
+ covariates: List[str],
7
+ query: Optional[str] = None) -> float:
8
+ '''Determines optimal caliper for PSM using data or LLM.
9
+
10
+ Placeholder: Returns a default value.
11
+ '''
12
+ # TODO: Implement data-driven (e.g., based on PS distribution) or LLM-assisted caliper selection.
13
+ # Common rule of thumb is 0.2 * std dev of logit(PS), but that requires calculating PS first.
14
+ return 0.2
15
+
16
+ def determine_optimal_weight_type(df: pd.DataFrame, treatment: str,
17
+ query: Optional[str] = None) -> str:
18
+ '''Determines the optimal type of IPW weights (ATE, ATT, etc.).
19
+
20
+ Placeholder: Defaults to ATE.
21
+ '''
22
+ # TODO: Implement LLM or rule-based selection.
23
+ return "ATE"
24
+
25
+ def determine_optimal_trim_threshold(df: pd.DataFrame, treatment: str,
26
+ propensity_scores: Optional[pd.Series] = None,
27
+ query: Optional[str] = None) -> Optional[float]:
28
+ '''Determines optimal threshold for trimming extreme propensity scores.
29
+
30
+ Placeholder: Defaults to no trimming (None).
31
+ '''
32
+ # TODO: Implement data-driven or LLM-assisted threshold selection (e.g., based on score distribution).
33
+ return None # Corresponds to no trimming by default
34
+
35
+ # Placeholder for calling LLM to get parameters (can use the one in utils if general enough)
36
+ def get_llm_parameters(df: pd.DataFrame, query: str, method: str) -> Dict[str, Any]:
37
+ '''Placeholder to get parameters via LLM based on dataset and query.'''
38
+ # In reality, call something like analyze_dataset_for_method from utils.llm_helpers
39
+ print(f"Simulating LLM call to get parameters for {method}...")
40
+ if method == "PS.Matching":
41
+ return {"parameters": {"caliper": 0.15}, "validation": {"check_balance": True}}
42
+ elif method == "PS.Weighting":
43
+ return {"parameters": {"weight_type": "ATE", "trim_threshold": 0.05}, "validation": {"check_weights": True}}
44
+ else:
45
+ return {"parameters": {}, "validation": {}}
auto_causal/methods/propensity_score/matching.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Propensity Score Matching Implementation
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.neighbors import NearestNeighbors
5
+ import statsmodels.api as sm # For bias adjustment regression
6
+ import logging # For logging fallback
7
+ from typing import Dict, List, Optional, Any
8
+
9
+ # Import DoWhy
10
+ from dowhy import CausalModel
11
+
12
+ from .base import estimate_propensity_scores, format_ps_results, select_propensity_model
13
+ from .diagnostics import assess_balance #, plot_overlap, plot_balance # Import diagnostic functions
14
+ # Remove determine_optimal_caliper, it will be replaced by a heuristic
15
+ from .llm_assist import get_llm_parameters # Import LLM helpers
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ def _calculate_logit(pscore):
20
+ """Calculate logit of propensity score, clipping to avoid inf."""
21
+ # Clip pscore to prevent log(0) or log(1) issues which lead to inf
22
+ epsilon = 1e-7
23
+ pscore_clipped = np.clip(pscore, epsilon, 1 - epsilon)
24
+ return np.log(pscore_clipped / (1 - pscore_clipped))
25
+
26
+ def _perform_matching_and_get_att(
27
+ df_sample: pd.DataFrame,
28
+ treatment: str,
29
+ outcome: str,
30
+ covariates: List[str],
31
+ propensity_model_type: str,
32
+ n_neighbors: int,
33
+ caliper: float,
34
+ perform_bias_adjustment: bool,
35
+ **kwargs
36
+ ) -> float:
37
+ """
38
+ Helper to perform Custom KNN PSM and calculate ATT, potentially with bias adjustment.
39
+ Returns the ATT estimate.
40
+ """
41
+ df_ps = df_sample.copy()
42
+ try:
43
+ propensity_scores = estimate_propensity_scores(
44
+ df_ps, treatment, covariates, model_type=propensity_model_type, **kwargs
45
+ )
46
+ except Exception as e:
47
+ logger.warning(f"Propensity score estimation failed in helper: {e}")
48
+ return np.nan # Cannot proceed without propensity scores
49
+
50
+ df_ps['propensity_score'] = propensity_scores
51
+
52
+ treated = df_ps[df_ps[treatment] == 1]
53
+ control = df_ps[df_ps[treatment] == 0]
54
+
55
+ if treated.empty or control.empty:
56
+ return np.nan
57
+
58
+ nn = NearestNeighbors(n_neighbors=n_neighbors, radius=caliper if caliper is not None else np.inf, metric='minkowski', p=2)
59
+ try:
60
+ # Ensure control PS are valid before fitting
61
+ control_ps_values = control[['propensity_score']].values
62
+ if np.isnan(control_ps_values).any():
63
+ logger.warning("NaN values found in control propensity scores before NN fitting.")
64
+ return np.nan
65
+ nn.fit(control_ps_values)
66
+
67
+ # Ensure treated PS are valid before querying
68
+ treated_ps_values = treated[['propensity_score']].values
69
+ if np.isnan(treated_ps_values).any():
70
+ logger.warning("NaN values found in treated propensity scores before NN query.")
71
+ return np.nan
72
+ distances, indices = nn.kneighbors(treated_ps_values)
73
+
74
+ except ValueError as e:
75
+ # Handles case where control group might be too small or have NaN PS scores
76
+ logger.warning(f"NearestNeighbors fitting/query failed: {e}")
77
+ return np.nan
78
+
79
+ matched_outcomes_treated = []
80
+ matched_outcomes_control_means = []
81
+ propensity_diffs = []
82
+
83
+ for i in range(len(treated)):
84
+ treated_unit = treated.iloc[[i]]
85
+ valid_neighbors_mask = distances[i] <= (caliper if caliper is not None else np.inf)
86
+ valid_neighbors_idx = indices[i][valid_neighbors_mask]
87
+
88
+ if len(valid_neighbors_idx) > 0:
89
+ matched_controls_for_this_treated = control.iloc[valid_neighbors_idx]
90
+ if matched_controls_for_this_treated.empty:
91
+ continue # Should not happen with valid_neighbors_idx check, but safety
92
+
93
+ matched_outcomes_treated.append(treated_unit[outcome].values[0])
94
+ matched_outcomes_control_means.append(matched_controls_for_this_treated[outcome].mean())
95
+
96
+ if perform_bias_adjustment:
97
+ # Ensure PS scores are valid before calculating difference
98
+ treated_ps = treated_unit['propensity_score'].values[0]
99
+ control_ps_mean = matched_controls_for_this_treated['propensity_score'].mean()
100
+ if np.isnan(treated_ps) or np.isnan(control_ps_mean):
101
+ logger.warning("NaN propensity score encountered during bias adjustment calculation.")
102
+ # Cannot perform bias adjustment for this unit, potentially skip or handle
103
+ # For now, let's skip adding to propensity_diffs if NaN found
104
+ continue
105
+ propensity_diff = treated_ps - control_ps_mean
106
+ propensity_diffs.append(propensity_diff)
107
+
108
+ if not matched_outcomes_treated:
109
+ return np.nan
110
+
111
+ raw_att_components = np.array(matched_outcomes_treated) - np.array(matched_outcomes_control_means)
112
+
113
+ if perform_bias_adjustment:
114
+ # Ensure lengths match *after* potential skips due to NaNs
115
+ if not propensity_diffs or len(raw_att_components) != len(propensity_diffs):
116
+ logger.warning("Bias adjustment skipped due to inconsistent data lengths after NaN checks.")
117
+ return np.mean(raw_att_components)
118
+
119
+ try:
120
+ X_bias_adj = sm.add_constant(np.array(propensity_diffs))
121
+ y_bias_adj = raw_att_components
122
+ # Add check for NaNs/Infs in inputs to OLS
123
+ if np.isnan(X_bias_adj).any() or np.isnan(y_bias_adj).any() or \
124
+ np.isinf(X_bias_adj).any() or np.isinf(y_bias_adj).any():
125
+ logger.warning("NaN/Inf values detected in OLS inputs for bias adjustment. Falling back.")
126
+ return np.mean(raw_att_components)
127
+
128
+ bias_model = sm.OLS(y_bias_adj, X_bias_adj).fit()
129
+ bias_adjusted_att = bias_model.params[0]
130
+ return bias_adjusted_att
131
+ except Exception as e:
132
+ logger.warning(f"OLS for bias adjustment failed: {e}. Falling back to raw ATT.")
133
+ return np.mean(raw_att_components)
134
+ else:
135
+ return np.mean(raw_att_components)
136
+
137
+ def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str,
138
+ covariates: List[str], **kwargs) -> Dict[str, Any]:
139
+ '''Estimate ATT using Propensity Score Matching.
140
+ Tries DoWhy's PSM first, falls back to custom implementation if DoWhy fails.
141
+ Uses bootstrap SE based on the custom implementation regardless.
142
+ '''
143
+ query = kwargs.get('query')
144
+ n_bootstraps = kwargs.get('n_bootstraps', 100)
145
+
146
+ # --- Parameter Setup (as before) ---
147
+ llm_params = get_llm_parameters(df, query, "PS.Matching")
148
+ llm_suggested_params = llm_params.get("parameters", {})
149
+
150
+ caliper = kwargs.get('caliper', llm_suggested_params.get('caliper'))
151
+ temp_propensity_scores_for_caliper = None
152
+ try:
153
+ temp_propensity_scores_for_caliper = estimate_propensity_scores(
154
+ df, treatment, covariates,
155
+ model_type=llm_suggested_params.get('propensity_model_type', 'logistic'),
156
+ **kwargs
157
+ )
158
+ if caliper is None and temp_propensity_scores_for_caliper is not None:
159
+ logit_ps = _calculate_logit(temp_propensity_scores_for_caliper)
160
+ if not np.isnan(logit_ps).all(): # Check if logit calculation was successful
161
+ caliper = 0.2 * np.nanstd(logit_ps) # Use nanstd for robustness
162
+ else:
163
+ logger.warning("Logit of propensity scores resulted in NaNs, cannot calculate heuristic caliper.")
164
+ caliper = None
165
+ elif caliper is None:
166
+ logger.warning("Could not estimate propensity scores for caliper heuristic.")
167
+ caliper = None
168
+
169
+ except Exception as e:
170
+ logger.warning(f"Failed to estimate initial propensity scores for caliper heuristic: {e}. Caliper set to None.")
171
+ caliper = None # Proceed without caliper if heuristic fails
172
+
173
+ n_neighbors = kwargs.get('n_neighbors', llm_suggested_params.get('n_neighbors', 1))
174
+ propensity_model_type = kwargs.get('propensity_model_type',
175
+ llm_suggested_params.get('propensity_model_type',
176
+ select_propensity_model(df, treatment, covariates, query)))
177
+
178
+ # --- Attempt DoWhy PSM for Point Estimate ---
179
+ att_estimate = np.nan
180
+ method_used_for_att = "Fallback Custom PSM"
181
+ dowhy_model = None
182
+ identified_estimand = None
183
+
184
+ try:
185
+ logger.info("Attempting estimation using DoWhy Propensity Score Matching...")
186
+ dowhy_model = CausalModel(
187
+ data=df,
188
+ treatment=treatment,
189
+ outcome=outcome,
190
+ common_causes=covariates,
191
+ estimand_type='nonparametric-ate' # Provide list of names directly
192
+ )
193
+ # Identify estimand (optional step, but good practice)
194
+ identified_estimand = dowhy_model.identify_effect(proceed_when_unidentifiable=True)
195
+ logger.info(f"DoWhy identified estimand: {identified_estimand}")
196
+
197
+ # Estimate effect using DoWhy's PSM
198
+ estimate = dowhy_model.estimate_effect(
199
+ identified_estimand,
200
+ method_name="backdoor.propensity_score_matching",
201
+ target_units="att",
202
+ method_params={}
203
+ )
204
+ att_estimate = estimate.value
205
+ method_used_for_att = "DoWhy PSM"
206
+ logger.info(f"DoWhy PSM successful. ATT Estimate: {att_estimate}")
207
+
208
+ except Exception as e:
209
+ logger.warning(f"DoWhy PSM failed: {e}. Falling back to custom PSM implementation.")
210
+ # Fallback is triggered implicitly if att_estimate remains NaN
211
+
212
+ # --- Fallback or if DoWhy failed ---
213
+ if np.isnan(att_estimate):
214
+ logger.info("Calculating ATT estimate using fallback custom PSM...")
215
+ att_estimate = _perform_matching_and_get_att(
216
+ df, treatment, outcome, covariates,
217
+ propensity_model_type, n_neighbors, caliper,
218
+ perform_bias_adjustment=True, **kwargs # Bias adjust the fallback
219
+ )
220
+ method_used_for_att = "Fallback Custom PSM" # Confirm it's fallback
221
+ if np.isnan(att_estimate):
222
+ raise ValueError("Fallback custom PSM estimation also failed. Cannot proceed.")
223
+ logger.info(f"Fallback Custom PSM successful. ATT Estimate: {att_estimate}")
224
+
225
+ # --- Bootstrap SE (using custom helper for consistency) ---
226
+ logger.info(f"Calculating Bootstrap SE using custom helper ({n_bootstraps} iterations)...")
227
+ bootstrap_atts = []
228
+ for i in range(n_bootstraps):
229
+ try:
230
+ # Ensure bootstrap samples are drawn correctly
231
+ df_boot = df.sample(n=len(df), replace=True, random_state=np.random.randint(1000000) + i)
232
+ # Bias adjustment in bootstrap can be slow, optionally disable it
233
+ boot_att = _perform_matching_and_get_att(
234
+ df_boot, treatment, outcome, covariates,
235
+ propensity_model_type, n_neighbors, caliper,
236
+ perform_bias_adjustment=False, **kwargs # Set bias adjustment to False for speed in bootstrap
237
+ )
238
+ if not np.isnan(boot_att):
239
+ bootstrap_atts.append(boot_att)
240
+ except Exception as boot_e:
241
+ logger.warning(f"Bootstrap iteration {i+1} failed: {boot_e}")
242
+ continue # Skip failed bootstrap iteration
243
+
244
+ att_se = np.nanstd(bootstrap_atts) if bootstrap_atts else np.nan # Use nanstd
245
+ actual_bootstrap_iterations = len(bootstrap_atts)
246
+ logger.info(f"Bootstrap SE calculated: {att_se} from {actual_bootstrap_iterations} successful iterations.")
247
+
248
+ # --- Diagnostics (using custom matching logic for consistency) ---
249
+ logger.info("Performing diagnostic checks using custom matching logic...")
250
+ diagnostics = {"error": "Diagnostics failed to run."}
251
+ propensity_scores_orig = temp_propensity_scores_for_caliper # Reuse if available and not None
252
+
253
+ if propensity_scores_orig is None:
254
+ try:
255
+ propensity_scores_orig = estimate_propensity_scores(
256
+ df, treatment, covariates, model_type=propensity_model_type, **kwargs
257
+ )
258
+ except Exception as e:
259
+ logger.error(f"Failed to estimate propensity scores for diagnostics: {e}")
260
+ propensity_scores_orig = None
261
+
262
+ if propensity_scores_orig is not None and not np.isnan(propensity_scores_orig).all():
263
+ df_ps_orig = df.copy()
264
+ df_ps_orig['propensity_score'] = propensity_scores_orig
265
+ treated_orig = df_ps_orig[df_ps_orig[treatment] == 1]
266
+ control_orig = df_ps_orig[df_ps_orig[treatment] == 0]
267
+ unmatched_treated_count = 0
268
+
269
+ # Drop rows with NaN propensity scores before diagnostics
270
+ treated_orig = treated_orig.dropna(subset=['propensity_score'])
271
+ control_orig = control_orig.dropna(subset=['propensity_score'])
272
+
273
+ if not treated_orig.empty and not control_orig.empty:
274
+ try:
275
+ nn_diag = NearestNeighbors(n_neighbors=n_neighbors, radius=caliper if caliper is not None else np.inf, metric='minkowski', p=2)
276
+ nn_diag.fit(control_orig[['propensity_score']].values)
277
+ distances_diag, indices_diag = nn_diag.kneighbors(treated_orig[['propensity_score']].values)
278
+
279
+ matched_treated_indices_diag = []
280
+ matched_control_indices_diag = []
281
+
282
+ for i in range(len(treated_orig)):
283
+ valid_neighbors_mask_diag = distances_diag[i] <= (caliper if caliper is not None else np.inf)
284
+ valid_neighbors_idx_diag = indices_diag[i][valid_neighbors_mask_diag]
285
+ if len(valid_neighbors_idx_diag) > 0:
286
+ # Get original DataFrame indices from control_orig based on iloc indices
287
+ selected_control_original_indices = control_orig.index[valid_neighbors_idx_diag]
288
+ matched_treated_indices_diag.extend([treated_orig.index[i]] * len(selected_control_original_indices))
289
+ matched_control_indices_diag.extend(selected_control_original_indices)
290
+ else:
291
+ unmatched_treated_count += 1
292
+
293
+ if matched_control_indices_diag:
294
+ # Use unique indices for creating the diagnostic dataframe
295
+ unique_matched_control_indices = list(set(matched_control_indices_diag))
296
+ unique_matched_treated_indices = list(set(matched_treated_indices_diag))
297
+
298
+ matched_control_df_diag = df.loc[unique_matched_control_indices]
299
+ matched_treated_df_for_diag = df.loc[unique_matched_treated_indices]
300
+ matched_df_diag = pd.concat([matched_treated_df_for_diag, matched_control_df_diag]).drop_duplicates()
301
+
302
+ # Retrieve propensity scores for the specific units in matched_df_diag
303
+ ps_matched_for_diag = propensity_scores_orig.loc[matched_df_diag.index]
304
+
305
+ diagnostics = assess_balance(df, matched_df_diag, treatment, covariates,
306
+ method="PSM",
307
+ propensity_scores_original=propensity_scores_orig,
308
+ propensity_scores_matched=ps_matched_for_diag)
309
+ else:
310
+ diagnostics = {"message": "No units could be matched for diagnostic assessment."}
311
+ # If no controls were matched, all treated were unmatched
312
+ unmatched_treated_count = len(treated_orig) if not treated_orig.empty else 0
313
+ except Exception as diag_e:
314
+ logger.error(f"Error during diagnostic matching/balance assessment: {diag_e}")
315
+ diagnostics = {"error": f"Diagnostics failed: {diag_e}"}
316
+ else:
317
+ diagnostics = {"message": "Treatment or control group empty after dropping NaN PS, diagnostics skipped."}
318
+ unmatched_treated_count = len(treated_orig) if not treated_orig.empty else 0
319
+
320
+ # Ensure unmatched count calculation is safe
321
+ if 'unmatched_treated_count' not in locals():
322
+ unmatched_treated_count = 0 # Initialize if loop didn't run
323
+ diagnostics["unmatched_treated_count"] = unmatched_treated_count
324
+ diagnostics["percent_treated_matched"] = (len(treated_orig) - unmatched_treated_count) / len(treated_orig) * 100 if len(treated_orig) > 0 else 0
325
+ else:
326
+ diagnostics = {"error": "Propensity scores could not be estimated for diagnostics."}
327
+
328
+ # Add final details to diagnostics
329
+ diagnostics["att_estimation_method"] = method_used_for_att
330
+ diagnostics["propensity_score_model"] = propensity_model_type
331
+ diagnostics["bootstrap_iterations_for_se"] = actual_bootstrap_iterations
332
+ diagnostics["final_caliper_used"] = caliper
333
+
334
+ # --- Format and return results ---
335
+ logger.info(f"Formatting results. ATT Estimate: {att_estimate}, SE: {att_se}, Method: {method_used_for_att}")
336
+ return format_ps_results(att_estimate, att_se, diagnostics,
337
+ method_details=f"PSM ({method_used_for_att})",
338
+ parameters={"caliper": caliper,
339
+ "n_neighbors": n_neighbors, # n_neighbors used in fallback/bootstrap/diag
340
+ "propensity_model": propensity_model_type,
341
+ "n_bootstraps_config": n_bootstraps})