# import os # import argparse # import numpy as np # from PIL import Image # import torch # import torchvision.transforms as T # from transformers import AutoTokenizer # import gradio as gr # from resnet50 import build_model # from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50 # from utils import IMAGENET_MEAN, IMAGENET_STD # from internvl.train.dataset import dynamic_preprocess # from internvl.model.internvl_chat import InternVLChatModel # import spaces # # 模型配置 # CHECKPOINTS = { # "TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg", # "TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg", # } # # 全局变量 # HF_TOKEN = os.getenv("HF_TOKEN") # current_vis = [] # current_bpe = [] # current_index = 0 # def load_model(check_type): # # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cuda") # if check_type == 'R50': # tokenizer = load_tokenizer('tokenizer_path') # model = build_model(argparse.Namespace()).eval() # model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model']) # transform = build_transform_R50(normalize_type='imagenet') # elif check_type == 'R50_siglip': # tokenizer = load_tokenizer('tokenizer_path') # model = build_model(argparse.Namespace()).eval() # model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model']) # transform = build_transform_R50(normalize_type='imagenet') # elif 'TokenFD' in check_type: # model_path = CHECKPOINTS[check_type] # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN) # model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval() # transform = T.Compose([ # T.Lambda(lambda img: img.convert('RGB')), # T.Resize((224, 224)), # T.ToTensor(), # T.Normalize(IMAGENET_MEAN, IMAGENET_STD) # ]) # return model.to(device), tokenizer, transform, device # def process_image(model, tokenizer, transform, device, check_type, image, text): # global current_vis, current_bpe, current_index # src_size = image.size # if 'TokenOCR' in check_type: # images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12, # image_size=model.config.force_image_size, # use_thumbnail=model.config.use_thumbnail, # return_ratio=True) # pixel_values = torch.stack([transform(img) for img in images]).to(device) # else: # pixel_values = torch.stack([transform(image)]).to(device) # target_ratio = (1, 1) # # 文本处理 # text += ' ' # input_ids = tokenizer(text)['input_ids'][1:] # input_ids = torch.tensor(input_ids, device=device) # # 获取嵌入 # with torch.no_grad(): # if 'R50' in check_type: # text_embeds = model.language_embedding(input_ids) # else: # text_embeds = model.tok_embeddings(input_ids) # vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device)) # print("vit_embeds",vit_embeds) # print("vit_embeds,shape",vit_embeds.shape) # print("target_ratio",target_ratio) # print("check_type",check_type) # vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type) # # 计算相似度 # text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) # vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True) # similarity = text_embeds @ vit_embeds.T # resized_size = size1 if size1 is not None else size2 # # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192 # # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944 # # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912 # # 生成可视化 # attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1]) # # attn_map = similarity.reshape(len(text_embeds), *target_ratio) # all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids] # current_vis = generate_similiarity_map([image], attn_map, # [tokenizer.decode([i]) for i in input_ids], # [], target_ratio, src_size) # current_bpe = [tokenizer.decode([i]) for i in input_ids] # # current_bpe[-1] = 'Input text' # current_bpe[-1] = text # print("current_vis",len(current_vis)) # print("current_bpe",len(current_bpe)) # return image, current_vis[0], current_bpe[0] # # 事件处理函数 # def update_index(change): # global current_vis, current_bpe, current_index # current_index = max(0, min(len(current_vis) - 1, current_index + change)) # return current_vis[current_index], format_bpe_display(current_bpe[current_index]) # def format_bpe_display(bpe): # # 使用HTML标签来设置字体大小、颜色,加粗,并居中 # return f"
Current BPE: {bpe}
" # def update_slider_index(x): # print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}") # if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe): # return current_vis[x], format_bpe_display(current_bpe[x]) # else: # return None, "索引超出范围" # # Gradio界面 # with gr.Blocks(title="BPE Visualization Demo") as demo: # gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化") # with gr.Row(): # with gr.Column(scale=0.5): # model_type = gr.Dropdown( # choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"], # label="Select model type", # value="TokenOCR_4096_English_seg" # 设置默认值为第一个选项 # ) # image_input = gr.Image(label="Upload images", type="pil") # text_input = gr.Textbox(label="Input text") # run_btn = gr.Button("RUN") # gr.Examples( # examples=[ # [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"], # [os.path.join("examples", "examples1.jpg"), "Refreshers"], # [os.path.join("examples", "examples2.png"), "Vision Transformer"] # ], # inputs=[image_input, text_input], # label="Sample input" # ) # with gr.Column(scale=2): # gr.Markdown("

If the input text is not included in the image, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.

") # with gr.Row(): # orig_img = gr.Image(label="Original picture", interactive=False) # heatmap = gr.Image(label="BPE visualization", interactive=False) # with gr.Row() as controls: # prev_btn = gr.Button("⬅ Last", visible=False) # index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False) # next_btn = gr.Button("⮕ Next", visible=False) # bpe_display = gr.Markdown("Current BPE: ", visible=False) # # 事件处理 # @spaces.GPU # def on_run_clicked(model_type, image, text): # global current_vis, current_bpe, current_index # current_index = 0 # Reset index when new image is processed # image, vis, bpe = process_image(*load_model(model_type), model_type, image, text) # # Update the slider range and set value to 0 # slider_max_val = len(current_bpe) - 1 # bpe_text = format_bpe_display(bpe) # print("current_vis",len(current_vis)) # print("current_bpe",len(current_bpe)) # return image, vis, bpe_text, slider_max_val # run_btn.click( # on_run_clicked, # inputs=[model_type, image_input, text_input], # outputs=[orig_img, heatmap, bpe_display, index_slider], # ).then( # lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)), # inputs=index_slider, # outputs=[prev_btn, index_slider, next_btn, bpe_display], # ) # prev_btn.click( # lambda: (*update_index(-1), current_index), # outputs=[heatmap, bpe_display, index_slider] # ) # next_btn.click( # lambda: (*update_index(1), current_index), # outputs=[heatmap, bpe_display, index_slider] # ) # # index_slider.change( # # lambda x: (current_vis[x], format_bpe_display(current_bpe[x])) if 0<=xCurrent BPE: {bpe}" # Gradio界面 with gr.Blocks(title="BPE Visualization Demo") as demo: gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化") with gr.Row(): with gr.Column(scale=0.5): model_type = gr.Dropdown( choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"], label="Select model type", value="TokenOCR_4096_English_seg" ) image_input = gr.Image(label="Upload images", type="pil") text_input = gr.Textbox(label="Input text") run_btn = gr.Button("RUN") gr.Examples( examples=[ [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"], [os.path.join("examples", "examples1.jpg"), "Refreshers"], [os.path.join("examples", "examples2.png"), "Vision Transformer"] ], inputs=[image_input, text_input], label="Sample input" ) with gr.Column(scale=2): gr.Markdown("

If the input text is not included in the image, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.

") with gr.Row(): orig_img = gr.Image(label="Original picture", interactive=False) heatmap = gr.Image(label="BPE visualization", interactive=False) bpe_display = gr.Markdown("Current BPE: ", visible=False) # 事件处理 @spaces.GPU def on_run_clicked(model_type, image, text): global current_vis, current_bpe image, vis, bpe = process_image(*load_model(model_type), model_type, image, text) bpe_text = format_bpe_display(bpe) return image, vis[0], bpe_text run_btn.click( on_run_clicked, inputs=[model_type, image_input, text_input], outputs=[orig_img, heatmap, bpe_display], ).then( lambda: (gr.update(visible=True)), outputs=[bpe_display], ) if __name__ == "__main__": demo.launch()