TongkunGuan's picture
Update app.py
51bc3e3 verified
import os
import argparse
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as T
from transformers import AutoTokenizer
import gradio as gr
from resnet50 import build_model
# from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50
from utils import generate_similiarity_map, get_transform, post_process, load_tokenizer, build_transform_R50
from utils import IMAGENET_MEAN, IMAGENET_STD
from internvl.train.dataset import dynamic_preprocess
from internvl.model.internvl_chat import InternVLChatModel
import spaces
# 模型配置
CHECKPOINTS = {
"TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg",
"TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg",
}
# 全局变量
HF_TOKEN = os.getenv("HF_TOKEN")
def load_model(check_type):
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda")
if check_type == 'R50':
tokenizer = load_tokenizer('tokenizer_path')
model = build_model(argparse.Namespace()).eval()
model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model'])
transform = build_transform_R50(normalize_type='imagenet')
elif check_type == 'R50_siglip':
tokenizer = load_tokenizer('tokenizer_path')
model = build_model(argparse.Namespace()).eval()
model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model'])
transform = build_transform_R50(normalize_type='imagenet')
elif 'TokenFD' in check_type:
model_path = CHECKPOINTS[check_type]
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN)
# model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval()
model = InternVLChatModel.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 ,load_in_8bit=False, load_in_4bit=False).eval()
transform = get_transform(is_train=False, image_size=model.config.force_image_size)
return model.to(device), tokenizer, transform, device
def process_image(model, tokenizer, transform, device, check_type, image, text):
src_size = image.size
if 'TokenFD' in check_type:
images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
image_size=model.config.force_image_size,
use_thumbnail=model.config.use_thumbnail,
return_ratio=True)
pixel_values = torch.stack([transform(img) for img in images]).to(device)
else:
pixel_values = torch.stack([transform(image)]).to(device)
target_ratio = (1, 1)
# 文本处理
text_input = text
if text_input[0] in '!"#$%&\'()*+,-./0123456789:;<=>?@^_{|}~0123456789':
input_ids = tokenizer(text_input)['input_ids'][1:]
else:
input_ids = tokenizer(' '+text_input)['input_ids'][1:]
input_ids = torch.tensor(input_ids, device=device)
# 获取嵌入
with torch.no_grad():
if 'R50' in check_type:
text_embeds = model.language_embedding(input_ids)
else:
text_embeds = model.tok_embeddings(input_ids).clone()
vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
# print("vit_embeds",vit_embeds)
# print("vit_embeds,shape",vit_embeds.shape)
# print("target_ratio",target_ratio)
print("check_type",check_type)
vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
# 计算相似度
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
similarity = text_embeds @ vit_embeds.T
resized_size = size1 if size1 is not None else size2
# print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192
# print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944
# print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912
# 生成可视化
attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
# attn_map = similarity.reshape(len(text_embeds), *target_ratio)
all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
current_vis = generate_similiarity_map(images, attn_map,
[tokenizer.decode([i]) for i in input_ids],
[], target_ratio, src_size)
current_bpe = [tokenizer.decode([i]) for i in input_ids]
# current_bpe[-1] = 'Input text'
# current_bpe.append(text)
return image, current_vis, current_bpe
# 事件处理函数
# 上一项和下一项按钮
def update_index(direction, current_vis, current_bpe, current_index):
# 计算新的索引
new_index = max(0, min(current_index + direction, len(current_vis) - 1))
# 更新可视化内容
return (
current_vis[new_index],
format_bpe_display(current_bpe[new_index]),
new_index # 更新索引
)
def format_bpe_display(bpe):
# 使用HTML标签来设置字体大小、颜色,加粗,并居中
return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
# Gradio界面
with gr.Blocks(title="BPE Visualization Demo") as demo:
gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
with gr.Row():
with gr.Column(scale=0.5):
model_type = gr.Dropdown(
choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg"],
label="Select model type",
value="TokenFD_4096_English_seg" # 设置默认值为第一个选项
)
image_input = gr.Image(label="Upload images", type="pil")
text_input = gr.Textbox(label="Input text on the image")
run_btn = gr.Button("RUN")
gr.Examples(
examples=[
[os.path.join("examples", "examples0.jpg"), "and website not "],
[os.path.join("examples", "examples1.jpg"), "STARBUCKS Refreshers 12 90 "],
[os.path.join("examples", "examples2.png"), "Vision Transformer "]
],
inputs=[image_input, text_input],
label="Sample input"
)
with gr.Column(scale=2):
gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value. Since there are fewer Chinese images in public data than English, <span style='color:red;'>we recommend you use the TokenFD-4096-English-seg version.</span></p>")
with gr.Row():
orig_img = gr.Image(label="Original picture", interactive=False)
heatmap = gr.Image(label="BPE visualization", interactive=False)
with gr.Row() as controls:
prev_btn = gr.Button("⬅ Last", visible=False)
next_btn = gr.Button("⮕ Next", visible=False)
bpe_display = gr.Markdown("Current BPE: ")
current_vis_state = gr.State([])
current_bpe_state = gr.State([])
current_index_state = gr.State(0)
# 事件处理
@spaces.GPU
def on_run_clicked(model_type, image, text):
current_index = 0 # Reset index when new image is processed
image, current_vis, current_bpe = process_image(*load_model(model_type), model_type, image, text)
bpe_text = format_bpe_display(current_bpe)
print("current_vis",len(current_vis))
print("current_bpe",len(current_bpe))
# return image, current_vis[0],f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{current_bpe[0]}</span></strong></div>", gr.update(visible=True), gr.update(visible=True)
return (
image,
current_vis[current_index],
format_bpe_display(current_bpe[current_index]),
gr.update(visible=True),
gr.update(visible=True),
current_vis, # 存储整个列表
current_bpe, # 存储整个列表
current_index # 存储当前索引
)
run_btn.click(
on_run_clicked,
inputs=[model_type, image_input, text_input],
outputs=[orig_img, heatmap, bpe_display, prev_btn, next_btn, current_vis_state, current_bpe_state, current_index_state]
)
prev_btn.click(
update_index,
inputs=[gr.State(-1), current_vis_state, current_bpe_state, current_index_state],
outputs=[heatmap, bpe_display, current_index_state]
)
next_btn.click(
update_index,
inputs=[gr.State(1), current_vis_state, current_bpe_state, current_index_state],
outputs=[heatmap, bpe_display, current_index_state]
)
if __name__ == "__main__":
demo.launch()