|
|
|
|
|
|
|
|
|
|
|
import json, sys |
|
from pathlib import Path, PurePosixPath |
|
|
|
import gradio as gr |
|
import spaces |
|
import torch |
|
import torch.nn.functional as F |
|
from huggingface_hub import snapshot_download |
|
|
|
from bert_handler import create_handler_from_checkpoint |
|
|
|
|
|
|
|
|
|
|
|
REPO_ID = "AbstractPhil/bert-beatrix-2048" |
|
LOCAL_DIR = "bert-beatrix-2048" |
|
|
|
snapshot_download( |
|
repo_id=REPO_ID, |
|
revision="main", |
|
local_dir=LOCAL_DIR, |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
cfg_path = Path(LOCAL_DIR) / "config.json" |
|
with cfg_path.open() as f: |
|
cfg = json.load(f) |
|
|
|
auto_map = cfg.get("auto_map", {}) |
|
patched = False |
|
for k, v in auto_map.items(): |
|
if "--" in v: |
|
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix() |
|
patched = True |
|
|
|
if patched: |
|
with cfg_path.open("w") as f: |
|
json.dump(cfg, f, indent=2) |
|
print("🛠️ Patched config.json → auto_map fixed.") |
|
|
|
|
|
|
|
|
|
|
|
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_DIR) |
|
full_model = full_model.eval().cuda() |
|
|
|
encoder = full_model.bert.encoder |
|
embeddings = full_model.bert.embeddings |
|
emb_ln = full_model.bert.emb_ln |
|
emb_drop = full_model.bert.emb_drop |
|
|
|
|
|
|
|
|
|
|
|
SYMBOLIC_ROLES = [ |
|
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", |
|
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", |
|
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", |
|
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", |
|
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", |
|
"<fabric>", "<jewelry>", |
|
] |
|
ROLE_ID = {tok: tokenizer.convert_tokens_to_ids(tok) for tok in SYMBOLIC_ROLES} |
|
missing = [tok for tok, tid in ROLE_ID.items() if tid == tokenizer.unk_token_id] |
|
if missing: |
|
sys.exit(f"❌ Tokenizer is missing {missing}") |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def encode_and_trace(text: str, selected_roles: list[str]): |
|
""" |
|
For each *selected* role: |
|
• find the contextual token whose hidden state is most similar to that |
|
role’s own embedding (cosine similarity) |
|
• return “role → token (sim)”, using tokens even when the prompt |
|
contained no <role> markers at all. |
|
Also keeps the older diagnostics. |
|
""" |
|
with torch.no_grad(): |
|
batch = tokenizer(text, return_tensors="pt").to("cuda") |
|
ids, mask = batch.input_ids, batch.attention_mask |
|
|
|
|
|
x = emb_drop(emb_ln(embeddings(ids))) |
|
msk = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1]) |
|
h = encoder(x, attention_mask=msk).squeeze(0) |
|
|
|
|
|
h_norm = F.normalize(h, dim=-1) |
|
|
|
|
|
matches = [] |
|
for role in selected_roles: |
|
role_vec = embeddings.word_embeddings.weight[ROLE_ID[role]].to(h.device) |
|
role_vec = F.normalize(role_vec, dim=-1) |
|
|
|
sims = (h_norm @ role_vec) |
|
best_idx = int(sims.argmax().item()) |
|
best_sim = float(sims[best_idx]) |
|
|
|
match_tok = tokenizer.convert_ids_to_tokens(int(ids[0, best_idx])) |
|
matches.append(f"{role} → {match_tok} ({best_sim:.2f})") |
|
|
|
match_str = ", ".join(matches) if matches else "(no roles selected)" |
|
|
|
|
|
present = [tok for tok_id, tok in zip(ids[0].tolist(), |
|
tokenizer.convert_ids_to_tokens(ids[0])) |
|
if tok in selected_roles] |
|
present_str = ", ".join(present) or "(none)" |
|
count = len(present) |
|
|
|
|
|
if count: |
|
exp_mask = torch.tensor([tid in ROLE_ID.values() for tid in ids[0]], device=h.device) |
|
norm_val = f"{h[exp_mask].mean(0).norm().item():.4f}" |
|
else: |
|
norm_val = "0.0000" |
|
|
|
return present_str, match_str, norm_val, count |
|
|
|
|
|
|
|
|
|
|
|
def build_interface(): |
|
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo: |
|
gr.Markdown( |
|
"## 🧠 Symbolic Encoder Inspector \n" |
|
"Select one or more symbolic *roles* on the left. " |
|
"The tool shows which regular tokens (if any) the model thinks " |
|
"best fit each role — even when your text doesn’t contain the " |
|
"explicit `<role>` marker." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
txt = gr.Textbox( |
|
label="Input text", |
|
lines=3, |
|
placeholder="Example: A small child in bright red boots jumps over a muddy puddle…", |
|
) |
|
roles = gr.CheckboxGroup( |
|
choices=SYMBOLIC_ROLES, |
|
label="Roles to probe", |
|
) |
|
btn = gr.Button("Run encoder probe") |
|
with gr.Column(): |
|
out_present = gr.Textbox(label="Explicit role tokens found") |
|
out_match = gr.Textbox(label="Role → Best-Match Token (cos θ)") |
|
out_norm = gr.Textbox(label="Mean hidden-state norm (explicit)") |
|
out_count = gr.Textbox(label="# explicit role tokens") |
|
|
|
btn.click( |
|
encode_and_trace, |
|
inputs=[txt, roles], |
|
outputs=[out_present, out_match, out_norm, out_count], |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
build_interface().launch() |
|
|