|
|
|
|
|
|
|
import spaces |
|
import torch |
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
from bert_handler import create_handler_from_checkpoint |
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
LOCAL_CKPT = snapshot_download( |
|
repo_id="AbstractPhil/bert-beatrix-2048", |
|
revision="main", |
|
local_dir="bert-beatrix-2048", |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
cfg_path = Path(LOCAL_CKPT) / "config.json" |
|
with open(cfg_path) as f: |
|
cfg = json.load(f) |
|
|
|
auto_map = cfg.get("auto_map", {}) |
|
changed = False |
|
for k, v in auto_map.items(): |
|
|
|
if "--" in v: |
|
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix() |
|
changed = True |
|
|
|
if changed: |
|
cfg["auto_map"] = auto_map |
|
with open(cfg_path, "w") as f: |
|
json.dump(cfg, f, indent=2) |
|
print("π§ Patched auto_map β now points to local modules only") |
|
|
|
|
|
for name in list(sys.modules): |
|
if name.startswith("transformers_modules.AbstractPhil.bert-beatrix-2048"): |
|
del sys.modules[name] |
|
|
|
|
|
|
|
|
|
from bert_handler import create_handler_from_checkpoint |
|
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT) |
|
full_model = full_model.eval().cuda() |
|
|
|
|
|
encoder = full_model.bert.encoder |
|
embeddings = full_model.bert.embeddings |
|
emb_ln = full_model.bert.emb_ln |
|
emb_drop = full_model.bert.emb_drop |
|
|
|
|
|
|
|
|
|
SYMBOLIC_ROLES = [ |
|
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", |
|
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", |
|
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", |
|
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", |
|
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", |
|
"<fabric>", "<jewelry>" |
|
] |
|
|
|
|
|
missing = [t for t in SYMBOLIC_ROLES |
|
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id] |
|
if missing: |
|
raise RuntimeError(f"Tokenizer is missing special tokens: {missing}") |
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def encode_and_trace(text: str, selected_roles: list[str]): |
|
with torch.no_grad(): |
|
batch = tokenizer(text, return_tensors="pt").to("cuda") |
|
ids, mask = batch.input_ids, batch.attention_mask |
|
|
|
x = emb_drop(emb_ln(embeddings(ids))) |
|
|
|
ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1]) |
|
enc = encoder(x, attention_mask=ext_mask) |
|
|
|
want = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles} |
|
keep = torch.tensor([tid in want for tid in ids[0]], device=enc.device) |
|
|
|
found = [tokenizer.convert_ids_to_tokens([tid])[0] for tid in ids[0] if tid in want] |
|
if keep.any(): |
|
vec = enc[0][keep].mean(0) |
|
norm = f"{vec.norm().item():.4f}" |
|
else: |
|
norm = "0.0000" |
|
|
|
return { |
|
"Symbolic Tokens": ", ".join(found) or "(none)", |
|
"Mean Norm": norm, |
|
"Token Count": int(keep.sum().item()), |
|
} |
|
|
|
|
|
|
|
|
|
def build_interface(): |
|
with gr.Blocks(title="π§ Symbolic Encoder Inspector") as demo: |
|
gr.Markdown("## π§ Symbolic Encoder Inspector") |
|
with gr.Row(): |
|
with gr.Column(): |
|
txt = gr.Textbox(label="Input with Symbolic Tokens", lines=3) |
|
chk = gr.CheckboxGroup(choices=SYMBOLIC_ROLES, label="Trace these roles") |
|
btn = gr.Button("Encode & Trace") |
|
with gr.Column(): |
|
out_tok = gr.Textbox(label="Symbolic Tokens Found") |
|
out_norm = gr.Textbox(label="Mean Norm") |
|
out_cnt = gr.Textbox(label="Token Count") |
|
btn.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt]) |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
build_interface().launch() |
|
|