AbstractPhil commited on
Commit
5e20a2a
·
verified ·
1 Parent(s): 0a14990

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -88
app.py CHANGED
@@ -1,6 +1,7 @@
1
- # app.py – encoder-only + masking accuracy demo for bert-beatrix-2048
2
- # -----------------------------------------------------------------
3
- # launch: python app.py (UI at http://localhost:7860)
 
4
 
5
  import json, re, sys
6
  from pathlib import Path, PurePosixPath
@@ -9,37 +10,48 @@ import gradio as gr
9
  import spaces
10
  import torch
11
  from huggingface_hub import snapshot_download
 
12
  from bert_handler import create_handler_from_checkpoint
13
 
 
14
  # ------------------------------------------------------------------
15
- # 0. download repo + patch auto_map --------------------------------
16
- REPO_ID = "AbstractPhil/bert-beatrix-2048"
17
- LOCAL_CK = "bert-beatrix-2048"
18
- snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_CK, local_dir_use_symlinks=False)
19
-
20
- cfg_p = Path(LOCAL_CK) / "config.json"
21
- with cfg_p.open() as f:
22
- cfg = json.load(f)
23
- for k, v in cfg.get("auto_map", {}).items():
 
 
 
 
24
  if "--" in v:
25
- cfg["auto_map"][k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
26
- with cfg_p.open("w") as f:
27
- json.dump(cfg, f, indent=2)
 
 
 
28
 
 
 
 
29
  # ------------------------------------------------------------------
30
- # 1. load model / tokenizer ---------------------------------------
31
- handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CK)
32
  full_model = full_model.eval().cuda()
33
 
34
- encoder = full_model.bert.encoder
35
- embeddings = full_model.bert.embeddings
36
- emb_ln = full_model.bert.emb_ln
37
- emb_drop = full_model.bert.emb_drop
38
 
39
- MASK = tokenizer.mask_token or "[MASK]"
40
 
41
  # ------------------------------------------------------------------
42
- # 2. symbolic role list -------------------------------------------
 
43
  SYMBOLIC_ROLES = [
44
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
45
  "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
@@ -48,96 +60,108 @@ SYMBOLIC_ROLES = [
48
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
49
  "<fabric>", "<jewelry>",
50
  ]
51
- miss = [t for t in SYMBOLIC_ROLES
52
- if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
53
- if miss:
54
- sys.exit(f"❌ Tokenizer missing {miss}")
 
 
55
 
56
  # ------------------------------------------------------------------
57
- # 3. inference util ----------------------------------------------
 
 
 
58
  @spaces.GPU
59
- def encode_and_trace(text: str, selected_roles: list[str]):
60
- # ----- 3-A. build masked version & encode original --------------
61
- sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
 
 
 
 
 
 
62
 
63
- # tokenised “plain” text
64
- plain = tokenizer(text, return_tensors="pt").to("cuda")
65
- ids_plain = plain.input_ids
 
66
 
67
- # make masked string (regex to avoid partial hits)
68
- masked_txt = text
69
- for tok in selected_roles:
70
- masked_txt = re.sub(re.escape(tok), MASK, masked_txt)
71
 
72
- masked = tokenizer(masked_txt, return_tensors="pt").to("cuda")
73
- ids_masked = masked.input_ids
 
74
 
75
- # ----- 3-B. run model on masked text ----------------------------
76
- with torch.no_grad():
77
- logits = full_model(**masked).logits[0] # (S, V)
78
- preds = logits.argmax(-1) # (S,)
79
-
80
- # ----- 3-C. gather stats per masked role ------------------------
81
- found_tokens, correct = [], 0
82
- role_flags = []
83
- for i, (orig_id, pred_id) in enumerate(zip(ids_plain[0], preds)):
84
- if orig_id.item() in sel_ids and ids_masked[0, i].item() == tokenizer.mask_token_id:
85
- found_tokens.append(tokenizer.convert_ids_to_tokens([orig_id])[0])
86
- correct += int(orig_id.item() == pred_id.item())
87
- role_flags.append(i)
88
-
89
- total = len(role_flags)
90
- acc = correct / total if total else 0.0
91
-
92
- # ----- 3-D. encoder rep pooling for *all* selected roles --------
93
- with torch.no_grad():
94
- # embeddings -> normed reps
95
- x = emb_drop(emb_ln(embeddings(ids_plain)))
96
- attn = full_model.bert.get_extended_attention_mask(
97
- plain.attention_mask, x.shape[:-1]
98
- )
99
- enc = encoder(x, attention_mask=attn) # (1,S,H)
100
- mask_vec = torch.tensor(
101
- [tid in sel_ids for tid in ids_plain[0].tolist()], device=enc.device
102
- )
103
- if mask_vec.any():
104
- pooled = enc[0][mask_vec].mean(0)
105
- norm = f"{pooled.norm().item():.4f}"
106
  else:
107
- norm = "0.0000"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- tokens_str = ", ".join(found_tokens) or "(none)"
110
- return tokens_str, norm, f"{acc*100:.1f}%"
111
 
112
  # ------------------------------------------------------------------
113
- # 4. gradio UI ----------------------------------------------------
114
- def app():
 
115
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
116
  gr.Markdown(
117
- "## 🧠 Symbolic Encoder Inspector \n"
118
- "1. Model side: we *mask* every chosen role token, run the LM, and report how often it recovers the original. \n"
119
- "2. Encoder side: we also pool hidden-state vectors for those roles and give their mean L2-norm."
120
  )
 
121
  with gr.Row():
122
  with gr.Column():
123
  txt = gr.Textbox(
124
  label="Input with Symbolic Tokens",
125
- lines=3,
126
  placeholder="Example: A <subject> wearing <upper_body_clothing> …",
 
127
  )
 
128
  roles = gr.CheckboxGroup(
129
  choices=SYMBOLIC_ROLES,
130
- value=SYMBOLIC_ROLES, # <- all pre-selected
131
- label="Roles to mask & trace",
 
132
  )
133
- run = gr.Button("Run")
134
  with gr.Column():
135
- o_tok = gr.Textbox(label="Masked-role tokens found")
136
- o_norm = gr.Textbox(label="Mean hidden-state L2-norm")
137
- o_acc = gr.Textbox(label="Recovery accuracy")
 
 
 
 
 
138
 
139
- run.click(encode_and_trace, [txt, roles], [o_tok, o_norm, o_acc])
140
  return demo
141
 
 
142
  if __name__ == "__main__":
143
- app().launch()
 
1
+ # app.py – encoder-only demo for bert-beatrix-2048
2
+ # ------------------------------------------------------------------
3
+ # launch: python app.py
4
+ # ------------------------------------------------------------------
5
 
6
  import json, re, sys
7
  from pathlib import Path, PurePosixPath
 
10
  import spaces
11
  import torch
12
  from huggingface_hub import snapshot_download
13
+
14
  from bert_handler import create_handler_from_checkpoint
15
 
16
+
17
  # ------------------------------------------------------------------
18
+ # 0. Download & patch config.json --------------------------------
19
+ # ------------------------------------------------------------------
20
+ REPO_ID = "AbstractPhil/bert-beatrix-2048"
21
+ LOCAL_DIR = "bert-beatrix-2048"
22
+
23
+ snapshot_download(REPO_ID, revision="main",
24
+ local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
25
+
26
+ cfg_path = Path(LOCAL_DIR) / "config.json"
27
+ cfg = json.loads(cfg_path.read_text())
28
+
29
+ auto_map, changed = cfg.get("auto_map", {}), False
30
+ for k, v in auto_map.items():
31
  if "--" in v:
32
+ auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
33
+ changed = True
34
+ if changed:
35
+ cfg["auto_map"] = auto_map
36
+ cfg_path.write_text(json.dumps(cfg, indent=2))
37
+ print("🛠️ Patched config.json → auto_map now points at local modules")
38
 
39
+
40
+ # ------------------------------------------------------------------
41
+ # 1. Model / tokenizer -------------------------------------------
42
  # ------------------------------------------------------------------
43
+ handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_DIR)
 
44
  full_model = full_model.eval().cuda()
45
 
46
+ encoder = full_model.bert.encoder
47
+ embeddings = full_model.bert.embeddings
48
+ emb_ln = full_model.bert.emb_ln
49
+ emb_drop = full_model.bert.emb_drop
50
 
 
51
 
52
  # ------------------------------------------------------------------
53
+ # 2. Symbolic token set ------------------------------------------
54
+ # ------------------------------------------------------------------
55
  SYMBOLIC_ROLES = [
56
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
57
  "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
 
60
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
61
  "<fabric>", "<jewelry>",
62
  ]
63
+
64
+ missing = [t for t in SYMBOLIC_ROLES
65
+ if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
66
+ if missing:
67
+ sys.exit(f"❌ Tokenizer is missing {missing}")
68
+
69
 
70
  # ------------------------------------------------------------------
71
+ # 3. Encoder + *mask-inference* util ------------------------------
72
+ # ------------------------------------------------------------------
73
+ MASK = tokenizer.mask_token or "[MASK]"
74
+
75
  @spaces.GPU
76
+ def encode_and_trace(text: str, _ignored): # all roles auto-selected
77
+ """
78
+ 1. run encoder pass → cosine report (as before)
79
+ 2. mask **every** symbolic token one-at-a-time
80
+ and ask the full model to predict it back.
81
+ Accuracy over those positions is returned.
82
+ """
83
+ if not text.strip():
84
+ return "(empty)", "0.0000", 0, "0 / 0 (0.0%)"
85
 
86
+ with torch.no_grad():
87
+ # -------- ENCODER PROBE (unchanged) ------------------
88
+ batch = tokenizer(text, return_tensors="pt").to("cuda")
89
+ ids, mask = batch.input_ids, batch.attention_mask
90
 
91
+ x = emb_drop(emb_ln(embeddings(ids)))
92
+ am = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
93
+ enc = encoder(x, attention_mask=am) # (1,S,H)
 
94
 
95
+ sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in SYMBOLIC_ROLES}
96
+ flags = torch.tensor([tid in sel_ids for tid in ids[0].tolist()],
97
+ device=enc.device)
98
 
99
+ found = [tokenizer.convert_ids_to_tokens([tid])[0]
100
+ for tid in ids[0].tolist() if tid in sel_ids]
101
+ tokens_str = ", ".join(found) if found else "(none)"
102
+
103
+ if flags.any():
104
+ vec = enc[0][flags].mean(0)
105
+ norm = f"{vec.norm().item():.4f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  else:
107
+ norm = "0.0000"
108
+
109
+ # -------- MASK-AND-PREDICT ACCURACY ------------------
110
+ correct, total = 0, 0
111
+ for pos, tid in enumerate(ids[0].tolist()):
112
+ if tid in sel_ids: # symbolic
113
+ total += 1
114
+ masked_ids = ids.clone()
115
+ masked_ids[0, pos] = tokenizer.mask_token_id
116
+ out = full_model(input_ids=masked_ids,
117
+ attention_mask=mask).logits # (1,S,V)
118
+ pred = out[0, pos].argmax(-1).item()
119
+ if pred == tid:
120
+ correct += 1
121
+
122
+ acc_str = f"{correct} / {total} ({(correct/total*100 if total else 0):.1f}%)"
123
+
124
+ return tokens_str, norm, len(found), acc_str
125
 
 
 
126
 
127
  # ------------------------------------------------------------------
128
+ # 4. Gradio UI ----------------------------------------------------
129
+ # ------------------------------------------------------------------
130
+ def build_interface():
131
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
132
  gr.Markdown(
133
+ "## 🧠 Symbolic Encoder Inspector\n"
134
+ "Enter text containing the `<role>` tokens.\n"
135
+ "Cosine probe **and** real mask-prediction accuracy are shown."
136
  )
137
+
138
  with gr.Row():
139
  with gr.Column():
140
  txt = gr.Textbox(
141
  label="Input with Symbolic Tokens",
 
142
  placeholder="Example: A <subject> wearing <upper_body_clothing> …",
143
+ lines=3,
144
  )
145
+ # checkbox group kept (pre-checked, disabled)
146
  roles = gr.CheckboxGroup(
147
  choices=SYMBOLIC_ROLES,
148
+ label="(all roles auto-selected)",
149
+ value=SYMBOLIC_ROLES,
150
+ interactive=False,
151
  )
152
+ btn = gr.Button("Run probe + MLM check")
153
  with gr.Column():
154
+ out_tok = gr.Textbox(label="Symbolic Tokens Found")
155
+ out_norm = gr.Textbox(label="Vector-norm (mean)")
156
+ out_cnt = gr.Textbox(label="Token Count")
157
+ out_acc = gr.Textbox(label="Mask-prediction accuracy")
158
+
159
+ btn.click(encode_and_trace,
160
+ inputs=[txt, roles],
161
+ outputs=[out_tok, out_norm, out_cnt, out_acc])
162
 
 
163
  return demo
164
 
165
+
166
  if __name__ == "__main__":
167
+ build_interface().launch()