import os import argparse import numpy as np from PIL import Image import torch import torchvision.transforms as T from transformers import AutoTokenizer import gradio as gr from resnet50 import build_model # from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50 from utils import generate_similiarity_map, get_transform, post_process, load_tokenizer, build_transform_R50 from utils import IMAGENET_MEAN, IMAGENET_STD from internvl.train.dataset import dynamic_preprocess from internvl.model.internvl_chat import InternVLChatModel import spaces # 模型配置 CHECKPOINTS = { "TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg", } # 全局变量 HF_TOKEN = os.getenv("HF_TOKEN") def load_model(check_type): # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda") if check_type == 'R50': tokenizer = load_tokenizer('tokenizer_path') model = build_model(argparse.Namespace()).eval() model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model']) transform = build_transform_R50(normalize_type='imagenet') elif check_type == 'R50_siglip': tokenizer = load_tokenizer('tokenizer_path') model = build_model(argparse.Namespace()).eval() model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model']) transform = build_transform_R50(normalize_type='imagenet') elif 'TokenFD' in check_type: model_path = CHECKPOINTS[check_type] tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN) # model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval() model = InternVLChatModel.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 ,load_in_8bit=False, load_in_4bit=False).eval() transform = get_transform(is_train=False, image_size=model.config.force_image_size) return model.to(device), tokenizer, transform, device def process_image(model, tokenizer, transform, device, check_type, image, text): src_size = image.size if 'TokenFD' in check_type: images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12, image_size=model.config.force_image_size, use_thumbnail=model.config.use_thumbnail, return_ratio=True) pixel_values = torch.stack([transform(img) for img in images]).to(device) else: pixel_values = torch.stack([transform(image)]).to(device) target_ratio = (1, 1) # 文本处理 text_input = text if text_input[0] in '!"#$%&\'()*+,-./0123456789:;<=>?@^_{|}~0123456789': input_ids = tokenizer(text_input)['input_ids'][1:] else: input_ids = tokenizer(' '+text_input)['input_ids'][1:] input_ids = torch.tensor(input_ids, device=device) # 获取嵌入 with torch.no_grad(): if 'R50' in check_type: text_embeds = model.language_embedding(input_ids) else: text_embeds = model.tok_embeddings(input_ids).clone() vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device)) # print("vit_embeds",vit_embeds) # print("vit_embeds,shape",vit_embeds.shape) # print("target_ratio",target_ratio) print("check_type",check_type) vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type) # 计算相似度 text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True) similarity = text_embeds @ vit_embeds.T resized_size = size1 if size1 is not None else size2 # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192 # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944 # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912 # 生成可视化 attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1]) # attn_map = similarity.reshape(len(text_embeds), *target_ratio) all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids] current_vis = generate_similiarity_map(images, attn_map, [tokenizer.decode([i]) for i in input_ids], [], target_ratio, src_size) current_bpe = [tokenizer.decode([i]) for i in input_ids] # current_bpe[-1] = 'Input text' # current_bpe.append(text) return image, current_vis, current_bpe # 事件处理函数 # 上一项和下一项按钮 def update_index(direction, current_vis, current_bpe, current_index): # 计算新的索引 new_index = max(0, min(current_index + direction, len(current_vis) - 1)) # 更新可视化内容 return ( current_vis[new_index], format_bpe_display(current_bpe[new_index]), new_index # 更新索引 ) def format_bpe_display(bpe): # 使用HTML标签来设置字体大小、颜色,加粗,并居中 return f"
Current BPE: {bpe}
" # Gradio界面 with gr.Blocks(title="BPE Visualization Demo") as demo: gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化") with gr.Row(): with gr.Column(scale=0.5): model_type = gr.Dropdown( choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg"], label="Select model type", value="TokenFD_4096_English_seg" # 设置默认值为第一个选项 ) image_input = gr.Image(label="Upload images", type="pil") text_input = gr.Textbox(label="Input text on the image") run_btn = gr.Button("RUN") gr.Examples( examples=[ [os.path.join("examples", "examples0.jpg"), "and website not "], [os.path.join("examples", "examples1.jpg"), "STARBUCKS Refreshers 12 90 "], [os.path.join("examples", "examples2.png"), "Vision Transformer "] ], inputs=[image_input, text_input], label="Sample input" ) with gr.Column(scale=2): gr.Markdown("

If the input text is not included in the image, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value. Since there are fewer Chinese images in public data than English, we recommend you use the TokenFD-4096-English-seg version.

") with gr.Row(): orig_img = gr.Image(label="Original picture", interactive=False) heatmap = gr.Image(label="BPE visualization", interactive=False) with gr.Row() as controls: prev_btn = gr.Button("⬅ Last", visible=False) next_btn = gr.Button("⮕ Next", visible=False) bpe_display = gr.Markdown("Current BPE: ") current_vis_state = gr.State([]) current_bpe_state = gr.State([]) current_index_state = gr.State(0) # 事件处理 @spaces.GPU def on_run_clicked(model_type, image, text): current_index = 0 # Reset index when new image is processed image, current_vis, current_bpe = process_image(*load_model(model_type), model_type, image, text) bpe_text = format_bpe_display(current_bpe) print("current_vis",len(current_vis)) print("current_bpe",len(current_bpe)) # return image, current_vis[0],f"
Current BPE: {current_bpe[0]}
", gr.update(visible=True), gr.update(visible=True) return ( image, current_vis[current_index], format_bpe_display(current_bpe[current_index]), gr.update(visible=True), gr.update(visible=True), current_vis, # 存储整个列表 current_bpe, # 存储整个列表 current_index # 存储当前索引 ) run_btn.click( on_run_clicked, inputs=[model_type, image_input, text_input], outputs=[orig_img, heatmap, bpe_display, prev_btn, next_btn, current_vis_state, current_bpe_state, current_index_state] ) prev_btn.click( update_index, inputs=[gr.State(-1), current_vis_state, current_bpe_state, current_index_state], outputs=[heatmap, bpe_display, current_index_state] ) next_btn.click( update_index, inputs=[gr.State(1), current_vis_state, current_bpe_state, current_index_state], outputs=[heatmap, bpe_display, current_index_state] ) if __name__ == "__main__": demo.launch()