zzzzzeee's picture
Create app.py
c441492 verified
raw
history blame
3.57 kB
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
from functools import partial
from dinov2 import DinoVisionTransformer
# 定义模型(使用 vit_small 作为示例)
def create_model():
model = DinoVisionTransformer(
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=4, # 设置 4 个 register tokens 用于测试
)
model.eval()
return model
# 图像预处理
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img_tensor = transform(image).unsqueeze(0) # 添加 batch 维度
return img_tensor
# 修改 forward_features 以输出中间形状
def forward_with_shapes(model, x, masks=None):
shapes = []
# 1. Patch Embedding
x = model.patch_embed(x)
shapes.append(f"After patch_embed: {x.shape}")
# 2. Prepare tokens with masks
B, nc, w, h = x.shape[0], 3, 224, 224 # 原始图像尺寸
if masks is not None:
x = torch.where(masks.unsqueeze(-1), model.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((model.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
shapes.append(f"After adding cls_token: {x.shape}")
# 3. Position Encoding
x = x + model.interpolate_pos_encoding(x, w, h)
shapes.append(f"After adding pos_embed: {x.shape}")
# 4. Register Tokens
if model.num_register_tokens > 0:
x = torch.cat(
(
x[:, :1], # cls_token
model.register_tokens.expand(x.shape[0], -1, -1), # register tokens
x[:, 1:], # patch tokens
),
dim=1,
)
shapes.append(f"After adding register_tokens: {x.shape}")
# 5. Transformer Blocks
for i, blk in enumerate(model.blocks):
x = blk(x)
shapes.append(f"After block {i+1}: {x.shape}")
# 6. Normalization
x_norm = model.norm(x)
shapes.append(f"After norm: {x_norm.shape}")
# 7. Extract outputs
output = {
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : model.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, model.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
shapes.append(f"x_norm_clstoken: {output['x_norm_clstoken'].shape}")
shapes.append(f"x_norm_regtokens: {output['x_norm_regtokens'].shape}")
shapes.append(f"x_norm_patchtokens: {output['x_norm_patchtokens'].shape}")
shapes.append(f"x_prenorm: {output['x_prenorm'].shape}")
return output, shapes
# 主处理函数
def process_image(image):
model = create_model()
img_tensor = preprocess_image(image)
# 前向传播并获取形状
output, shapes = forward_with_shapes(model, img_tensor)
# 将形状列表转换为字符串
shapes_text = "\n".join(shapes)
return shapes_text
# Gradio 界面
demo = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload an Image"),
outputs=gr.Textbox(label="Feature Map Shapes"),
title="DinoVisionTransformer Feature Map Shapes",
description="Upload an image to see the shapes of feature maps at each step of DinoVisionTransformer (vit_small, 4 register tokens).",
)
demo.launch()