Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from functools import partial | |
from dinov2 import DinoVisionTransformer | |
from dinov2layers.block import Block | |
from dinov2layers.attention import MemEffAttention | |
# 定义模型(使用 vit_small 作为示例) | |
def create_model(height, width): | |
model = DinoVisionTransformer( | |
img_size=(height, width), # 动态调整输入尺寸 | |
patch_size=16, | |
in_chans=3, | |
embed_dim=384, | |
depth=12, | |
num_heads=6, | |
mlp_ratio=4, | |
block_fn=partial(Block, attn_class=MemEffAttention), | |
num_register_tokens=4, # 设置 4 个 register tokens 用于测试 | |
) | |
model.eval() | |
return model | |
# 生成随机输入张量 | |
def generate_input(batch_size, channels, height, width): | |
# 确保输入尺寸是 patch_size 的整数倍 | |
patch_size = 16 | |
height = (height // patch_size) * patch_size | |
width = (width // patch_size) * patch_size | |
img_tensor = torch.randn(batch_size, channels, height, width) | |
return img_tensor | |
# 修改 forward_features 以输出中间形状 | |
def forward_with_shapes(model, x, masks=None): | |
shapes = [] | |
# 1. Patch Embedding | |
x = model.patch_embed(x) | |
shapes.append(f"After patch_embed: {x.shape}") | |
# 2. Prepare tokens with masks | |
B, nc, w, h = x.shape[0], 3, x.shape[-2], x.shape[-1] # 使用动态输入尺寸 | |
if masks is not None: | |
x = torch.where(masks.unsqueeze(-1), model.mask_token.to(x.dtype).unsqueeze(0), x) | |
x = torch.cat((model.cls_token.expand(x.shape[0], -1, -1), x), dim=1) | |
shapes.append(f"After adding cls_token: {x.shape}") | |
# 3. Position Encoding | |
x = x + model.interpolate_pos_encoding(x, w, h) | |
shapes.append(f"After adding pos_embed: {x.shape}") | |
# 4. Register Tokens | |
if model.num_register_tokens > 0: | |
x = torch.cat( | |
( | |
x[:, :1], # cls_token | |
model.register_tokens.expand(x.shape[0], -1, -1), # register tokens | |
x[:, 1:], # patch tokens | |
), | |
dim=1, | |
) | |
shapes.append(f"After adding register_tokens: {x.shape}") | |
# 5. Transformer Blocks | |
for i, blk in enumerate(model.blocks): | |
x = blk(x) | |
shapes.append(f"After block {i+1}: {x.shape}") | |
# 6. Normalization | |
x_norm = model.norm(x) | |
shapes.append(f"After norm: {x_norm.shape}") | |
# 7. Extract outputs | |
output = { | |
"x_norm_clstoken": x_norm[:, 0], | |
"x_norm_regtokens": x_norm[:, 1 : model.num_register_tokens + 1], | |
"x_norm_patchtokens": x_norm[:, model.num_register_tokens + 1 :], | |
"x_prenorm": x, | |
"masks": masks, | |
} | |
shapes.append(f"x_norm_clstoken: {output['x_norm_clstoken'].shape}") | |
shapes.append(f"x_norm_regtokens: {output['x_norm_regtokens'].shape}") | |
shapes.append(f"x_norm_patchtokens: {output['x_norm_patchtokens'].shape}") | |
shapes.append(f"x_prenorm: {output['x_prenorm'].shape}") | |
return output, shapes | |
# 主处理函数 | |
def process_tensor(batch_size, channels, height, width): | |
# 生成随机输入 | |
img_tensor = generate_input(batch_size, channels, height, width) | |
# 创建模型,动态调整 img_size | |
model = create_model(height=img_tensor.shape[2], width=img_tensor.shape[3]) | |
# 前向传播并获取形状 | |
output, shapes = forward_with_shapes(model, img_tensor) | |
# 将形状列表转换为字符串 | |
shapes_text = f"Input shape: {img_tensor.shape}\n" + "\n".join(shapes) | |
return shapes_text | |
# Gradio 界面 | |
demo = gr.Interface( | |
fn=process_tensor, | |
inputs=[ | |
gr.Slider(minimum=1, maximum=8, step=1, value=1, label="Batch Size"), | |
gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Channels"), | |
gr.Slider(minimum=16, maximum=512, step=16, value=224, label="Height"), | |
gr.Slider(minimum=16, maximum=512, step=16, value=224, label="Width"), | |
], | |
outputs=gr.Textbox(label="Feature Map Shapes"), | |
title="DinoVisionTransformer Feature Map Shapes", | |
description="Adjust the sliders to set the input tensor dimensions (B, C, H, W) and see the shapes of feature maps at each step of DinoVisionTransformer (vit_small, 4 register tokens). Height and Width will be adjusted to be multiples of 16.", | |
) | |
demo.launch() |