AbstractPhil commited on
Commit
8a2e372
·
verified ·
1 Parent(s): 323ce30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -40
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py
2
 
3
  from bert_handler import create_handler_from_checkpoint
4
  import torch
@@ -6,40 +6,48 @@ import gradio as gr
6
  import re
7
  from pathlib import Path
8
  import spaces
 
9
 
10
- @spaces.GPU
11
- def mask_and_predict(text: str, selected_roles: list[str]):
12
- MASK_TOKEN = tokenizer.mask_token or "[MASK]"
13
- results = []
14
- masked_text = text
15
- token_ids = tokenizer.encode(text, return_tensors="pt").cuda()
 
 
 
16
 
17
- for role in selected_roles:
18
- role_pattern = re.escape(role)
19
- masked_text = re.sub(role_pattern, MASK_TOKEN, masked_text)
 
 
20
 
21
- masked_ids = tokenizer.encode(masked_text, return_tensors="pt").cuda()
 
22
  with torch.no_grad():
23
- outputs = model(input_ids=masked_ids)
24
- logits = outputs.logits[0]
25
- predictions = torch.argmax(logits, dim=-1)
26
 
27
- original_ids = tokenizer.convert_ids_to_tokens(token_ids[0])
28
- predicted_ids = tokenizer.convert_ids_to_tokens(predictions)
29
- masked_ids_tokens = tokenizer.convert_ids_to_tokens(masked_ids[0])
 
 
30
 
31
- for i, token in enumerate(masked_ids_tokens):
32
- if token == MASK_TOKEN:
33
- results.append({
34
- "Position": i,
35
- "Masked Token": MASK_TOKEN,
36
- "Predicted": predicted_ids[i],
37
- "Original": original_ids[i] if i < len(original_ids) else "",
38
- "Match": "✅" if predicted_ids[i] == original_ids[i] else "❌"
39
- })
40
 
41
- accuracy = sum(1 for r in results if r["Match"] == "✅") / max(len(results), 1)
42
- return results, f"Accuracy: {accuracy:.1%}"
 
 
 
 
 
 
43
 
44
  symbolic_roles = [
45
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
@@ -50,27 +58,23 @@ symbolic_roles = [
50
  "<fabric>", "<jewelry>"
51
  ]
52
 
53
- # Load from official hosted checkpoint
54
- checkpoint_path = "./bert-beatrix-2048"
55
- handler, model, tokenizer = create_handler_from_checkpoint(checkpoint_path)
56
- model = model.eval().cuda()
57
-
58
  def build_interface():
59
  with gr.Blocks() as demo:
60
- gr.Markdown("## 🔎 Symbolic BERT Inference Test")
61
  with gr.Row():
62
  with gr.Column():
63
- input_text = gr.Textbox(label="Symbolic Input Caption", lines=3)
64
  selected_roles = gr.CheckboxGroup(
65
  choices=symbolic_roles,
66
- label="Mask these symbolic roles"
67
  )
68
- run_btn = gr.Button("Run Mask Inference")
69
  with gr.Column():
70
- output_table = gr.Dataframe(headers=["Position", "Masked Token", "Predicted", "Original", "Match"], interactive=False)
71
- accuracy_score = gr.Textbox(label="Mask Accuracy")
 
72
 
73
- run_btn.click(fn=mask_and_predict, inputs=[input_text, selected_roles], outputs=[output_table, accuracy_score])
74
 
75
  return demo
76
 
 
1
+ # Updating the app to use only the encoder from the model, ensuring symbolic support
2
 
3
  from bert_handler import create_handler_from_checkpoint
4
  import torch
 
6
  import re
7
  from pathlib import Path
8
  import spaces
9
+ from huggingface_hub import snapshot_download
10
 
11
+ # Load checkpoint using BERTHandler (loads tokenizer and full model)
12
+ checkpoint_path = snapshot_download(
13
+ repo_id="AbstractPhil/bert-beatrix-2048",
14
+ revision="main",
15
+ local_dir="bert-beatrix-2048",
16
+ local_dir_use_symlinks=False
17
+ )
18
+ handler, model, tokenizer = create_handler_from_checkpoint(checkpoint_path)
19
+ model = model.eval().cuda()
20
 
21
+ # Extract encoder only (NomicBertModel -> encoder)
22
+ encoder = model.bert.encoder
23
+ embeddings = model.bert.embeddings
24
+ emb_ln = model.bert.emb_ln
25
+ emb_drop = model.bert.emb_drop
26
 
27
+ @spaces.GPU
28
+ def encode_and_predict(text: str, selected_roles: list[str]):
29
  with torch.no_grad():
30
+ inputs = tokenizer(text, return_tensors="pt").to("cuda")
31
+ input_ids = inputs.input_ids
32
+ attention_mask = inputs.attention_mask
33
 
34
+ # Run embedding + encoder pipeline
35
+ x = embeddings(input_ids)
36
+ x = emb_ln(x)
37
+ x = emb_drop(x)
38
+ encoded = encoder(x, attention_mask=attention_mask.bool())
39
 
40
+ symbolic_ids = [tokenizer.convert_tokens_to_ids(tok) for tok in selected_roles]
41
+ symbolic_mask = torch.isin(input_ids, torch.tensor(symbolic_ids, device=input_ids.device))
 
 
 
 
 
 
 
42
 
43
+ masked_tokens = [tokenizer.convert_ids_to_tokens([tid])[0] for tid in input_ids[0] if tid in symbolic_ids]
44
+ role_reprs = encoded[symbolic_mask].mean(dim=0) if symbolic_mask.any() else torch.zeros_like(encoded[0, 0])
45
+
46
+ return {
47
+ "Symbolic Tokens": masked_tokens,
48
+ "Embedding Norm": f"{role_reprs.norm().item():.4f}",
49
+ "Symbolic Token Count": symbolic_mask.sum().item(),
50
+ }
51
 
52
  symbolic_roles = [
53
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
 
58
  "<fabric>", "<jewelry>"
59
  ]
60
 
 
 
 
 
 
61
  def build_interface():
62
  with gr.Blocks() as demo:
63
+ gr.Markdown("## 🧠 Symbolic Encoder Inspector")
64
  with gr.Row():
65
  with gr.Column():
66
+ input_text = gr.Textbox(label="Input with Symbolic Tokens", lines=3)
67
  selected_roles = gr.CheckboxGroup(
68
  choices=symbolic_roles,
69
+ label="Which symbolic tokens should be traced?"
70
  )
71
+ run_btn = gr.Button("Encode & Trace")
72
  with gr.Column():
73
+ symbolic_tokens = gr.Textbox(label="Symbolic Tokens Found")
74
+ embedding_norm = gr.Textbox(label="Mean Norm of Symbolic Embeddings")
75
+ token_count = gr.Textbox(label="Count of Symbolic Tokens")
76
 
77
+ run_btn.click(fn=encode_and_predict, inputs=[input_text, selected_roles], outputs=[symbolic_tokens, embedding_norm, token_count])
78
 
79
  return demo
80