|
import os |
|
import argparse |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torchvision.transforms as T |
|
from transformers import AutoTokenizer |
|
import gradio as gr |
|
from resnet50 import build_model |
|
|
|
from utils import generate_similiarity_map, get_transform, post_process, load_tokenizer, build_transform_R50 |
|
from utils import IMAGENET_MEAN, IMAGENET_STD |
|
from internvl.train.dataset import dynamic_preprocess |
|
from internvl.model.internvl_chat import InternVLChatModel |
|
import spaces |
|
|
|
|
|
CHECKPOINTS = { |
|
"TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg", |
|
"TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg", |
|
} |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
def load_model(check_type): |
|
|
|
device = torch.device("cuda") |
|
if check_type == 'R50': |
|
tokenizer = load_tokenizer('tokenizer_path') |
|
model = build_model(argparse.Namespace()).eval() |
|
model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model']) |
|
transform = build_transform_R50(normalize_type='imagenet') |
|
|
|
elif check_type == 'R50_siglip': |
|
tokenizer = load_tokenizer('tokenizer_path') |
|
model = build_model(argparse.Namespace()).eval() |
|
model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model']) |
|
transform = build_transform_R50(normalize_type='imagenet') |
|
|
|
elif 'TokenFD' in check_type: |
|
model_path = CHECKPOINTS[check_type] |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN) |
|
|
|
model = InternVLChatModel.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 ,load_in_8bit=False, load_in_4bit=False).eval() |
|
transform = get_transform(is_train=False, image_size=model.config.force_image_size) |
|
|
|
return model.to(device), tokenizer, transform, device |
|
|
|
def process_image(model, tokenizer, transform, device, check_type, image, text): |
|
|
|
src_size = image.size |
|
if 'TokenFD' in check_type: |
|
images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12, |
|
image_size=model.config.force_image_size, |
|
use_thumbnail=model.config.use_thumbnail, |
|
return_ratio=True) |
|
pixel_values = torch.stack([transform(img) for img in images]).to(device) |
|
else: |
|
pixel_values = torch.stack([transform(image)]).to(device) |
|
target_ratio = (1, 1) |
|
|
|
|
|
text_input = text |
|
|
|
if text_input[0] in '!"#$%&\'()*+,-./0123456789:;<=>?@^_{|}~0123456789': |
|
input_ids = tokenizer(text_input)['input_ids'][1:] |
|
else: |
|
input_ids = tokenizer(' '+text_input)['input_ids'][1:] |
|
|
|
input_ids = torch.tensor(input_ids, device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
if 'R50' in check_type: |
|
text_embeds = model.language_embedding(input_ids) |
|
else: |
|
text_embeds = model.tok_embeddings(input_ids).clone() |
|
|
|
vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device)) |
|
|
|
|
|
|
|
|
|
print("check_type",check_type) |
|
vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type) |
|
|
|
|
|
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) |
|
vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True) |
|
similarity = text_embeds @ vit_embeds.T |
|
resized_size = size1 if size1 is not None else size2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1]) |
|
|
|
all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids] |
|
current_vis = generate_similiarity_map(images, attn_map, |
|
[tokenizer.decode([i]) for i in input_ids], |
|
[], target_ratio, src_size) |
|
|
|
current_bpe = [tokenizer.decode([i]) for i in input_ids] |
|
|
|
|
|
|
|
|
|
|
|
return image, current_vis, current_bpe |
|
|
|
|
|
|
|
def update_index(direction, current_vis, current_bpe, current_index): |
|
|
|
new_index = max(0, min(current_index + direction, len(current_vis) - 1)) |
|
|
|
|
|
return ( |
|
current_vis[new_index], |
|
format_bpe_display(current_bpe[new_index]), |
|
new_index |
|
) |
|
|
|
def format_bpe_display(bpe): |
|
|
|
return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>" |
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="BPE Visualization Demo") as demo: |
|
gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.5): |
|
model_type = gr.Dropdown( |
|
choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg"], |
|
label="Select model type", |
|
value="TokenFD_4096_English_seg" |
|
) |
|
image_input = gr.Image(label="Upload images", type="pil") |
|
text_input = gr.Textbox(label="Input text on the image") |
|
|
|
run_btn = gr.Button("RUN") |
|
|
|
gr.Examples( |
|
examples=[ |
|
[os.path.join("examples", "examples0.jpg"), "and website not "], |
|
[os.path.join("examples", "examples1.jpg"), "STARBUCKS Refreshers 12 90 "], |
|
[os.path.join("examples", "examples2.png"), "Vision Transformer "] |
|
], |
|
inputs=[image_input, text_input], |
|
label="Sample input" |
|
) |
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value. Since there are fewer Chinese images in public data than English, <span style='color:red;'>we recommend you use the TokenFD-4096-English-seg version.</span></p>") |
|
|
|
with gr.Row(): |
|
orig_img = gr.Image(label="Original picture", interactive=False) |
|
heatmap = gr.Image(label="BPE visualization", interactive=False) |
|
|
|
with gr.Row() as controls: |
|
prev_btn = gr.Button("⬅ Last", visible=False) |
|
next_btn = gr.Button("⮕ Next", visible=False) |
|
|
|
bpe_display = gr.Markdown("Current BPE: ") |
|
|
|
current_vis_state = gr.State([]) |
|
current_bpe_state = gr.State([]) |
|
current_index_state = gr.State(0) |
|
|
|
|
|
@spaces.GPU |
|
def on_run_clicked(model_type, image, text): |
|
|
|
current_index = 0 |
|
image, current_vis, current_bpe = process_image(*load_model(model_type), model_type, image, text) |
|
|
|
bpe_text = format_bpe_display(current_bpe) |
|
|
|
print("current_vis",len(current_vis)) |
|
print("current_bpe",len(current_bpe)) |
|
|
|
|
|
return ( |
|
image, |
|
current_vis[current_index], |
|
format_bpe_display(current_bpe[current_index]), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
current_vis, |
|
current_bpe, |
|
current_index |
|
) |
|
|
|
|
|
|
|
run_btn.click( |
|
on_run_clicked, |
|
inputs=[model_type, image_input, text_input], |
|
outputs=[orig_img, heatmap, bpe_display, prev_btn, next_btn, current_vis_state, current_bpe_state, current_index_state] |
|
) |
|
|
|
prev_btn.click( |
|
update_index, |
|
inputs=[gr.State(-1), current_vis_state, current_bpe_state, current_index_state], |
|
outputs=[heatmap, bpe_display, current_index_state] |
|
) |
|
|
|
next_btn.click( |
|
update_index, |
|
inputs=[gr.State(1), current_vis_state, current_bpe_state, current_index_state], |
|
outputs=[heatmap, bpe_display, current_index_state] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |