Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from functools import partial | |
from dinov2 import DinoVisionTransformer | |
# 定义模型(使用 vit_small 作为示例) | |
def create_model(): | |
model = DinoVisionTransformer( | |
img_size=224, | |
patch_size=16, | |
in_chans=3, | |
embed_dim=384, | |
depth=12, | |
num_heads=6, | |
mlp_ratio=4, | |
block_fn=partial(Block, attn_class=MemEffAttention), | |
num_register_tokens=4, # 设置 4 个 register tokens 用于测试 | |
) | |
model.eval() | |
return model | |
# 图像预处理 | |
def preprocess_image(image): | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
img_tensor = transform(image).unsqueeze(0) # 添加 batch 维度 | |
return img_tensor | |
# 修改 forward_features 以输出中间形状 | |
def forward_with_shapes(model, x, masks=None): | |
shapes = [] | |
# 1. Patch Embedding | |
x = model.patch_embed(x) | |
shapes.append(f"After patch_embed: {x.shape}") | |
# 2. Prepare tokens with masks | |
B, nc, w, h = x.shape[0], 3, 224, 224 # 原始图像尺寸 | |
if masks is not None: | |
x = torch.where(masks.unsqueeze(-1), model.mask_token.to(x.dtype).unsqueeze(0), x) | |
x = torch.cat((model.cls_token.expand(x.shape[0], -1, -1), x), dim=1) | |
shapes.append(f"After adding cls_token: {x.shape}") | |
# 3. Position Encoding | |
x = x + model.interpolate_pos_encoding(x, w, h) | |
shapes.append(f"After adding pos_embed: {x.shape}") | |
# 4. Register Tokens | |
if model.num_register_tokens > 0: | |
x = torch.cat( | |
( | |
x[:, :1], # cls_token | |
model.register_tokens.expand(x.shape[0], -1, -1), # register tokens | |
x[:, 1:], # patch tokens | |
), | |
dim=1, | |
) | |
shapes.append(f"After adding register_tokens: {x.shape}") | |
# 5. Transformer Blocks | |
for i, blk in enumerate(model.blocks): | |
x = blk(x) | |
shapes.append(f"After block {i+1}: {x.shape}") | |
# 6. Normalization | |
x_norm = model.norm(x) | |
shapes.append(f"After norm: {x_norm.shape}") | |
# 7. Extract outputs | |
output = { | |
"x_norm_clstoken": x_norm[:, 0], | |
"x_norm_regtokens": x_norm[:, 1 : model.num_register_tokens + 1], | |
"x_norm_patchtokens": x_norm[:, model.num_register_tokens + 1 :], | |
"x_prenorm": x, | |
"masks": masks, | |
} | |
shapes.append(f"x_norm_clstoken: {output['x_norm_clstoken'].shape}") | |
shapes.append(f"x_norm_regtokens: {output['x_norm_regtokens'].shape}") | |
shapes.append(f"x_norm_patchtokens: {output['x_norm_patchtokens'].shape}") | |
shapes.append(f"x_prenorm: {output['x_prenorm'].shape}") | |
return output, shapes | |
# 主处理函数 | |
def process_image(image): | |
model = create_model() | |
img_tensor = preprocess_image(image) | |
# 前向传播并获取形状 | |
output, shapes = forward_with_shapes(model, img_tensor) | |
# 将形状列表转换为字符串 | |
shapes_text = "\n".join(shapes) | |
return shapes_text | |
# Gradio 界面 | |
demo = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type="pil", label="Upload an Image"), | |
outputs=gr.Textbox(label="Feature Map Shapes"), | |
title="DinoVisionTransformer Feature Map Shapes", | |
description="Upload an image to see the shapes of feature maps at each step of DinoVisionTransformer (vit_small, 4 register tokens).", | |
) | |
demo.launch() | |