import os import argparse import numpy as np from PIL import Image import torch import torchvision.transforms as T from transformers import AutoTokenizer import gradio as gr from resnet50 import build_model # from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50 from utils import generate_similiarity_map, get_transform, post_process, load_tokenizer, build_transform_R50 from utils import IMAGENET_MEAN, IMAGENET_STD from internvl.train.dataset import dynamic_preprocess from internvl.model.internvl_chat import InternVLChatModel import spaces # 模型配置 CHECKPOINTS = { "TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg", } # 全局变量 HF_TOKEN = os.getenv("HF_TOKEN") def load_model(check_type): # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda") if check_type == 'R50': tokenizer = load_tokenizer('tokenizer_path') model = build_model(argparse.Namespace()).eval() model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model']) transform = build_transform_R50(normalize_type='imagenet') elif check_type == 'R50_siglip': tokenizer = load_tokenizer('tokenizer_path') model = build_model(argparse.Namespace()).eval() model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model']) transform = build_transform_R50(normalize_type='imagenet') elif 'TokenFD' in check_type: model_path = CHECKPOINTS[check_type] tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN) # model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval() model = InternVLChatModel.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 ,load_in_8bit=False, load_in_4bit=False).eval() transform = get_transform(is_train=False, image_size=model.config.force_image_size) return model.to(device), tokenizer, transform, device def process_image(model, tokenizer, transform, device, check_type, image, text): src_size = image.size if 'TokenFD' in check_type: images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12, image_size=model.config.force_image_size, use_thumbnail=model.config.use_thumbnail, return_ratio=True) pixel_values = torch.stack([transform(img) for img in images]).to(device) else: pixel_values = torch.stack([transform(image)]).to(device) target_ratio = (1, 1) # 文本处理 text_input = text if text_input[0] in '!"#$%&\'()*+,-./0123456789:;<=>?@^_{|}~0123456789': input_ids = tokenizer(text_input)['input_ids'][1:] else: input_ids = tokenizer(' '+text_input)['input_ids'][1:] input_ids = torch.tensor(input_ids, device=device) # 获取嵌入 with torch.no_grad(): if 'R50' in check_type: text_embeds = model.language_embedding(input_ids) else: text_embeds = model.tok_embeddings(input_ids).clone() vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device)) # print("vit_embeds",vit_embeds) # print("vit_embeds,shape",vit_embeds.shape) # print("target_ratio",target_ratio) print("check_type",check_type) vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type) # 计算相似度 text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True) similarity = text_embeds @ vit_embeds.T resized_size = size1 if size1 is not None else size2 # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192 # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944 # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912 # 生成可视化 attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1]) # attn_map = similarity.reshape(len(text_embeds), *target_ratio) all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids] current_vis = generate_similiarity_map(images, attn_map, [tokenizer.decode([i]) for i in input_ids], [], target_ratio, src_size) current_bpe = [tokenizer.decode([i]) for i in input_ids] # current_bpe[-1] = 'Input text' # current_bpe.append(text) return image, current_vis, current_bpe # 事件处理函数 # 上一项和下一项按钮 def update_index(direction, current_vis, current_bpe, current_index): # 计算新的索引 new_index = max(0, min(current_index + direction, len(current_vis) - 1)) # 更新可视化内容 return ( current_vis[new_index], format_bpe_display(current_bpe[new_index]), new_index # 更新索引 ) def format_bpe_display(bpe): # 使用HTML标签来设置字体大小、颜色,加粗,并居中 return f"
If the input text is not included in the image, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value. Since there are fewer Chinese images in public data than English, we recommend you use the TokenFD-4096-English-seg version.
") with gr.Row(): orig_img = gr.Image(label="Original picture", interactive=False) heatmap = gr.Image(label="BPE visualization", interactive=False) with gr.Row() as controls: prev_btn = gr.Button("⬅ Last", visible=False) next_btn = gr.Button("⮕ Next", visible=False) bpe_display = gr.Markdown("Current BPE: ") current_vis_state = gr.State([]) current_bpe_state = gr.State([]) current_index_state = gr.State(0) # 事件处理 @spaces.GPU def on_run_clicked(model_type, image, text): current_index = 0 # Reset index when new image is processed image, current_vis, current_bpe = process_image(*load_model(model_type), model_type, image, text) bpe_text = format_bpe_display(current_bpe) print("current_vis",len(current_vis)) print("current_bpe",len(current_bpe)) # return image, current_vis[0],f"