FireShadow commited on
Commit
efbcc96
·
1 Parent(s): b7dc123

added decision tree visualization

Browse files
Files changed (2) hide show
  1. app.py +103 -23
  2. visualise.py +328 -0
app.py CHANGED
@@ -6,6 +6,45 @@ import gradio as gr
6
  import time
7
  import smtplib
8
  from email.message import EmailMessage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Make your repo importable (expecting a folder named causal-agent at repo root)
11
  sys.path.append(str(Path(__file__).parent / "causal-agent"))
@@ -120,18 +159,33 @@ def _ok_html(text):
120
  return f"<div style='padding:10px;border:1px solid #2ea043;border-radius:5px;color:#2ea043;background-color:#333333;'>✅ {text}</div>"
121
 
122
  # --- Email support ---
123
- def send_email(recipient: str, subject: str, body_text: str, attachment_name: str = None, attachment_json: dict = None) -> str:
124
- """Returns '' on success, or error message string."""
125
- host = os.getenv("SMTP_HOST")
126
- port = int(os.getenv("SMTP_PORT", "587"))
127
- user = os.getenv("SMTP_USER")
128
- pwd = os.getenv("SMTP_PASS")
129
- from_addr = os.getenv("EMAIL_FROM")
130
 
131
- if not all([host, port, user, pwd, from_addr]):
132
- return "Email is not configured (set SMTP_HOST, SMTP_PORT, SMTP_USER, SMTP_PASS, EMAIL_FROM)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  try:
 
135
  msg = EmailMessage()
136
  msg["From"] = from_addr
137
  msg["To"] = recipient
@@ -142,10 +196,17 @@ def send_email(recipient: str, subject: str, body_text: str, attachment_name: st
142
  payload = json.dumps(attachment_json, indent=2).encode("utf-8")
143
  msg.add_attachment(payload, maintype="application", subtype="json", filename=attachment_name)
144
 
145
- with smtplib.SMTP(host, port, timeout=30) as s:
146
- s.starttls()
147
- s.login(user, pwd)
148
- s.send_message(msg)
 
 
 
 
 
 
 
149
  return ""
150
  except Exception as e:
151
  return f"Email send failed: {e}"
@@ -154,35 +215,43 @@ def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
154
  start = time.time()
155
 
156
  processing_html = _html_panel("🔄 Analysis in Progress...", "<div style='font-size:14px;color:#bbb;'>This may take 1–2 minutes depending on dataset size.</div>")
157
- yield (processing_html, processing_html, processing_html, {"status": "Processing started..."})
158
 
159
  if not os.getenv("OPENAI_API_KEY"):
160
- yield (_err_html("Set a Space Secret named OPENAI_API_KEY"), "", "", {})
161
  return
162
  if not csv_path:
163
- yield (_warn_html("Please upload a CSV dataset."), "", "", {})
164
  return
165
 
166
  try:
167
  step_html = _html_panel("📊 Running Causal Analysis...", "<div style='font-size:14px;color:#bbb;'>Analyzing dataset and selecting optimal method…</div>")
168
- yield (step_html, step_html, step_html, {"status": "Running causal analysis..."})
169
 
170
  result = run_causal_analysis(
171
  query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
172
  dataset_path=csv_path,
173
  dataset_description=(dataset_description or "").strip(),
174
  )
175
-
176
  llm_html = _html_panel("🤖 Generating Summary...", "<div style='font-size:14px;color:#bbb;'>Creating human-readable interpretation…</div>")
177
- yield (llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}})
178
 
179
  except Exception as e:
180
- yield (_err_html(str(e)), "", "", {})
181
  return
182
 
183
  try:
184
  payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
 
185
  method = payload.get("method_used", "N/A")
 
 
 
 
 
 
 
186
 
187
  method_html = _html_panel("Selected Method", f"<p style='margin:0;font-size:16px;'>{method}</p>")
188
 
@@ -199,7 +268,7 @@ def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
199
  explanation_html = _warn_html(f"LLM summary failed: {e}")
200
 
201
  except Exception as e:
202
- yield (_err_html(f"Failed to parse results: {e}"), "", "", {})
203
  return
204
 
205
  # Optional email send (best-effort)
@@ -225,8 +294,12 @@ def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
225
  explanation_html += _warn_html(email_err)
226
  else:
227
  explanation_html += _ok_html(f"Results emailed to {email.strip()}")
 
228
 
229
- yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {})
 
 
 
230
 
231
  with gr.Blocks() as demo:
232
  gr.Markdown("# Causal AI Scientist")
@@ -310,16 +383,23 @@ with gr.Blocks() as demo:
310
  with gr.Row():
311
  explanation_out = gr.HTML(label="Detailed Explanation")
312
 
 
 
 
 
 
313
  with gr.Accordion("Raw Results (Advanced)", open=False):
314
  raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
315
 
316
  run_btn.click(
317
  fn=run_agent,
318
  inputs=[query, csv_file, dataset_description, email],
319
- outputs=[method_out, effects_out, explanation_out, raw_results],
320
  show_progress=True
321
  )
322
 
 
 
323
 
324
 
325
  if __name__ == "__main__":
 
6
  import time
7
  import smtplib
8
  from email.message import EmailMessage
9
+ from visualise import render_from_json
10
+ from pathlib import Path
11
+ import time
12
+ import os, json, time, tempfile
13
+ from huggingface_hub import HfApi, HfFileSystem, create_repo
14
+
15
+ REPO = "CausalNLP/cais-demo-cache" # dataset repo id
16
+ TOKEN = os.environ["HF_WRITE_TOKEN"] # set as Space secret
17
+ api = HfApi(token=TOKEN)
18
+ fs = HfFileSystem(token=TOKEN)
19
+
20
+ # 1) ensure repo exists
21
+ create_repo(REPO, repo_type="dataset", private=True, exist_ok=True, token=TOKEN)
22
+
23
+ def cache_run(query, payload, artifacts=None):
24
+ ts = time.strftime("%Y-%m-%dT%H:%M:%S")
25
+ row = {"timestamp": ts, "query": query, "payload": payload, "artifacts": artifacts or {}}
26
+
27
+ hub_path = f"datasets/{REPO}/logs.jsonl"
28
+ # 2) download existing (if any), append, and push in one commit
29
+ with tempfile.TemporaryDirectory() as td:
30
+ local = os.path.join(td, "logs.jsonl")
31
+ try:
32
+ with fs.open(hub_path, "rb") as fsrc, open(local, "wb") as fdst:
33
+ fdst.write(fsrc.read())
34
+ except FileNotFoundError:
35
+ open(local, "w").close()
36
+
37
+ with open(local, "a", encoding="utf-8") as f:
38
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
39
+
40
+ api.upload_file(
41
+ path_or_fileobj=local,
42
+ path_in_repo="logs.jsonl",
43
+ repo_id=REPO,
44
+ repo_type="dataset",
45
+ commit_message=f"append log {ts}"
46
+ )
47
+
48
 
49
  # Make your repo importable (expecting a folder named causal-agent at repo root)
50
  sys.path.append(str(Path(__file__).parent / "causal-agent"))
 
159
  return f"<div style='padding:10px;border:1px solid #2ea043;border-radius:5px;color:#2ea043;background-color:#333333;'>✅ {text}</div>"
160
 
161
  # --- Email support ---
162
+ import base64, json, requests
163
+ from email.message import EmailMessage
 
 
 
 
 
164
 
165
+ def _gmail_access_token() -> str:
166
+ token_url = "https://oauth2.googleapis.com/token"
167
+ data = {
168
+ "client_id": os.getenv("GMAIL_CLIENT_ID"),
169
+ "client_secret": os.getenv("GMAIL_CLIENT_SECRET"),
170
+ "refresh_token": os.getenv("GMAIL_REFRESH_TOKEN"),
171
+ "grant_type": "refresh_token",
172
+ }
173
+ r = requests.post(token_url, data=data, timeout=20)
174
+ r.raise_for_status()
175
+ return r.json()["access_token"]
176
+
177
+ def send_email(recipient: str, subject: str, body_text: str,
178
+ attachment_name: str = None, attachment_json: dict = None) -> str:
179
+ """
180
+ Sends via Gmail API. Returns '' on success, or an error string.
181
+ """
182
+ from_addr = os.getenv("EMAIL_FROM")
183
+ if not all([os.getenv("GMAIL_CLIENT_ID"), os.getenv("GMAIL_CLIENT_SECRET"),
184
+ os.getenv("GMAIL_REFRESH_TOKEN"), from_addr]):
185
+ return "Gmail API not configured (set GMAIL_CLIENT_ID, GMAIL_CLIENT_SECRET, GMAIL_REFRESH_TOKEN, EMAIL_FROM)."
186
 
187
  try:
188
+ # Build MIME message
189
  msg = EmailMessage()
190
  msg["From"] = from_addr
191
  msg["To"] = recipient
 
196
  payload = json.dumps(attachment_json, indent=2).encode("utf-8")
197
  msg.add_attachment(payload, maintype="application", subtype="json", filename=attachment_name)
198
 
199
+ # Base64url encode the raw RFC822 message
200
+ raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
201
+
202
+ # Get access token and send
203
+ access_token = _gmail_access_token()
204
+ api_url = "https://gmail.googleapis.com/gmail/v1/users/me/messages/send"
205
+ headers = {"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}
206
+ r = requests.post(api_url, headers=headers, json={"raw": raw}, timeout=20)
207
+
208
+ if r.status_code >= 400:
209
+ return f"Gmail API error {r.status_code}: {r.text[:300]}"
210
  return ""
211
  except Exception as e:
212
  return f"Email send failed: {e}"
 
215
  start = time.time()
216
 
217
  processing_html = _html_panel("🔄 Analysis in Progress...", "<div style='font-size:14px;color:#bbb;'>This may take 1–2 minutes depending on dataset size.</div>")
218
+ yield (processing_html, processing_html, processing_html, {"status": "Processing started..."}, None, None)
219
 
220
  if not os.getenv("OPENAI_API_KEY"):
221
+ yield (_err_html("Set a Space Secret named OPENAI_API_KEY"), "", "", {}, None, None)
222
  return
223
  if not csv_path:
224
+ yield (_warn_html("Please upload a CSV dataset."), "", "", {}, None, None)
225
  return
226
 
227
  try:
228
  step_html = _html_panel("📊 Running Causal Analysis...", "<div style='font-size:14px;color:#bbb;'>Analyzing dataset and selecting optimal method…</div>")
229
+ yield (step_html, step_html, step_html, {"status": "Running causal analysis..."}, None, None)
230
 
231
  result = run_causal_analysis(
232
  query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
233
  dataset_path=csv_path,
234
  dataset_description=(dataset_description or "").strip(),
235
  )
236
+ cache_run(query, result)
237
  llm_html = _html_panel("🤖 Generating Summary...", "<div style='font-size:14px;color:#bbb;'>Creating human-readable interpretation…</div>")
238
+ yield (llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}}, None, None)
239
 
240
  except Exception as e:
241
+ yield (_err_html(str(e)), "", "", {}, None, None)
242
  return
243
 
244
  try:
245
  payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
246
+
247
  method = payload.get("method_used", "N/A")
248
+ # --- Decision tree render ---
249
+ artifacts_dir = Path("artifacts")
250
+ artifacts_dir.mkdir(exist_ok=True)
251
+ ts = time.strftime("%Y%m%d-%H%M%S")
252
+ out_stem = str(artifacts_dir / f"decision_tree_{ts}")
253
+
254
+ # This creates: out_stem.dot, out_stem.svg, out_stem.png
255
 
256
  method_html = _html_panel("Selected Method", f"<p style='margin:0;font-size:16px;'>{method}</p>")
257
 
 
268
  explanation_html = _warn_html(f"LLM summary failed: {e}")
269
 
270
  except Exception as e:
271
+ yield (_err_html(f"Failed to parse results: {e}"), "", "", {}, "", None)
272
  return
273
 
274
  # Optional email send (best-effort)
 
294
  explanation_html += _warn_html(email_err)
295
  else:
296
  explanation_html += _ok_html(f"Results emailed to {email.strip()}")
297
+ render_from_json(result, out_stem)
298
 
299
+ tree_png = f"{out_stem}.png"
300
+ tree_svg = f"{out_stem}.svg"
301
+ tree_dot = f"{out_stem}.dot"
302
+ yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {}, tree_png, [tree_svg, tree_dot, tree_png])
303
 
304
  with gr.Blocks() as demo:
305
  gr.Markdown("# Causal AI Scientist")
 
383
  with gr.Row():
384
  explanation_out = gr.HTML(label="Detailed Explanation")
385
 
386
+ with gr.Row():
387
+ tree_img = gr.Image(label="Decision Tree", type="filepath")
388
+ with gr.Row():
389
+ tree_files = gr.Files(label="Download decision tree artifacts (.svg / .dot / .png)")
390
+
391
  with gr.Accordion("Raw Results (Advanced)", open=False):
392
  raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
393
 
394
  run_btn.click(
395
  fn=run_agent,
396
  inputs=[query, csv_file, dataset_description, email],
397
+ outputs=[method_out, effects_out, explanation_out, raw_results, tree_img, tree_files],
398
  show_progress=True
399
  )
400
 
401
+
402
+
403
 
404
 
405
  if __name__ == "__main__":
visualise.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Render a JSON-aware visualization of CAIS's rule-based method selector.
4
+ - Parses a CAIS run payload (dict) and highlights ALL plausible candidates (green).
5
+ - The actually selected method receives a thicker border.
6
+ - The traversed decision path edges are colored.
7
+
8
+ Usage:
9
+ render_from_json(payload_dict, out_stem="artifacts/decision_tree")
10
+
11
+ (Optional) CLI:
12
+ python decision_tree.py payload.json
13
+ """
14
+
15
+ from graphviz import Digraph
16
+ import json, sys
17
+ from typing import Dict, Any, List, Set, Tuple, Optional
18
+
19
+ from auto_causal.components.decision_tree import (
20
+ DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY,
21
+ INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING,
22
+ GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT
23
+ )
24
+
25
+ LABEL = {
26
+ DIFF_IN_MEANS: "Diff-in-Means (RCT)",
27
+ LINEAR_REGRESSION: "Linear Regression",
28
+ DIFF_IN_DIFF: "Difference-in-Differences",
29
+ REGRESSION_DISCONTINUITY: "Regression Discontinuity",
30
+ INSTRUMENTAL_VARIABLE: "Instrumental Variables",
31
+ PROPENSITY_SCORE_MATCHING: "PS Matching",
32
+ PROPENSITY_SCORE_WEIGHTING: "PS Weighting",
33
+ GENERALIZED_PROPENSITY_SCORE: "Generalized PS (continuous T)",
34
+ BACKDOOR_ADJUSTMENT: "Backdoor Adjustment",
35
+ FRONTDOOR_ADJUSTMENT: "Frontdoor Adjustment",
36
+ }
37
+
38
+ # -------- Heuristic extractors from payload -------- #
39
+
40
+ def _get(d: Dict, path: List[str], default=None):
41
+ cur = d
42
+ for k in path:
43
+ if not isinstance(cur, dict) or k not in cur:
44
+ return default
45
+ cur = cur[k]
46
+ return cur
47
+
48
+ def extract_signals(p: Dict[str, Any]) -> Dict[str, Any]:
49
+ vars_ = _get(p, ["results", "variables"], {}) or _get(p, ["variables"], {}) or {}
50
+ da = _get(p, ["results", "dataset_analysis"], {}) or _get(p, ["dataset_analysis"], {}) or {}
51
+
52
+ treatment = vars_.get("treatment_variable")
53
+ t_type = vars_.get("treatment_variable_type") # "binary"/"continuous"
54
+ is_rct = bool(vars_.get("is_rct", False))
55
+
56
+ # Temporal / panel
57
+ temporal_detected = bool(da.get("temporal_structure_detected", False))
58
+ time_var = vars_.get("time_variable")
59
+ group_var = vars_.get("group_variable")
60
+ has_temporal = temporal_detected or bool(time_var) or bool(group_var)
61
+
62
+ # RDD
63
+ running_variable = vars_.get("running_variable")
64
+ cutoff_value = vars_.get("cutoff_value")
65
+ rdd_ready = running_variable is not None and cutoff_value is not None
66
+ # (Some detectors raise 'discontinuities_detected', but we still require running var + cutoff.)
67
+ # If you want permissive behavior, flip rdd_ready to also consider da.get("discontinuities_detected").
68
+
69
+ # Instruments
70
+ instrument = vars_.get("instrument_variable")
71
+ pot_instr = da.get("potential_instruments") or []
72
+ # Consider an instrument valid only if it exists and is NOT the treatment itself
73
+ has_valid_instrument = (
74
+ instrument is not None and instrument != treatment
75
+ ) or any(pi and pi != treatment for pi in pot_instr)
76
+
77
+ covariates = vars_.get("covariates") or []
78
+ has_covariates = len(covariates) > 0
79
+
80
+ # Frontdoor: only mark if explicitly provided (else too speculative)
81
+ frontdoor_ok = bool(_get(p, ["results", "dataset_analysis", "frontdoor_satisfied"], False))
82
+
83
+ # Overlap: if explicitly known, use it; else unknown → both PS variants remain plausible.
84
+ overlap_assessment = da.get("overlap_assessment")
85
+ strong_overlap = None
86
+ if isinstance(overlap_assessment, dict):
87
+ # accept typical keys like {"strong_overlap": true}
88
+ strong_overlap = overlap_assessment.get("strong_overlap")
89
+
90
+ return dict(
91
+ treatment=treatment,
92
+ t_type=t_type,
93
+ is_rct=is_rct,
94
+ has_temporal=has_temporal,
95
+ rdd_ready=rdd_ready,
96
+ has_valid_instrument=has_valid_instrument,
97
+ has_covariates=has_covariates,
98
+ frontdoor_ok=frontdoor_ok,
99
+ strong_overlap=strong_overlap,
100
+ )
101
+
102
+ # -------- Candidate inference (green leaves) -------- #
103
+
104
+ def infer_candidate_methods(signals: Dict[str, Any]) -> Set[str]:
105
+ cands: Set[str] = set()
106
+ is_rct = signals["is_rct"]
107
+
108
+ # RCT branch: both Diff-in-Means and LR are valid analyses; IV only if a valid instrument exists (e.g., randomized encouragement)
109
+ if is_rct:
110
+ cands.add(DIFF_IN_MEANS)
111
+ if signals["has_covariates"]:
112
+ cands.add(LINEAR_REGRESSION)
113
+ if signals["has_valid_instrument"]:
114
+ cands.add(INSTRUMENTAL_VARIABLE)
115
+ return cands # stop here; the observational tree is not needed
116
+
117
+ # Observational branch
118
+ if signals["has_temporal"]:
119
+ cands.add(DIFF_IN_DIFF)
120
+ if signals["rdd_ready"]:
121
+ cands.add(REGRESSION_DISCONTINUITY)
122
+ if signals["has_valid_instrument"]:
123
+ cands.add(INSTRUMENTAL_VARIABLE)
124
+ if signals["frontdoor_ok"]:
125
+ cands.add(FRONTDOOR_ADJUSTMENT)
126
+
127
+ # Treatment type
128
+ if str(signals["t_type"]).lower() == "continuous":
129
+ cands.add(GENERALIZED_PROPENSITY_SCORE)
130
+
131
+ # Backdoor / PS (need covariates)
132
+ if signals["has_covariates"]:
133
+ # If overlap is known, choose one; if unknown, mark both as plausible.
134
+ if signals["strong_overlap"] is True:
135
+ cands.add(PROPENSITY_SCORE_MATCHING)
136
+ elif signals["strong_overlap"] is False:
137
+ cands.add(PROPENSITY_SCORE_WEIGHTING)
138
+ else:
139
+ cands.add(PROPENSITY_SCORE_MATCHING)
140
+ cands.add(PROPENSITY_SCORE_WEIGHTING)
141
+ cands.add(BACKDOOR_ADJUSTMENT)
142
+
143
+ return cands
144
+
145
+ # -------- Compute the single realized path to the chosen leaf (for edge coloring) -------- #
146
+
147
+ def infer_decision_path(signals: Dict[str, Any], selected_method: Optional[str]) -> List[Tuple[str, str]]:
148
+ path: List[Tuple[str, str]] = []
149
+ # Start → is_rct
150
+ path.append(("start", "is_rct"))
151
+
152
+ if signals["is_rct"]:
153
+ path.append(("is_rct", "has_instr_rct"))
154
+ if signals["has_valid_instrument"]:
155
+ path.append(("has_instr_rct", INSTRUMENTAL_VARIABLE))
156
+ else:
157
+ path.append(("has_instr_rct", "has_cov_rct"))
158
+ if signals["has_covariates"]:
159
+ path.append(("has_cov_rct", LINEAR_REGRESSION))
160
+ else:
161
+ path.append(("has_cov_rct", DIFF_IN_MEANS))
162
+ return path
163
+
164
+ # Observational
165
+ path.append(("is_rct", "has_temporal"))
166
+ if signals["has_temporal"]:
167
+ path.append(("has_temporal", DIFF_IN_DIFF))
168
+ return path
169
+ else:
170
+ path.append(("has_temporal", "has_rv"))
171
+
172
+ if signals["rdd_ready"]:
173
+ path.append(("has_rv", REGRESSION_DISCONTINUITY))
174
+ return path
175
+ else:
176
+ path.append(("has_rv", "has_instr"))
177
+
178
+ if signals["has_valid_instrument"]:
179
+ path.append(("has_instr", INSTRUMENTAL_VARIABLE))
180
+ return path
181
+ else:
182
+ path.append(("has_instr", "frontdoor"))
183
+
184
+ if signals["frontdoor_ok"]:
185
+ path.append(("frontdoor", FRONTDOOR_ADJUSTMENT))
186
+ return path
187
+ else:
188
+ path.append(("frontdoor", "t_cont"))
189
+
190
+ if str(signals["t_type"]).lower() == "continuous":
191
+ path.append(("t_cont", GENERALIZED_PROPENSITY_SCORE))
192
+ return path
193
+ else:
194
+ path.append(("t_cont", "has_cov"))
195
+
196
+ if signals["has_covariates"]:
197
+ path.append(("has_cov", "overlap"))
198
+ # If overlap known, pick the branch; else default to weighting.
199
+ if signals["strong_overlap"] is True:
200
+ path.append(("overlap", PROPENSITY_SCORE_MATCHING))
201
+ else:
202
+ path.append(("overlap", PROPENSITY_SCORE_WEIGHTING))
203
+ else:
204
+ path.append(("has_cov", BACKDOOR_ADJUSTMENT)) # keep original topology; see note in previous message
205
+ return path
206
+
207
+ # -------- Graph building -------- #
208
+
209
+ def build_graph(payload: Dict[str, Any]) -> Digraph:
210
+ g = Digraph("CAISDecisionTree", format="svg")
211
+ g.attr(rankdir="LR", nodesep="0.4", ranksep="0.35", fontsize="11")
212
+
213
+ # Decisions
214
+ g.node("start", "Start", shape="circle")
215
+ g.node("is_rct", "Is RCT?", shape="diamond")
216
+ g.node("has_instr_rct", "Instrument available?", shape="diamond")
217
+ g.node("has_cov_rct", "Covariates observed?", shape="diamond")
218
+ g.node("has_temporal", "Temporal structure?", shape="diamond")
219
+ g.node("has_rv", "Running var & cutoff?", shape="diamond")
220
+ g.node("has_instr", "Instrument available?", shape="diamond")
221
+ g.node("frontdoor", "Frontdoor criterion satisfied?", shape="diamond")
222
+ g.node("has_cov", "Covariates observed?", shape="diamond")
223
+ g.node("overlap", "Strong overlap?\n(overlap ≥ 0.1)", shape="diamond")
224
+ g.node("t_cont", "Treatment continuous?", shape="diamond")
225
+
226
+ # Leaves
227
+ def leaf(name_const, fill=None, bold=False):
228
+ attrs = {"shape": "box", "style": "rounded"}
229
+ if fill:
230
+ attrs.update(style="rounded,filled", fillcolor=fill)
231
+ if bold:
232
+ attrs.update(penwidth="2")
233
+ g.node(name_const, LABEL[name_const], **attrs)
234
+
235
+ # Compute signals, candidates, path
236
+ signals = extract_signals(payload)
237
+ candidates = infer_candidate_methods(signals)
238
+
239
+ selected_method_str = _get(payload, ["results", "results", "method_used"]) \
240
+ or _get(payload, ["results", "method_used"]) \
241
+ or _get(payload, ["method"])
242
+ selected_method = {
243
+ "linear_regression": LINEAR_REGRESSION,
244
+ "diff_in_means": DIFF_IN_MEANS,
245
+ "difference_in_differences": DIFF_IN_DIFF,
246
+ "regression_discontinuity": REGRESSION_DISCONTINUITY,
247
+ "instrumental_variable": INSTRUMENTAL_VARIABLE,
248
+ "propensity_score_matching": PROPENSITY_SCORE_MATCHING,
249
+ "propensity_score_weighting": PROPENSITY_SCORE_WEIGHTING,
250
+ "generalized_propensity_score": GENERALIZED_PROPENSITY_SCORE,
251
+ "backdoor_adjustment": BACKDOOR_ADJUSTMENT,
252
+ "frontdoor_adjustment": FRONTDOOR_ADJUSTMENT,
253
+ }.get(str(selected_method_str or "").lower())
254
+
255
+ # Add leaves with coloring
256
+ for m in [
257
+ DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY,
258
+ INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING,
259
+ GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT
260
+ ]:
261
+ leaf(m,
262
+ fill=("palegreen" if m in candidates else None),
263
+ bold=(m == selected_method))
264
+
265
+ # Edges with optional path highlighting
266
+ path_edges = set(infer_decision_path(signals, selected_method))
267
+ def e(u, v, label=None):
268
+ attrs = {}
269
+ if (u, v) in path_edges:
270
+ attrs.update(color="forestgreen", penwidth="2")
271
+ g.edge(u, v, **({} if label is None else {"label": label}) | attrs)
272
+
273
+ # Topology (unchanged)
274
+ e("start", "is_rct")
275
+
276
+ # RCT branch
277
+ e("is_rct", "has_instr_rct", label="Yes")
278
+ e("has_instr_rct", INSTRUMENTAL_VARIABLE, label="Yes")
279
+ e("has_instr_rct", "has_cov_rct", label="No")
280
+ e("has_cov_rct", LINEAR_REGRESSION, label="Yes")
281
+ e("has_cov_rct", DIFF_IN_MEANS, label="No")
282
+
283
+ # Observational branch
284
+ e("is_rct", "has_temporal", label="No")
285
+ e("has_temporal", DIFF_IN_DIFF, label="Yes")
286
+ e("has_temporal", "has_rv", label="No")
287
+
288
+ e("has_rv", REGRESSION_DISCONTINUITY, label="Yes")
289
+ e("has_rv", "has_instr", label="No")
290
+
291
+ e("has_instr", INSTRUMENTAL_VARIABLE, label="Yes")
292
+ e("has_instr", "frontdoor", label="No")
293
+
294
+ e("frontdoor", FRONTDOOR_ADJUSTMENT, label="Yes")
295
+ e("frontdoor", "t_cont", label="No")
296
+
297
+ e("t_cont", GENERALIZED_PROPENSITY_SCORE, label="Yes")
298
+ e("t_cont", "has_cov", label="No")
299
+
300
+ e("has_cov", "overlap", label="Yes")
301
+ e("has_cov", BACKDOOR_ADJUSTMENT, label="No")
302
+
303
+ e("overlap", PROPENSITY_SCORE_MATCHING, label="Yes")
304
+ e("overlap", PROPENSITY_SCORE_WEIGHTING, label="No")
305
+
306
+ # Optional legend
307
+ g.node("legend", "Legend:\nGreen = plausible candidate(s)\nBold border = method used", shape="note")
308
+ g.edge("legend", "start", style="dashed", arrowhead="none")
309
+
310
+ return g
311
+
312
+ def render_from_json(payload: Dict[str, Any], out_stem: str = "artifacts/decision_tree"):
313
+ g = build_graph(payload)
314
+ g.save(filename=f"{out_stem}.dot")
315
+ g.render(filename=out_stem, cleanup=True) # SVG
316
+ g.format = "png"
317
+ g.render(filename=out_stem, cleanup=True) # PNG
318
+
319
+ def main():
320
+ # if len(sys.argv) >= 2:
321
+ with open('sample_output.json', "r") as f:
322
+ payload = json.load(f)
323
+ # else:
324
+ # payload = json.load()
325
+ render_from_json(payload)
326
+
327
+ if __name__ == "__main__":
328
+ main()