|
|
|
|
|
|
|
import json, re, sys, math |
|
from pathlib import Path, PurePosixPath |
|
|
|
import torch, torch.nn.functional as F |
|
import gradio as gr |
|
import spaces |
|
from huggingface_hub import snapshot_download |
|
|
|
from bert_handler import create_handler_from_checkpoint |
|
|
|
|
|
|
|
|
|
REPO_ID = "AbstractPhil/bert-beatrix-2048" |
|
LOCAL_CKPT = "bert-beatrix-2048" |
|
|
|
snapshot_download( |
|
repo_id=REPO_ID, |
|
revision="main", |
|
local_dir=LOCAL_CKPT, |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
|
|
cfg_path = Path(LOCAL_CKPT) / "config.json" |
|
with cfg_path.open() as f: cfg = json.load(f) |
|
|
|
amap = cfg.get("auto_map", {}) |
|
for k,v in amap.items(): |
|
if "--" in v: |
|
amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix() |
|
cfg["auto_map"] = amap |
|
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2) |
|
|
|
|
|
|
|
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT) |
|
full_model = full_model.eval().cuda() |
|
|
|
encoder = full_model.bert.encoder |
|
embeddings = full_model.bert.embeddings |
|
emb_ln = full_model.bert.emb_ln |
|
emb_drop = full_model.bert.emb_drop |
|
mlm_head = full_model.cls |
|
|
|
|
|
|
|
SYMBOLIC_ROLES = [ |
|
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", |
|
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", |
|
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", |
|
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", |
|
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", |
|
"<fabric>", "<jewelry>", |
|
] |
|
if any(tokenizer.convert_tokens_to_ids(t)==tokenizer.unk_token_id |
|
for t in SYMBOLIC_ROLES): |
|
sys.exit("❌ tokenizer missing special tokens") |
|
|
|
|
|
MASK = tokenizer.mask_token |
|
|
|
|
|
|
|
|
|
def cosine(a,b): |
|
return torch.nn.functional.cosine_similarity(a,b,dim=-1) |
|
|
|
def pool_accuracy(ids, logits, pool_mask): |
|
""" |
|
ids : (S,) gold token ids |
|
logits : (S,V) MLM logits |
|
pool_mask : bool (S,) which tokens belong to the candidate pool |
|
returns accuracy over masked positions only (if none, return 0) |
|
""" |
|
idx = pool_mask.nonzero(as_tuple=False).flatten() |
|
if idx.numel()==0: return 0.0 |
|
preds = logits.argmax(-1)[idx] |
|
gold = ids[idx] |
|
return (preds==gold).float().mean().item() |
|
|
|
|
|
@spaces.GPU |
|
def encode_and_trace(text, selected_roles): |
|
if not selected_roles: |
|
selected_roles = SYMBOLIC_ROLES |
|
sel_ids = [tokenizer.convert_tokens_to_ids(t) for t in selected_roles] |
|
sel_ids_tensor = torch.tensor(sel_ids, device="cuda") |
|
|
|
|
|
batch = tokenizer(text, return_tensors="pt").to("cuda") |
|
ids, attn = batch.input_ids, batch.attention_mask |
|
S = ids.shape[1] |
|
|
|
|
|
def encode(input_ids, attn_mask): |
|
x = embeddings(input_ids) |
|
if emb_ln: x = emb_ln(x) |
|
if emb_drop: x = emb_drop(x) |
|
ext = full_model.bert.get_extended_attention_mask(attn_mask, x.shape[:-1]) |
|
return encoder(x, attention_mask=ext)[0] |
|
|
|
encoded = encode(ids, attn) |
|
|
|
|
|
symbolic_embeds = embeddings.word_embeddings(sel_ids_tensor) |
|
sim = cosine(encoded.unsqueeze(1), symbolic_embeds.unsqueeze(0)) |
|
maxcos, argrole = sim.max(-1) |
|
top_roles = [selected_roles[i] for i in argrole.tolist()] |
|
sort_idx = maxcos.argsort(descending=True) |
|
hi_idx = sort_idx[:S // 2] |
|
lo_idx = sort_idx[S // 2:] |
|
|
|
MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]") |
|
|
|
|
|
def evaluate_pool(idx_order, label, ids): |
|
best_pool, best_acc = [], 0.0 |
|
ptr = 0 |
|
while ptr < len(idx_order): |
|
cand = idx_order[ptr:ptr + 2] |
|
pool = best_pool + cand.tolist() |
|
ptr += 2 |
|
|
|
mask_flags = torch.zeros_like(ids, dtype=torch.bool) |
|
mask_flags[0, pool] = True |
|
masked_input = ids.where(mask_flags, MASK_ID) |
|
|
|
encoded_m = encode(masked_input, attn) |
|
logits = mlm_head(encoded_m)[0] |
|
preds = logits.argmax(-1) |
|
|
|
masked_positions = (~mask_flags[0]).nonzero(as_tuple=False).squeeze(-1) |
|
if masked_positions.numel() == 0: |
|
continue |
|
|
|
|
|
gold = ids[0][masked_positions] |
|
correct = (preds[masked_positions] == gold).float() |
|
acc = correct.mean().item() |
|
|
|
if acc > best_acc: |
|
best_pool, best_acc = pool, acc |
|
if acc >= 0.5: |
|
break |
|
|
|
return best_pool, best_acc |
|
|
|
|
|
pool_hi, acc_hi = evaluate_pool(hi_idx, "high", ids) |
|
pool_lo, acc_lo = evaluate_pool(lo_idx, "low", ids) |
|
|
|
|
|
decoded_tokens = tokenizer.convert_ids_to_tokens(ids[0]) |
|
role_trace = [ |
|
f"{tok:<15} → {role} cos={score:.4f}" |
|
for tok, role, score in zip(decoded_tokens, top_roles, maxcos.tolist()) |
|
] |
|
|
|
|
|
res_json = { |
|
"High-pool tokens": tokenizer.decode(ids[0, pool_hi]), |
|
"High accuracy": f"{acc_hi:.3f}", |
|
"Low-pool tokens": tokenizer.decode(ids[0, pool_lo]), |
|
"Low accuracy": f"{acc_lo:.3f}", |
|
"Token–Symbolic Role Alignment": role_trace |
|
} |
|
|
|
return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_interface(): |
|
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo: |
|
gr.Markdown("## 🧠 Symbolic Encoder Inspector") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
txt = gr.Textbox(label="Prompt", lines=3) |
|
roles= gr.CheckboxGroup( |
|
choices=SYMBOLIC_ROLES, label="Roles", |
|
value=SYMBOLIC_ROLES |
|
) |
|
btn = gr.Button("Run") |
|
with gr.Column(): |
|
out_json = gr.Textbox(label="Result JSON") |
|
out_max = gr.Textbox(label="Max cos") |
|
out_cnt = gr.Textbox(label="# roles") |
|
|
|
btn.click(encode_and_trace, [txt,roles], [out_json,out_max,out_cnt]) |
|
return demo |
|
|
|
|
|
if __name__=="__main__": |
|
build_interface().launch() |