TongkunGuan's picture
Update app.py
b6f1806 verified
raw
history blame
16.4 kB
# import os
# import argparse
# import numpy as np
# from PIL import Image
# import torch
# import torchvision.transforms as T
# from transformers import AutoTokenizer
# import gradio as gr
# from resnet50 import build_model
# from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50
# from utils import IMAGENET_MEAN, IMAGENET_STD
# from internvl.train.dataset import dynamic_preprocess
# from internvl.model.internvl_chat import InternVLChatModel
# import spaces
# # 模型配置
# CHECKPOINTS = {
# "TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg",
# "TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg",
# }
# # 全局变量
# HF_TOKEN = os.getenv("HF_TOKEN")
# current_vis = []
# current_bpe = []
# current_index = 0
# def load_model(check_type):
# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda")
# if check_type == 'R50':
# tokenizer = load_tokenizer('tokenizer_path')
# model = build_model(argparse.Namespace()).eval()
# model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model'])
# transform = build_transform_R50(normalize_type='imagenet')
# elif check_type == 'R50_siglip':
# tokenizer = load_tokenizer('tokenizer_path')
# model = build_model(argparse.Namespace()).eval()
# model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model'])
# transform = build_transform_R50(normalize_type='imagenet')
# elif 'TokenFD' in check_type:
# model_path = CHECKPOINTS[check_type]
# tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN)
# model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval()
# transform = T.Compose([
# T.Lambda(lambda img: img.convert('RGB')),
# T.Resize((224, 224)),
# T.ToTensor(),
# T.Normalize(IMAGENET_MEAN, IMAGENET_STD)
# ])
# return model.to(device), tokenizer, transform, device
# def process_image(model, tokenizer, transform, device, check_type, image, text):
# global current_vis, current_bpe, current_index
# src_size = image.size
# if 'TokenOCR' in check_type:
# images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
# image_size=model.config.force_image_size,
# use_thumbnail=model.config.use_thumbnail,
# return_ratio=True)
# pixel_values = torch.stack([transform(img) for img in images]).to(device)
# else:
# pixel_values = torch.stack([transform(image)]).to(device)
# target_ratio = (1, 1)
# # 文本处理
# text += ' '
# input_ids = tokenizer(text)['input_ids'][1:]
# input_ids = torch.tensor(input_ids, device=device)
# # 获取嵌入
# with torch.no_grad():
# if 'R50' in check_type:
# text_embeds = model.language_embedding(input_ids)
# else:
# text_embeds = model.tok_embeddings(input_ids)
# vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
# print("vit_embeds",vit_embeds)
# print("vit_embeds,shape",vit_embeds.shape)
# print("target_ratio",target_ratio)
# print("check_type",check_type)
# vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
# # 计算相似度
# text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
# vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
# similarity = text_embeds @ vit_embeds.T
# resized_size = size1 if size1 is not None else size2
# # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192
# # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944
# # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912
# # 生成可视化
# attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
# # attn_map = similarity.reshape(len(text_embeds), *target_ratio)
# all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
# current_vis = generate_similiarity_map([image], attn_map,
# [tokenizer.decode([i]) for i in input_ids],
# [], target_ratio, src_size)
# current_bpe = [tokenizer.decode([i]) for i in input_ids]
# # current_bpe[-1] = 'Input text'
# current_bpe[-1] = text
# print("current_vis",len(current_vis))
# print("current_bpe",len(current_bpe))
# return image, current_vis[0], current_bpe[0]
# # 事件处理函数
# def update_index(change):
# global current_vis, current_bpe, current_index
# current_index = max(0, min(len(current_vis) - 1, current_index + change))
# return current_vis[current_index], format_bpe_display(current_bpe[current_index])
# def format_bpe_display(bpe):
# # 使用HTML标签来设置字体大小、颜色,加粗,并居中
# return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
# def update_slider_index(x):
# print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}")
# if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
# return current_vis[x], format_bpe_display(current_bpe[x])
# else:
# return None, "索引超出范围"
# # Gradio界面
# with gr.Blocks(title="BPE Visualization Demo") as demo:
# gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
# with gr.Row():
# with gr.Column(scale=0.5):
# model_type = gr.Dropdown(
# choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
# label="Select model type",
# value="TokenOCR_4096_English_seg" # 设置默认值为第一个选项
# )
# image_input = gr.Image(label="Upload images", type="pil")
# text_input = gr.Textbox(label="Input text")
# run_btn = gr.Button("RUN")
# gr.Examples(
# examples=[
# [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"],
# [os.path.join("examples", "examples1.jpg"), "Refreshers"],
# [os.path.join("examples", "examples2.png"), "Vision Transformer"]
# ],
# inputs=[image_input, text_input],
# label="Sample input"
# )
# with gr.Column(scale=2):
# gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.</p>")
# with gr.Row():
# orig_img = gr.Image(label="Original picture", interactive=False)
# heatmap = gr.Image(label="BPE visualization", interactive=False)
# with gr.Row() as controls:
# prev_btn = gr.Button("⬅ Last", visible=False)
# index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False)
# next_btn = gr.Button("⮕ Next", visible=False)
# bpe_display = gr.Markdown("Current BPE: ", visible=False)
# # 事件处理
# @spaces.GPU
# def on_run_clicked(model_type, image, text):
# global current_vis, current_bpe, current_index
# current_index = 0 # Reset index when new image is processed
# image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
# # Update the slider range and set value to 0
# slider_max_val = len(current_bpe) - 1
# bpe_text = format_bpe_display(bpe)
# print("current_vis",len(current_vis))
# print("current_bpe",len(current_bpe))
# return image, vis, bpe_text, slider_max_val
# run_btn.click(
# on_run_clicked,
# inputs=[model_type, image_input, text_input],
# outputs=[orig_img, heatmap, bpe_display, index_slider],
# ).then(
# lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
# inputs=index_slider,
# outputs=[prev_btn, index_slider, next_btn, bpe_display],
# )
# prev_btn.click(
# lambda: (*update_index(-1), current_index),
# outputs=[heatmap, bpe_display, index_slider]
# )
# next_btn.click(
# lambda: (*update_index(1), current_index),
# outputs=[heatmap, bpe_display, index_slider]
# )
# # index_slider.change(
# # lambda x: (current_vis[x], format_bpe_display(current_bpe[x])) if 0<=x<len(current_vis else (None,"Invaild")
# # inputs=index_slider,
# # outputs=[heatmap, bpe_display]
# # )
# index_slider.change(
# update_slider_index,
# inputs=index_slider,
# outputs=[heatmap, bpe_display]
# )
# if __name__ == "__main__":
# demo.launch()
import os
import argparse
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as T
from transformers import AutoTokenizer
import gradio as gr
from resnet50 import build_model
from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50
from utils import IMAGENET_MEAN, IMAGENET_STD
from internvl.train.dataset import dynamic_preprocess
from internvl.model.internvl_chat import InternVLChatModel
import spaces
# 模型配置
CHECKPOINTS = {
"TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg",
"TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg",
}
# 全局变量
HF_TOKEN = os.getenv("HF_TOKEN")
current_vis = []
current_bpe = []
def load_model(check_type):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if check_type == 'R50':
tokenizer = load_tokenizer('tokenizer_path')
model = build_model(argparse.Namespace()).eval()
model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model'])
transform = build_transform_R50(normalize_type='imagenet')
elif check_type == 'R50_siglip':
tokenizer = load_tokenizer('tokenizer_path')
model = build_model(argparse.Namespace()).eval()
model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model'])
transform = build_transform_R50(normalize_type='imagenet')
elif 'TokenFD' in check_type:
model_path = CHECKPOINTS[check_type]
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN)
model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval()
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')),
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])
return model.to(device), tokenizer, transform, device
def process_image(model, tokenizer, transform, device, check_type, image, text):
global current_vis, current_bpe
src_size = image.size
if 'TokenOCR' in check_type:
images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
image_size=model.config.force_image_size,
use_thumbnail=model.config.use_thumbnail,
return_ratio=True)
pixel_values = torch.stack([transform(img) for img in images]).to(device)
else:
pixel_values = torch.stack([transform(image)]).to(device)
target_ratio = (1, 1)
# 文本处理
text += ' '
input_ids = tokenizer(text)['input_ids'][1:]
input_ids = torch.tensor(input_ids, device=device)
# 获取嵌入
with torch.no_grad():
if 'R50' in check_type:
text_embeds = model.language_embedding(input_ids)
else:
text_embeds = model.tok_embeddings(input_ids)
vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
# 计算相似度
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
similarity = text_embeds @ vit_embeds.T
resized_size = size1 if size1 is not None else size2
# 生成可视化
attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
current_vis = generate_similiarity_map([image], attn_map,
[tokenizer.decode([i]) for i in input_ids],
[], target_ratio, src_size)
current_bpe = [tokenizer.decode([i]) for i in input_ids]
current_bpe[-1] = text
return image, current_vis, current_bpe
def format_bpe_display(bpe):
return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
# Gradio界面
with gr.Blocks(title="BPE Visualization Demo") as demo:
gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
with gr.Row():
with gr.Column(scale=0.5):
model_type = gr.Dropdown(
choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
label="Select model type",
value="TokenOCR_4096_English_seg"
)
image_input = gr.Image(label="Upload images", type="pil")
text_input = gr.Textbox(label="Input text")
run_btn = gr.Button("RUN")
gr.Examples(
examples=[
[os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"],
[os.path.join("examples", "examples1.jpg"), "Refreshers"],
[os.path.join("examples", "examples2.png"), "Vision Transformer"]
],
inputs=[image_input, text_input],
label="Sample input"
)
with gr.Column(scale=2):
gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.</p>")
with gr.Row():
orig_img = gr.Image(label="Original picture", interactive=False)
heatmap = gr.Image(label="BPE visualization", interactive=False)
bpe_display = gr.Markdown("Current BPE: ", visible=False)
# 事件处理
@spaces.GPU
def on_run_clicked(model_type, image, text):
global current_vis, current_bpe
image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
bpe_text = format_bpe_display(bpe)
return image, vis[0], bpe_text
run_btn.click(
on_run_clicked,
inputs=[model_type, image_input, text_input],
outputs=[orig_img, heatmap, bpe_display],
).then(
lambda: (gr.update(visible=True)),
outputs=[bpe_display],
)
if __name__ == "__main__":
demo.launch()