Spaces:
Running
Running
Commit
·
efbcc96
1
Parent(s):
b7dc123
added decision tree visualization
Browse files- app.py +103 -23
- visualise.py +328 -0
app.py
CHANGED
@@ -6,6 +6,45 @@ import gradio as gr
|
|
6 |
import time
|
7 |
import smtplib
|
8 |
from email.message import EmailMessage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Make your repo importable (expecting a folder named causal-agent at repo root)
|
11 |
sys.path.append(str(Path(__file__).parent / "causal-agent"))
|
@@ -120,18 +159,33 @@ def _ok_html(text):
|
|
120 |
return f"<div style='padding:10px;border:1px solid #2ea043;border-radius:5px;color:#2ea043;background-color:#333333;'>✅ {text}</div>"
|
121 |
|
122 |
# --- Email support ---
|
123 |
-
|
124 |
-
|
125 |
-
host = os.getenv("SMTP_HOST")
|
126 |
-
port = int(os.getenv("SMTP_PORT", "587"))
|
127 |
-
user = os.getenv("SMTP_USER")
|
128 |
-
pwd = os.getenv("SMTP_PASS")
|
129 |
-
from_addr = os.getenv("EMAIL_FROM")
|
130 |
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
try:
|
|
|
135 |
msg = EmailMessage()
|
136 |
msg["From"] = from_addr
|
137 |
msg["To"] = recipient
|
@@ -142,10 +196,17 @@ def send_email(recipient: str, subject: str, body_text: str, attachment_name: st
|
|
142 |
payload = json.dumps(attachment_json, indent=2).encode("utf-8")
|
143 |
msg.add_attachment(payload, maintype="application", subtype="json", filename=attachment_name)
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
return ""
|
150 |
except Exception as e:
|
151 |
return f"Email send failed: {e}"
|
@@ -154,35 +215,43 @@ def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
|
|
154 |
start = time.time()
|
155 |
|
156 |
processing_html = _html_panel("🔄 Analysis in Progress...", "<div style='font-size:14px;color:#bbb;'>This may take 1–2 minutes depending on dataset size.</div>")
|
157 |
-
yield (processing_html, processing_html, processing_html, {"status": "Processing started..."})
|
158 |
|
159 |
if not os.getenv("OPENAI_API_KEY"):
|
160 |
-
yield (_err_html("Set a Space Secret named OPENAI_API_KEY"), "", "", {})
|
161 |
return
|
162 |
if not csv_path:
|
163 |
-
yield (_warn_html("Please upload a CSV dataset."), "", "", {})
|
164 |
return
|
165 |
|
166 |
try:
|
167 |
step_html = _html_panel("📊 Running Causal Analysis...", "<div style='font-size:14px;color:#bbb;'>Analyzing dataset and selecting optimal method…</div>")
|
168 |
-
yield (step_html, step_html, step_html, {"status": "Running causal analysis..."})
|
169 |
|
170 |
result = run_causal_analysis(
|
171 |
query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
|
172 |
dataset_path=csv_path,
|
173 |
dataset_description=(dataset_description or "").strip(),
|
174 |
)
|
175 |
-
|
176 |
llm_html = _html_panel("🤖 Generating Summary...", "<div style='font-size:14px;color:#bbb;'>Creating human-readable interpretation…</div>")
|
177 |
-
yield (llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}})
|
178 |
|
179 |
except Exception as e:
|
180 |
-
yield (_err_html(str(e)), "", "", {})
|
181 |
return
|
182 |
|
183 |
try:
|
184 |
payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
|
|
|
185 |
method = payload.get("method_used", "N/A")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
method_html = _html_panel("Selected Method", f"<p style='margin:0;font-size:16px;'>{method}</p>")
|
188 |
|
@@ -199,7 +268,7 @@ def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
|
|
199 |
explanation_html = _warn_html(f"LLM summary failed: {e}")
|
200 |
|
201 |
except Exception as e:
|
202 |
-
yield (_err_html(f"Failed to parse results: {e}"), "", "", {})
|
203 |
return
|
204 |
|
205 |
# Optional email send (best-effort)
|
@@ -225,8 +294,12 @@ def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
|
|
225 |
explanation_html += _warn_html(email_err)
|
226 |
else:
|
227 |
explanation_html += _ok_html(f"Results emailed to {email.strip()}")
|
|
|
228 |
|
229 |
-
|
|
|
|
|
|
|
230 |
|
231 |
with gr.Blocks() as demo:
|
232 |
gr.Markdown("# Causal AI Scientist")
|
@@ -310,16 +383,23 @@ with gr.Blocks() as demo:
|
|
310 |
with gr.Row():
|
311 |
explanation_out = gr.HTML(label="Detailed Explanation")
|
312 |
|
|
|
|
|
|
|
|
|
|
|
313 |
with gr.Accordion("Raw Results (Advanced)", open=False):
|
314 |
raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
|
315 |
|
316 |
run_btn.click(
|
317 |
fn=run_agent,
|
318 |
inputs=[query, csv_file, dataset_description, email],
|
319 |
-
outputs=[method_out, effects_out, explanation_out, raw_results],
|
320 |
show_progress=True
|
321 |
)
|
322 |
|
|
|
|
|
323 |
|
324 |
|
325 |
if __name__ == "__main__":
|
|
|
6 |
import time
|
7 |
import smtplib
|
8 |
from email.message import EmailMessage
|
9 |
+
from visualise import render_from_json
|
10 |
+
from pathlib import Path
|
11 |
+
import time
|
12 |
+
import os, json, time, tempfile
|
13 |
+
from huggingface_hub import HfApi, HfFileSystem, create_repo
|
14 |
+
|
15 |
+
REPO = "CausalNLP/cais-demo-cache" # dataset repo id
|
16 |
+
TOKEN = os.environ["HF_WRITE_TOKEN"] # set as Space secret
|
17 |
+
api = HfApi(token=TOKEN)
|
18 |
+
fs = HfFileSystem(token=TOKEN)
|
19 |
+
|
20 |
+
# 1) ensure repo exists
|
21 |
+
create_repo(REPO, repo_type="dataset", private=True, exist_ok=True, token=TOKEN)
|
22 |
+
|
23 |
+
def cache_run(query, payload, artifacts=None):
|
24 |
+
ts = time.strftime("%Y-%m-%dT%H:%M:%S")
|
25 |
+
row = {"timestamp": ts, "query": query, "payload": payload, "artifacts": artifacts or {}}
|
26 |
+
|
27 |
+
hub_path = f"datasets/{REPO}/logs.jsonl"
|
28 |
+
# 2) download existing (if any), append, and push in one commit
|
29 |
+
with tempfile.TemporaryDirectory() as td:
|
30 |
+
local = os.path.join(td, "logs.jsonl")
|
31 |
+
try:
|
32 |
+
with fs.open(hub_path, "rb") as fsrc, open(local, "wb") as fdst:
|
33 |
+
fdst.write(fsrc.read())
|
34 |
+
except FileNotFoundError:
|
35 |
+
open(local, "w").close()
|
36 |
+
|
37 |
+
with open(local, "a", encoding="utf-8") as f:
|
38 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
39 |
+
|
40 |
+
api.upload_file(
|
41 |
+
path_or_fileobj=local,
|
42 |
+
path_in_repo="logs.jsonl",
|
43 |
+
repo_id=REPO,
|
44 |
+
repo_type="dataset",
|
45 |
+
commit_message=f"append log {ts}"
|
46 |
+
)
|
47 |
+
|
48 |
|
49 |
# Make your repo importable (expecting a folder named causal-agent at repo root)
|
50 |
sys.path.append(str(Path(__file__).parent / "causal-agent"))
|
|
|
159 |
return f"<div style='padding:10px;border:1px solid #2ea043;border-radius:5px;color:#2ea043;background-color:#333333;'>✅ {text}</div>"
|
160 |
|
161 |
# --- Email support ---
|
162 |
+
import base64, json, requests
|
163 |
+
from email.message import EmailMessage
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
+
def _gmail_access_token() -> str:
|
166 |
+
token_url = "https://oauth2.googleapis.com/token"
|
167 |
+
data = {
|
168 |
+
"client_id": os.getenv("GMAIL_CLIENT_ID"),
|
169 |
+
"client_secret": os.getenv("GMAIL_CLIENT_SECRET"),
|
170 |
+
"refresh_token": os.getenv("GMAIL_REFRESH_TOKEN"),
|
171 |
+
"grant_type": "refresh_token",
|
172 |
+
}
|
173 |
+
r = requests.post(token_url, data=data, timeout=20)
|
174 |
+
r.raise_for_status()
|
175 |
+
return r.json()["access_token"]
|
176 |
+
|
177 |
+
def send_email(recipient: str, subject: str, body_text: str,
|
178 |
+
attachment_name: str = None, attachment_json: dict = None) -> str:
|
179 |
+
"""
|
180 |
+
Sends via Gmail API. Returns '' on success, or an error string.
|
181 |
+
"""
|
182 |
+
from_addr = os.getenv("EMAIL_FROM")
|
183 |
+
if not all([os.getenv("GMAIL_CLIENT_ID"), os.getenv("GMAIL_CLIENT_SECRET"),
|
184 |
+
os.getenv("GMAIL_REFRESH_TOKEN"), from_addr]):
|
185 |
+
return "Gmail API not configured (set GMAIL_CLIENT_ID, GMAIL_CLIENT_SECRET, GMAIL_REFRESH_TOKEN, EMAIL_FROM)."
|
186 |
|
187 |
try:
|
188 |
+
# Build MIME message
|
189 |
msg = EmailMessage()
|
190 |
msg["From"] = from_addr
|
191 |
msg["To"] = recipient
|
|
|
196 |
payload = json.dumps(attachment_json, indent=2).encode("utf-8")
|
197 |
msg.add_attachment(payload, maintype="application", subtype="json", filename=attachment_name)
|
198 |
|
199 |
+
# Base64url encode the raw RFC822 message
|
200 |
+
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
201 |
+
|
202 |
+
# Get access token and send
|
203 |
+
access_token = _gmail_access_token()
|
204 |
+
api_url = "https://gmail.googleapis.com/gmail/v1/users/me/messages/send"
|
205 |
+
headers = {"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}
|
206 |
+
r = requests.post(api_url, headers=headers, json={"raw": raw}, timeout=20)
|
207 |
+
|
208 |
+
if r.status_code >= 400:
|
209 |
+
return f"Gmail API error {r.status_code}: {r.text[:300]}"
|
210 |
return ""
|
211 |
except Exception as e:
|
212 |
return f"Email send failed: {e}"
|
|
|
215 |
start = time.time()
|
216 |
|
217 |
processing_html = _html_panel("🔄 Analysis in Progress...", "<div style='font-size:14px;color:#bbb;'>This may take 1–2 minutes depending on dataset size.</div>")
|
218 |
+
yield (processing_html, processing_html, processing_html, {"status": "Processing started..."}, None, None)
|
219 |
|
220 |
if not os.getenv("OPENAI_API_KEY"):
|
221 |
+
yield (_err_html("Set a Space Secret named OPENAI_API_KEY"), "", "", {}, None, None)
|
222 |
return
|
223 |
if not csv_path:
|
224 |
+
yield (_warn_html("Please upload a CSV dataset."), "", "", {}, None, None)
|
225 |
return
|
226 |
|
227 |
try:
|
228 |
step_html = _html_panel("📊 Running Causal Analysis...", "<div style='font-size:14px;color:#bbb;'>Analyzing dataset and selecting optimal method…</div>")
|
229 |
+
yield (step_html, step_html, step_html, {"status": "Running causal analysis..."}, None, None)
|
230 |
|
231 |
result = run_causal_analysis(
|
232 |
query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
|
233 |
dataset_path=csv_path,
|
234 |
dataset_description=(dataset_description or "").strip(),
|
235 |
)
|
236 |
+
cache_run(query, result)
|
237 |
llm_html = _html_panel("🤖 Generating Summary...", "<div style='font-size:14px;color:#bbb;'>Creating human-readable interpretation…</div>")
|
238 |
+
yield (llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}}, None, None)
|
239 |
|
240 |
except Exception as e:
|
241 |
+
yield (_err_html(str(e)), "", "", {}, None, None)
|
242 |
return
|
243 |
|
244 |
try:
|
245 |
payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
|
246 |
+
|
247 |
method = payload.get("method_used", "N/A")
|
248 |
+
# --- Decision tree render ---
|
249 |
+
artifacts_dir = Path("artifacts")
|
250 |
+
artifacts_dir.mkdir(exist_ok=True)
|
251 |
+
ts = time.strftime("%Y%m%d-%H%M%S")
|
252 |
+
out_stem = str(artifacts_dir / f"decision_tree_{ts}")
|
253 |
+
|
254 |
+
# This creates: out_stem.dot, out_stem.svg, out_stem.png
|
255 |
|
256 |
method_html = _html_panel("Selected Method", f"<p style='margin:0;font-size:16px;'>{method}</p>")
|
257 |
|
|
|
268 |
explanation_html = _warn_html(f"LLM summary failed: {e}")
|
269 |
|
270 |
except Exception as e:
|
271 |
+
yield (_err_html(f"Failed to parse results: {e}"), "", "", {}, "", None)
|
272 |
return
|
273 |
|
274 |
# Optional email send (best-effort)
|
|
|
294 |
explanation_html += _warn_html(email_err)
|
295 |
else:
|
296 |
explanation_html += _ok_html(f"Results emailed to {email.strip()}")
|
297 |
+
render_from_json(result, out_stem)
|
298 |
|
299 |
+
tree_png = f"{out_stem}.png"
|
300 |
+
tree_svg = f"{out_stem}.svg"
|
301 |
+
tree_dot = f"{out_stem}.dot"
|
302 |
+
yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {}, tree_png, [tree_svg, tree_dot, tree_png])
|
303 |
|
304 |
with gr.Blocks() as demo:
|
305 |
gr.Markdown("# Causal AI Scientist")
|
|
|
383 |
with gr.Row():
|
384 |
explanation_out = gr.HTML(label="Detailed Explanation")
|
385 |
|
386 |
+
with gr.Row():
|
387 |
+
tree_img = gr.Image(label="Decision Tree", type="filepath")
|
388 |
+
with gr.Row():
|
389 |
+
tree_files = gr.Files(label="Download decision tree artifacts (.svg / .dot / .png)")
|
390 |
+
|
391 |
with gr.Accordion("Raw Results (Advanced)", open=False):
|
392 |
raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
|
393 |
|
394 |
run_btn.click(
|
395 |
fn=run_agent,
|
396 |
inputs=[query, csv_file, dataset_description, email],
|
397 |
+
outputs=[method_out, effects_out, explanation_out, raw_results, tree_img, tree_files],
|
398 |
show_progress=True
|
399 |
)
|
400 |
|
401 |
+
|
402 |
+
|
403 |
|
404 |
|
405 |
if __name__ == "__main__":
|
visualise.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Render a JSON-aware visualization of CAIS's rule-based method selector.
|
4 |
+
- Parses a CAIS run payload (dict) and highlights ALL plausible candidates (green).
|
5 |
+
- The actually selected method receives a thicker border.
|
6 |
+
- The traversed decision path edges are colored.
|
7 |
+
|
8 |
+
Usage:
|
9 |
+
render_from_json(payload_dict, out_stem="artifacts/decision_tree")
|
10 |
+
|
11 |
+
(Optional) CLI:
|
12 |
+
python decision_tree.py payload.json
|
13 |
+
"""
|
14 |
+
|
15 |
+
from graphviz import Digraph
|
16 |
+
import json, sys
|
17 |
+
from typing import Dict, Any, List, Set, Tuple, Optional
|
18 |
+
|
19 |
+
from auto_causal.components.decision_tree import (
|
20 |
+
DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY,
|
21 |
+
INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING,
|
22 |
+
GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT
|
23 |
+
)
|
24 |
+
|
25 |
+
LABEL = {
|
26 |
+
DIFF_IN_MEANS: "Diff-in-Means (RCT)",
|
27 |
+
LINEAR_REGRESSION: "Linear Regression",
|
28 |
+
DIFF_IN_DIFF: "Difference-in-Differences",
|
29 |
+
REGRESSION_DISCONTINUITY: "Regression Discontinuity",
|
30 |
+
INSTRUMENTAL_VARIABLE: "Instrumental Variables",
|
31 |
+
PROPENSITY_SCORE_MATCHING: "PS Matching",
|
32 |
+
PROPENSITY_SCORE_WEIGHTING: "PS Weighting",
|
33 |
+
GENERALIZED_PROPENSITY_SCORE: "Generalized PS (continuous T)",
|
34 |
+
BACKDOOR_ADJUSTMENT: "Backdoor Adjustment",
|
35 |
+
FRONTDOOR_ADJUSTMENT: "Frontdoor Adjustment",
|
36 |
+
}
|
37 |
+
|
38 |
+
# -------- Heuristic extractors from payload -------- #
|
39 |
+
|
40 |
+
def _get(d: Dict, path: List[str], default=None):
|
41 |
+
cur = d
|
42 |
+
for k in path:
|
43 |
+
if not isinstance(cur, dict) or k not in cur:
|
44 |
+
return default
|
45 |
+
cur = cur[k]
|
46 |
+
return cur
|
47 |
+
|
48 |
+
def extract_signals(p: Dict[str, Any]) -> Dict[str, Any]:
|
49 |
+
vars_ = _get(p, ["results", "variables"], {}) or _get(p, ["variables"], {}) or {}
|
50 |
+
da = _get(p, ["results", "dataset_analysis"], {}) or _get(p, ["dataset_analysis"], {}) or {}
|
51 |
+
|
52 |
+
treatment = vars_.get("treatment_variable")
|
53 |
+
t_type = vars_.get("treatment_variable_type") # "binary"/"continuous"
|
54 |
+
is_rct = bool(vars_.get("is_rct", False))
|
55 |
+
|
56 |
+
# Temporal / panel
|
57 |
+
temporal_detected = bool(da.get("temporal_structure_detected", False))
|
58 |
+
time_var = vars_.get("time_variable")
|
59 |
+
group_var = vars_.get("group_variable")
|
60 |
+
has_temporal = temporal_detected or bool(time_var) or bool(group_var)
|
61 |
+
|
62 |
+
# RDD
|
63 |
+
running_variable = vars_.get("running_variable")
|
64 |
+
cutoff_value = vars_.get("cutoff_value")
|
65 |
+
rdd_ready = running_variable is not None and cutoff_value is not None
|
66 |
+
# (Some detectors raise 'discontinuities_detected', but we still require running var + cutoff.)
|
67 |
+
# If you want permissive behavior, flip rdd_ready to also consider da.get("discontinuities_detected").
|
68 |
+
|
69 |
+
# Instruments
|
70 |
+
instrument = vars_.get("instrument_variable")
|
71 |
+
pot_instr = da.get("potential_instruments") or []
|
72 |
+
# Consider an instrument valid only if it exists and is NOT the treatment itself
|
73 |
+
has_valid_instrument = (
|
74 |
+
instrument is not None and instrument != treatment
|
75 |
+
) or any(pi and pi != treatment for pi in pot_instr)
|
76 |
+
|
77 |
+
covariates = vars_.get("covariates") or []
|
78 |
+
has_covariates = len(covariates) > 0
|
79 |
+
|
80 |
+
# Frontdoor: only mark if explicitly provided (else too speculative)
|
81 |
+
frontdoor_ok = bool(_get(p, ["results", "dataset_analysis", "frontdoor_satisfied"], False))
|
82 |
+
|
83 |
+
# Overlap: if explicitly known, use it; else unknown → both PS variants remain plausible.
|
84 |
+
overlap_assessment = da.get("overlap_assessment")
|
85 |
+
strong_overlap = None
|
86 |
+
if isinstance(overlap_assessment, dict):
|
87 |
+
# accept typical keys like {"strong_overlap": true}
|
88 |
+
strong_overlap = overlap_assessment.get("strong_overlap")
|
89 |
+
|
90 |
+
return dict(
|
91 |
+
treatment=treatment,
|
92 |
+
t_type=t_type,
|
93 |
+
is_rct=is_rct,
|
94 |
+
has_temporal=has_temporal,
|
95 |
+
rdd_ready=rdd_ready,
|
96 |
+
has_valid_instrument=has_valid_instrument,
|
97 |
+
has_covariates=has_covariates,
|
98 |
+
frontdoor_ok=frontdoor_ok,
|
99 |
+
strong_overlap=strong_overlap,
|
100 |
+
)
|
101 |
+
|
102 |
+
# -------- Candidate inference (green leaves) -------- #
|
103 |
+
|
104 |
+
def infer_candidate_methods(signals: Dict[str, Any]) -> Set[str]:
|
105 |
+
cands: Set[str] = set()
|
106 |
+
is_rct = signals["is_rct"]
|
107 |
+
|
108 |
+
# RCT branch: both Diff-in-Means and LR are valid analyses; IV only if a valid instrument exists (e.g., randomized encouragement)
|
109 |
+
if is_rct:
|
110 |
+
cands.add(DIFF_IN_MEANS)
|
111 |
+
if signals["has_covariates"]:
|
112 |
+
cands.add(LINEAR_REGRESSION)
|
113 |
+
if signals["has_valid_instrument"]:
|
114 |
+
cands.add(INSTRUMENTAL_VARIABLE)
|
115 |
+
return cands # stop here; the observational tree is not needed
|
116 |
+
|
117 |
+
# Observational branch
|
118 |
+
if signals["has_temporal"]:
|
119 |
+
cands.add(DIFF_IN_DIFF)
|
120 |
+
if signals["rdd_ready"]:
|
121 |
+
cands.add(REGRESSION_DISCONTINUITY)
|
122 |
+
if signals["has_valid_instrument"]:
|
123 |
+
cands.add(INSTRUMENTAL_VARIABLE)
|
124 |
+
if signals["frontdoor_ok"]:
|
125 |
+
cands.add(FRONTDOOR_ADJUSTMENT)
|
126 |
+
|
127 |
+
# Treatment type
|
128 |
+
if str(signals["t_type"]).lower() == "continuous":
|
129 |
+
cands.add(GENERALIZED_PROPENSITY_SCORE)
|
130 |
+
|
131 |
+
# Backdoor / PS (need covariates)
|
132 |
+
if signals["has_covariates"]:
|
133 |
+
# If overlap is known, choose one; if unknown, mark both as plausible.
|
134 |
+
if signals["strong_overlap"] is True:
|
135 |
+
cands.add(PROPENSITY_SCORE_MATCHING)
|
136 |
+
elif signals["strong_overlap"] is False:
|
137 |
+
cands.add(PROPENSITY_SCORE_WEIGHTING)
|
138 |
+
else:
|
139 |
+
cands.add(PROPENSITY_SCORE_MATCHING)
|
140 |
+
cands.add(PROPENSITY_SCORE_WEIGHTING)
|
141 |
+
cands.add(BACKDOOR_ADJUSTMENT)
|
142 |
+
|
143 |
+
return cands
|
144 |
+
|
145 |
+
# -------- Compute the single realized path to the chosen leaf (for edge coloring) -------- #
|
146 |
+
|
147 |
+
def infer_decision_path(signals: Dict[str, Any], selected_method: Optional[str]) -> List[Tuple[str, str]]:
|
148 |
+
path: List[Tuple[str, str]] = []
|
149 |
+
# Start → is_rct
|
150 |
+
path.append(("start", "is_rct"))
|
151 |
+
|
152 |
+
if signals["is_rct"]:
|
153 |
+
path.append(("is_rct", "has_instr_rct"))
|
154 |
+
if signals["has_valid_instrument"]:
|
155 |
+
path.append(("has_instr_rct", INSTRUMENTAL_VARIABLE))
|
156 |
+
else:
|
157 |
+
path.append(("has_instr_rct", "has_cov_rct"))
|
158 |
+
if signals["has_covariates"]:
|
159 |
+
path.append(("has_cov_rct", LINEAR_REGRESSION))
|
160 |
+
else:
|
161 |
+
path.append(("has_cov_rct", DIFF_IN_MEANS))
|
162 |
+
return path
|
163 |
+
|
164 |
+
# Observational
|
165 |
+
path.append(("is_rct", "has_temporal"))
|
166 |
+
if signals["has_temporal"]:
|
167 |
+
path.append(("has_temporal", DIFF_IN_DIFF))
|
168 |
+
return path
|
169 |
+
else:
|
170 |
+
path.append(("has_temporal", "has_rv"))
|
171 |
+
|
172 |
+
if signals["rdd_ready"]:
|
173 |
+
path.append(("has_rv", REGRESSION_DISCONTINUITY))
|
174 |
+
return path
|
175 |
+
else:
|
176 |
+
path.append(("has_rv", "has_instr"))
|
177 |
+
|
178 |
+
if signals["has_valid_instrument"]:
|
179 |
+
path.append(("has_instr", INSTRUMENTAL_VARIABLE))
|
180 |
+
return path
|
181 |
+
else:
|
182 |
+
path.append(("has_instr", "frontdoor"))
|
183 |
+
|
184 |
+
if signals["frontdoor_ok"]:
|
185 |
+
path.append(("frontdoor", FRONTDOOR_ADJUSTMENT))
|
186 |
+
return path
|
187 |
+
else:
|
188 |
+
path.append(("frontdoor", "t_cont"))
|
189 |
+
|
190 |
+
if str(signals["t_type"]).lower() == "continuous":
|
191 |
+
path.append(("t_cont", GENERALIZED_PROPENSITY_SCORE))
|
192 |
+
return path
|
193 |
+
else:
|
194 |
+
path.append(("t_cont", "has_cov"))
|
195 |
+
|
196 |
+
if signals["has_covariates"]:
|
197 |
+
path.append(("has_cov", "overlap"))
|
198 |
+
# If overlap known, pick the branch; else default to weighting.
|
199 |
+
if signals["strong_overlap"] is True:
|
200 |
+
path.append(("overlap", PROPENSITY_SCORE_MATCHING))
|
201 |
+
else:
|
202 |
+
path.append(("overlap", PROPENSITY_SCORE_WEIGHTING))
|
203 |
+
else:
|
204 |
+
path.append(("has_cov", BACKDOOR_ADJUSTMENT)) # keep original topology; see note in previous message
|
205 |
+
return path
|
206 |
+
|
207 |
+
# -------- Graph building -------- #
|
208 |
+
|
209 |
+
def build_graph(payload: Dict[str, Any]) -> Digraph:
|
210 |
+
g = Digraph("CAISDecisionTree", format="svg")
|
211 |
+
g.attr(rankdir="LR", nodesep="0.4", ranksep="0.35", fontsize="11")
|
212 |
+
|
213 |
+
# Decisions
|
214 |
+
g.node("start", "Start", shape="circle")
|
215 |
+
g.node("is_rct", "Is RCT?", shape="diamond")
|
216 |
+
g.node("has_instr_rct", "Instrument available?", shape="diamond")
|
217 |
+
g.node("has_cov_rct", "Covariates observed?", shape="diamond")
|
218 |
+
g.node("has_temporal", "Temporal structure?", shape="diamond")
|
219 |
+
g.node("has_rv", "Running var & cutoff?", shape="diamond")
|
220 |
+
g.node("has_instr", "Instrument available?", shape="diamond")
|
221 |
+
g.node("frontdoor", "Frontdoor criterion satisfied?", shape="diamond")
|
222 |
+
g.node("has_cov", "Covariates observed?", shape="diamond")
|
223 |
+
g.node("overlap", "Strong overlap?\n(overlap ≥ 0.1)", shape="diamond")
|
224 |
+
g.node("t_cont", "Treatment continuous?", shape="diamond")
|
225 |
+
|
226 |
+
# Leaves
|
227 |
+
def leaf(name_const, fill=None, bold=False):
|
228 |
+
attrs = {"shape": "box", "style": "rounded"}
|
229 |
+
if fill:
|
230 |
+
attrs.update(style="rounded,filled", fillcolor=fill)
|
231 |
+
if bold:
|
232 |
+
attrs.update(penwidth="2")
|
233 |
+
g.node(name_const, LABEL[name_const], **attrs)
|
234 |
+
|
235 |
+
# Compute signals, candidates, path
|
236 |
+
signals = extract_signals(payload)
|
237 |
+
candidates = infer_candidate_methods(signals)
|
238 |
+
|
239 |
+
selected_method_str = _get(payload, ["results", "results", "method_used"]) \
|
240 |
+
or _get(payload, ["results", "method_used"]) \
|
241 |
+
or _get(payload, ["method"])
|
242 |
+
selected_method = {
|
243 |
+
"linear_regression": LINEAR_REGRESSION,
|
244 |
+
"diff_in_means": DIFF_IN_MEANS,
|
245 |
+
"difference_in_differences": DIFF_IN_DIFF,
|
246 |
+
"regression_discontinuity": REGRESSION_DISCONTINUITY,
|
247 |
+
"instrumental_variable": INSTRUMENTAL_VARIABLE,
|
248 |
+
"propensity_score_matching": PROPENSITY_SCORE_MATCHING,
|
249 |
+
"propensity_score_weighting": PROPENSITY_SCORE_WEIGHTING,
|
250 |
+
"generalized_propensity_score": GENERALIZED_PROPENSITY_SCORE,
|
251 |
+
"backdoor_adjustment": BACKDOOR_ADJUSTMENT,
|
252 |
+
"frontdoor_adjustment": FRONTDOOR_ADJUSTMENT,
|
253 |
+
}.get(str(selected_method_str or "").lower())
|
254 |
+
|
255 |
+
# Add leaves with coloring
|
256 |
+
for m in [
|
257 |
+
DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY,
|
258 |
+
INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING,
|
259 |
+
GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT
|
260 |
+
]:
|
261 |
+
leaf(m,
|
262 |
+
fill=("palegreen" if m in candidates else None),
|
263 |
+
bold=(m == selected_method))
|
264 |
+
|
265 |
+
# Edges with optional path highlighting
|
266 |
+
path_edges = set(infer_decision_path(signals, selected_method))
|
267 |
+
def e(u, v, label=None):
|
268 |
+
attrs = {}
|
269 |
+
if (u, v) in path_edges:
|
270 |
+
attrs.update(color="forestgreen", penwidth="2")
|
271 |
+
g.edge(u, v, **({} if label is None else {"label": label}) | attrs)
|
272 |
+
|
273 |
+
# Topology (unchanged)
|
274 |
+
e("start", "is_rct")
|
275 |
+
|
276 |
+
# RCT branch
|
277 |
+
e("is_rct", "has_instr_rct", label="Yes")
|
278 |
+
e("has_instr_rct", INSTRUMENTAL_VARIABLE, label="Yes")
|
279 |
+
e("has_instr_rct", "has_cov_rct", label="No")
|
280 |
+
e("has_cov_rct", LINEAR_REGRESSION, label="Yes")
|
281 |
+
e("has_cov_rct", DIFF_IN_MEANS, label="No")
|
282 |
+
|
283 |
+
# Observational branch
|
284 |
+
e("is_rct", "has_temporal", label="No")
|
285 |
+
e("has_temporal", DIFF_IN_DIFF, label="Yes")
|
286 |
+
e("has_temporal", "has_rv", label="No")
|
287 |
+
|
288 |
+
e("has_rv", REGRESSION_DISCONTINUITY, label="Yes")
|
289 |
+
e("has_rv", "has_instr", label="No")
|
290 |
+
|
291 |
+
e("has_instr", INSTRUMENTAL_VARIABLE, label="Yes")
|
292 |
+
e("has_instr", "frontdoor", label="No")
|
293 |
+
|
294 |
+
e("frontdoor", FRONTDOOR_ADJUSTMENT, label="Yes")
|
295 |
+
e("frontdoor", "t_cont", label="No")
|
296 |
+
|
297 |
+
e("t_cont", GENERALIZED_PROPENSITY_SCORE, label="Yes")
|
298 |
+
e("t_cont", "has_cov", label="No")
|
299 |
+
|
300 |
+
e("has_cov", "overlap", label="Yes")
|
301 |
+
e("has_cov", BACKDOOR_ADJUSTMENT, label="No")
|
302 |
+
|
303 |
+
e("overlap", PROPENSITY_SCORE_MATCHING, label="Yes")
|
304 |
+
e("overlap", PROPENSITY_SCORE_WEIGHTING, label="No")
|
305 |
+
|
306 |
+
# Optional legend
|
307 |
+
g.node("legend", "Legend:\nGreen = plausible candidate(s)\nBold border = method used", shape="note")
|
308 |
+
g.edge("legend", "start", style="dashed", arrowhead="none")
|
309 |
+
|
310 |
+
return g
|
311 |
+
|
312 |
+
def render_from_json(payload: Dict[str, Any], out_stem: str = "artifacts/decision_tree"):
|
313 |
+
g = build_graph(payload)
|
314 |
+
g.save(filename=f"{out_stem}.dot")
|
315 |
+
g.render(filename=out_stem, cleanup=True) # SVG
|
316 |
+
g.format = "png"
|
317 |
+
g.render(filename=out_stem, cleanup=True) # PNG
|
318 |
+
|
319 |
+
def main():
|
320 |
+
# if len(sys.argv) >= 2:
|
321 |
+
with open('sample_output.json', "r") as f:
|
322 |
+
payload = json.load(f)
|
323 |
+
# else:
|
324 |
+
# payload = json.load()
|
325 |
+
render_from_json(payload)
|
326 |
+
|
327 |
+
if __name__ == "__main__":
|
328 |
+
main()
|