|
|
|
|
|
|
|
import json, re, sys, math |
|
from pathlib import Path, PurePosixPath |
|
|
|
import torch, torch.nn.functional as F |
|
import gradio as gr |
|
import spaces |
|
from huggingface_hub import snapshot_download |
|
|
|
from bert_handler import create_handler_from_checkpoint |
|
|
|
|
|
|
|
|
|
REPO_ID = "AbstractPhil/bert-beatrix-2048" |
|
LOCAL_CKPT = "bert-beatrix-2048" |
|
|
|
snapshot_download( |
|
repo_id=REPO_ID, |
|
revision="main", |
|
local_dir=LOCAL_CKPT, |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
|
|
cfg_path = Path(LOCAL_CKPT) / "config.json" |
|
with cfg_path.open() as f: cfg = json.load(f) |
|
|
|
amap = cfg.get("auto_map", {}) |
|
for k,v in amap.items(): |
|
if "--" in v: |
|
amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix() |
|
cfg["auto_map"] = amap |
|
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2) |
|
|
|
|
|
|
|
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT) |
|
full_model = full_model.eval().cuda() |
|
|
|
|
|
|
|
SYMBOLIC_ROLES = [ |
|
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", |
|
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", |
|
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", |
|
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", |
|
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", |
|
"<fabric>", "<jewelry>", |
|
] |
|
|
|
|
|
missing_tokens = [] |
|
symbolic_token_ids = {} |
|
for token in SYMBOLIC_ROLES: |
|
token_id = tokenizer.convert_tokens_to_ids(token) |
|
if token_id == tokenizer.unk_token_id: |
|
missing_tokens.append(token) |
|
else: |
|
symbolic_token_ids[token] = token_id |
|
|
|
if missing_tokens: |
|
print(f"β οΈ Missing symbolic tokens: {missing_tokens}") |
|
print("Available tokens will be used for classification") |
|
|
|
MASK = tokenizer.mask_token |
|
MASK_ID = tokenizer.mask_token_id |
|
|
|
print(f"β
Loaded {len(symbolic_token_ids)} symbolic tokens") |
|
|
|
|
|
|
|
|
|
|
|
def get_symbolic_predictions(input_ids, attention_mask, mask_positions, selected_roles): |
|
""" |
|
Proper MLM-based prediction for symbolic tokens at masked positions |
|
|
|
Args: |
|
input_ids: (B, S) token IDs with [MASK] at positions to classify |
|
attention_mask: (B, S) attention mask |
|
mask_positions: list of positions that are masked |
|
selected_roles: list of symbolic role tokens to consider |
|
|
|
Returns: |
|
predictions and probabilities for each masked position |
|
""" |
|
|
|
with torch.no_grad(): |
|
outputs = full_model(input_ids=input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
|
|
|
|
selected_token_ids = [symbolic_token_ids[role] for role in selected_roles |
|
if role in symbolic_token_ids] |
|
|
|
if not selected_token_ids: |
|
return [], [] |
|
|
|
results = [] |
|
|
|
for pos in mask_positions: |
|
|
|
pos_logits = logits[0, pos] |
|
|
|
|
|
symbolic_logits = pos_logits[selected_token_ids] |
|
|
|
|
|
symbolic_probs = F.softmax(symbolic_logits, dim=-1) |
|
|
|
|
|
top_indices = torch.argsort(symbolic_probs, descending=True) |
|
|
|
pos_results = [] |
|
for i in top_indices: |
|
token_idx = selected_token_ids[i] |
|
token = tokenizer.convert_ids_to_tokens([token_idx])[0] |
|
prob = symbolic_probs[i].item() |
|
pos_results.append({ |
|
"token": token, |
|
"probability": prob, |
|
"token_id": token_idx |
|
}) |
|
|
|
results.append({ |
|
"position": pos, |
|
"predictions": pos_results |
|
}) |
|
|
|
return results |
|
|
|
|
|
def create_strategic_masks(text, tokenizer, strategy="content_words"): |
|
""" |
|
Create strategic mask positions based on different strategies |
|
|
|
Args: |
|
text: input text |
|
tokenizer: tokenizer |
|
strategy: masking strategy |
|
|
|
Returns: |
|
input_ids with masks, attention_mask, original_tokens, mask_positions |
|
""" |
|
|
|
batch = tokenizer(text, return_tensors="pt", add_special_tokens=True) |
|
input_ids = batch.input_ids[0] |
|
attention_mask = batch.attention_mask[0] |
|
|
|
|
|
original_tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
|
|
|
|
mask_positions = [] |
|
|
|
if strategy == "content_words": |
|
|
|
skip_tokens = { |
|
tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token, |
|
".", ",", "!", "?", ":", ";", "'", '"', "-", "(", ")", "[", "]", |
|
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", |
|
"for", "of", "with", "by", "is", "are", "was", "were", "be", "been" |
|
} |
|
|
|
for i, token in enumerate(original_tokens): |
|
if (token not in skip_tokens and |
|
not token.startswith("##") and |
|
len(token) > 2 and |
|
token.isalpha()): |
|
mask_positions.append(i) |
|
|
|
elif strategy == "every_nth": |
|
|
|
for i in range(1, len(original_tokens) - 1, 3): |
|
mask_positions.append(i) |
|
|
|
elif strategy == "random": |
|
|
|
import random |
|
candidates = list(range(1, len(original_tokens) - 1)) |
|
num_to_mask = max(1, int(len(candidates) * 0.15)) |
|
mask_positions = random.sample(candidates, min(num_to_mask, len(candidates))) |
|
mask_positions.sort() |
|
|
|
elif strategy == "manual": |
|
|
|
|
|
pass |
|
|
|
|
|
mask_positions = mask_positions[:10] |
|
|
|
|
|
masked_input_ids = input_ids.clone() |
|
for pos in mask_positions: |
|
masked_input_ids[pos] = MASK_ID |
|
|
|
return masked_input_ids.unsqueeze(0), attention_mask.unsqueeze(0), original_tokens, mask_positions |
|
|
|
|
|
@spaces.GPU |
|
def symbolic_classification_analysis(text, selected_roles, masking_strategy="content_words", num_predictions=5): |
|
""" |
|
Perform symbolic classification analysis using MLM prediction |
|
FIXED: Now tests what the model actually learned |
|
""" |
|
if not selected_roles: |
|
selected_roles = list(symbolic_token_ids.keys()) |
|
|
|
if not text.strip(): |
|
return "Please enter some text to analyze.", "", 0 |
|
|
|
try: |
|
|
|
if any(role in text for role in symbolic_token_ids.keys()): |
|
|
|
return test_descriptive_prediction(text, selected_roles, num_predictions) |
|
else: |
|
|
|
return test_with_context_injection(text, selected_roles, num_predictions) |
|
|
|
except Exception as e: |
|
error_msg = f"Error during analysis: {str(e)}" |
|
print(error_msg) |
|
return error_msg, "", 0 |
|
|
|
|
|
def test_descriptive_prediction(text, selected_roles, num_predictions): |
|
""" |
|
Test what descriptive words the model predicts after symbolic tokens |
|
This matches the actual training objective |
|
""" |
|
|
|
tokens = tokenizer.tokenize(text, add_special_tokens=True) |
|
token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
|
|
symbolic_positions = [] |
|
for i, token in enumerate(tokens): |
|
if token in symbolic_token_ids: |
|
|
|
for offset in range(1, min(4, len(tokens) - i)): |
|
if i + offset < len(tokens) and tokens[i + offset] not in ['[SEP]', '[PAD]']: |
|
symbolic_positions.append({ |
|
'mask_pos': i + offset, |
|
'symbolic_token': token, |
|
'original_token': tokens[i + offset] |
|
}) |
|
|
|
if not symbolic_positions: |
|
return "No symbolic tokens found in input. Try format like: '<subject> a young woman'", "", 0 |
|
|
|
|
|
results = [] |
|
for pos_info in symbolic_positions[:5]: |
|
masked_ids = token_ids.copy() |
|
masked_ids[pos_info['mask_pos']] = MASK_ID |
|
|
|
|
|
masked_input = torch.tensor([masked_ids]).to("cuda") |
|
attention_mask = torch.ones_like(masked_input) |
|
|
|
with torch.no_grad(): |
|
outputs = full_model(input_ids=masked_input, attention_mask=attention_mask) |
|
logits = outputs.logits[0, pos_info['mask_pos']] |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
top_indices = torch.argsort(probs, descending=True)[:num_predictions] |
|
|
|
predictions = [] |
|
for idx in top_indices: |
|
token_text = tokenizer.convert_ids_to_tokens([idx.item()])[0] |
|
prob = probs[idx].item() |
|
predictions.append({ |
|
"token": token_text, |
|
"probability": prob |
|
}) |
|
|
|
results.append({ |
|
"symbolic_context": pos_info['symbolic_token'], |
|
"position": pos_info['mask_pos'], |
|
"original_token": pos_info['original_token'], |
|
"predictions": predictions |
|
}) |
|
|
|
|
|
analysis = { |
|
"input_text": text, |
|
"test_type": "descriptive_prediction", |
|
"explanation": "Testing what descriptive words model predicts after symbolic tokens", |
|
"results": results |
|
} |
|
|
|
summary_lines = [f"π― Testing Descriptive Prediction (what model actually learned)\n"] |
|
for result in results: |
|
ctx = result["symbolic_context"] |
|
orig = result["original_token"] |
|
top_pred = result["predictions"][0] |
|
|
|
summary_lines.append( |
|
f"After {ctx}: '{orig}' β '{top_pred['token']}' ({top_pred['probability']:.4f})" |
|
) |
|
|
|
summary = "\n".join(summary_lines) |
|
return json.dumps(analysis, indent=2), summary, len(results) |
|
|
|
|
|
def test_with_context_injection(text, selected_roles, num_predictions): |
|
""" |
|
Inject symbolic context and test what descriptive words are predicted |
|
""" |
|
results = [] |
|
|
|
|
|
for role in selected_roles[:3]: |
|
|
|
context_text = f"{role} {text}" |
|
|
|
|
|
tokens = tokenizer.tokenize(context_text, add_special_tokens=True) |
|
token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
|
|
role_pos = None |
|
for i, token in enumerate(tokens): |
|
if token == role: |
|
role_pos = i |
|
break |
|
|
|
if role_pos is None or role_pos + 2 >= len(tokens): |
|
continue |
|
|
|
|
|
mask_pos = role_pos + 1 |
|
skip_words = {'a', 'an', 'the', 'some', 'this', 'that'} |
|
while mask_pos < len(tokens) - 1: |
|
current_token = tokens[mask_pos].lower() |
|
if current_token not in skip_words and len(current_token) > 2: |
|
break |
|
mask_pos += 1 |
|
|
|
if mask_pos >= len(tokens): |
|
continue |
|
|
|
|
|
masked_ids = token_ids.copy() |
|
original_token = tokens[mask_pos] |
|
masked_ids[mask_pos] = MASK_ID |
|
|
|
|
|
masked_input = torch.tensor([masked_ids]).to("cuda") |
|
attention_mask = torch.ones_like(masked_input) |
|
|
|
with torch.no_grad(): |
|
outputs = full_model(input_ids=masked_input, attention_mask=attention_mask) |
|
logits = outputs.logits[0, mask_pos] |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
top_indices = torch.argsort(probs, descending=True)[:num_predictions] |
|
|
|
predictions = [] |
|
for idx in top_indices: |
|
token_text = tokenizer.convert_ids_to_tokens([idx.item()])[0] |
|
prob = probs[idx].item() |
|
predictions.append({ |
|
"token": token_text, |
|
"probability": prob |
|
}) |
|
|
|
results.append({ |
|
"symbolic_context": role, |
|
"position": mask_pos, |
|
"original_token": original_token, |
|
"context_text": context_text, |
|
"predictions": predictions |
|
}) |
|
|
|
|
|
analysis = { |
|
"input_text": text, |
|
"test_type": "context_injection", |
|
"explanation": "Injected symbolic tokens and tested descriptive predictions", |
|
"results": results |
|
} |
|
|
|
summary_lines = [f"π― Testing with Symbolic Context Injection\n"] |
|
for result in results: |
|
role = result["symbolic_context"] |
|
orig = result["original_token"] |
|
top_pred = result["predictions"][0] |
|
|
|
summary_lines.append( |
|
f"{role} context: '{orig}' β '{top_pred['token']}' ({top_pred['probability']:.4f})" |
|
) |
|
|
|
summary = "\n".join(summary_lines) |
|
return json.dumps(analysis, indent=2), summary, len(results) |
|
|
|
|
|
def create_manual_mask_analysis(text, mask_positions_str, selected_roles): |
|
""" |
|
Allow manual specification of mask positions |
|
""" |
|
try: |
|
|
|
mask_positions = [int(x.strip()) for x in mask_positions_str.split(",") if x.strip().isdigit()] |
|
|
|
if not mask_positions: |
|
return "Please specify valid mask positions (comma-separated numbers)", "", 0 |
|
|
|
|
|
batch = tokenizer(text, return_tensors="pt", add_special_tokens=True) |
|
input_ids = batch.input_ids[0] |
|
attention_mask = batch.attention_mask[0] |
|
original_tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
|
|
|
|
valid_positions = [pos for pos in mask_positions if 0 <= pos < len(input_ids)] |
|
if not valid_positions: |
|
return f"Invalid positions. Text has {len(input_ids)} tokens (0-{len(input_ids)-1})", "", 0 |
|
|
|
|
|
masked_input_ids = input_ids.clone() |
|
for pos in valid_positions: |
|
masked_input_ids[pos] = MASK_ID |
|
|
|
|
|
masked_input_ids = masked_input_ids.unsqueeze(0).to("cuda") |
|
attention_mask = attention_mask.unsqueeze(0).to("cuda") |
|
|
|
predictions = get_symbolic_predictions( |
|
masked_input_ids, attention_mask, valid_positions, selected_roles |
|
) |
|
|
|
|
|
results = [] |
|
for pred_data in predictions: |
|
pos = pred_data["position"] |
|
original = original_tokens[pos] |
|
top_pred = pred_data["predictions"][0] if pred_data["predictions"] else None |
|
|
|
if top_pred: |
|
results.append( |
|
f"Pos {pos}: '{original}' β {top_pred['token']} ({top_pred['probability']:.4f})" |
|
) |
|
|
|
return "\n".join(results), f"Analyzed {len(valid_positions)} positions", len(valid_positions) |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}", "", 0 |
|
|
|
|
|
|
|
|
|
def build_interface(): |
|
with gr.Blocks(title="π§ MLM Symbolic Classifier", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# π§ MLM-Based Symbolic Classification") |
|
gr.Markdown("Analyze text using masked language modeling to predict symbolic roles at specific positions.") |
|
|
|
with gr.Tab("Automatic Analysis"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
txt_input = gr.Textbox( |
|
label="Input Text", |
|
lines=4, |
|
placeholder="Try: '<subject> a young woman wearing elegant dress' or just 'young woman wearing dress'" |
|
) |
|
|
|
with gr.Row(): |
|
masking_strategy = gr.Dropdown( |
|
choices=["content_words", "every_nth", "random"], |
|
value="content_words", |
|
label="Masking Strategy" |
|
) |
|
num_predictions = gr.Slider( |
|
minimum=1, maximum=10, value=5, step=1, |
|
label="Top Predictions per Position" |
|
) |
|
|
|
roles_selection = gr.CheckboxGroup( |
|
choices=list(symbolic_token_ids.keys()), |
|
value=list(symbolic_token_ids.keys()), |
|
label="Symbolic Roles to Consider" |
|
) |
|
|
|
analyze_btn = gr.Button("π Analyze", variant="primary") |
|
|
|
with gr.Column(): |
|
summary_output = gr.Textbox( |
|
label="Analysis Summary", |
|
lines=10, |
|
max_lines=15 |
|
) |
|
|
|
with gr.Row(): |
|
positions_analyzed = gr.Number(label="Positions Analyzed", precision=0) |
|
max_confidence = gr.Textbox(label="Best Prediction", max_lines=1) |
|
|
|
detailed_output = gr.JSON(label="Detailed Results") |
|
|
|
with gr.Tab("Manual Masking"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
manual_text = gr.Textbox( |
|
label="Input Text", |
|
lines=3, |
|
placeholder="Enter text for manual analysis..." |
|
) |
|
|
|
mask_positions_input = gr.Textbox( |
|
label="Mask Positions (comma-separated)", |
|
placeholder="e.g., 2,5,8,12", |
|
info="Specify token positions to mask (0-based indexing)" |
|
) |
|
|
|
manual_roles = gr.CheckboxGroup( |
|
choices=list(symbolic_token_ids.keys()), |
|
value=list(symbolic_token_ids.keys())[:10], |
|
label="Symbolic Roles" |
|
) |
|
|
|
manual_analyze_btn = gr.Button("π― Analyze Specific Positions") |
|
|
|
with gr.Column(): |
|
manual_results = gr.Textbox( |
|
label="Manual Analysis Results", |
|
lines=8 |
|
) |
|
|
|
manual_summary = gr.Textbox(label="Summary") |
|
manual_count = gr.Number(label="Positions", precision=0) |
|
|
|
with gr.Tab("Token Inspector"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
inspect_text = gr.Textbox( |
|
label="Text to Inspect", |
|
lines=2, |
|
placeholder="Enter text to see tokenization..." |
|
) |
|
|
|
|
|
example_patterns = gr.Button("π Load Image Caption Examples") |
|
|
|
inspect_btn = gr.Button("π Inspect Tokens") |
|
|
|
with gr.Column(): |
|
token_breakdown = gr.Textbox( |
|
label="Token Breakdown", |
|
lines=8, |
|
info="Shows how text is tokenized with position indices" |
|
) |
|
|
|
with gr.Tab("Caption Examples"): |
|
gr.Markdown("### πΌοΈ Test with Training-Style Patterns") |
|
gr.Markdown(""" |
|
**The model was trained to predict descriptive words AFTER symbolic tokens.** |
|
|
|
Test with patterns like: |
|
- `<subject> a young woman wearing elegant dress` |
|
- `<lighting> soft natural illumination on the scene` |
|
- `<emotion> happy expression while posing confidently` |
|
""") |
|
|
|
example_captions = [ |
|
"<subject> a young woman wearing a blue dress", |
|
"<lighting> soft natural illumination in the scene", |
|
"<emotion> happy expression while posing confidently", |
|
"<pose> standing gracefully near the window", |
|
"<upper_body_clothing> elegant silk blouse with intricate patterns", |
|
"<material> luxurious velvet fabric with rich texture", |
|
"<accessory> delicate silver jewelry catching the light", |
|
"<surface> polished marble floor reflecting ambient glow" |
|
] |
|
|
|
for caption in example_captions: |
|
with gr.Row(): |
|
gr.Textbox(value=caption, label="Training-Style Example", interactive=False, scale=3) |
|
copy_btn = gr.Button("π Test This", scale=1) |
|
|
|
|
|
analyze_btn.click( |
|
symbolic_classification_analysis, |
|
inputs=[txt_input, roles_selection, masking_strategy, num_predictions], |
|
outputs=[detailed_output, summary_output, positions_analyzed] |
|
) |
|
|
|
manual_analyze_btn.click( |
|
create_manual_mask_analysis, |
|
inputs=[manual_text, mask_positions_input, manual_roles], |
|
outputs=[manual_results, manual_summary, manual_count] |
|
) |
|
|
|
def load_examples(): |
|
return "a young woman wearing a blue dress" |
|
|
|
def inspect_tokens(text): |
|
if not text.strip(): |
|
return "Enter text to inspect tokenization" |
|
|
|
tokens = tokenizer.tokenize(text, add_special_tokens=True) |
|
result_lines = [] |
|
|
|
for i, token in enumerate(tokens): |
|
result_lines.append(f"{i:2d}: '{token}'") |
|
|
|
return "\n".join(result_lines) |
|
|
|
|
|
example_patterns.click( |
|
load_examples, |
|
outputs=[inspect_text] |
|
) |
|
|
|
inspect_btn.click( |
|
inspect_tokens, |
|
inputs=[inspect_text], |
|
outputs=[token_breakdown] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
print("π Starting MLM Symbolic Classifier...") |
|
print(f"β
Model loaded with {len(symbolic_token_ids)} symbolic tokens") |
|
print(f"π― Available symbolic roles: {list(symbolic_token_ids.keys())[:5]}...") |
|
|
|
build_interface().launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True |
|
) |