import gradio as gr import torch import torch.nn as nn from PIL import Image import torchvision.transforms as transforms from functools import partial from dinov2 import DinoVisionTransformer from dinov2layers.block import Block from dinov2layers.attention import MemEffAttention # 定义模型(使用 vit_small 作为示例) def create_model(height, width): model = DinoVisionTransformer( img_size=(height, width), # 动态调整输入尺寸 patch_size=16, in_chans=3, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=4, # 设置 4 个 register tokens 用于测试 ) model.eval() return model # 生成随机输入张量 def generate_input(batch_size, channels, height, width): # 确保输入尺寸是 patch_size 的整数倍 patch_size = 16 height = (height // patch_size) * patch_size width = (width // patch_size) * patch_size img_tensor = torch.randn(batch_size, channels, height, width) return img_tensor # 修改 forward_features 以输出中间形状 def forward_with_shapes(model, x, masks=None): shapes = [] # 1. Patch Embedding x = model.patch_embed(x) shapes.append(f"After patch_embed: {x.shape}") # 2. Prepare tokens with masks B, nc, w, h = x.shape[0], 3, x.shape[-2], x.shape[-1] # 使用动态输入尺寸 if masks is not None: x = torch.where(masks.unsqueeze(-1), model.mask_token.to(x.dtype).unsqueeze(0), x) x = torch.cat((model.cls_token.expand(x.shape[0], -1, -1), x), dim=1) shapes.append(f"After adding cls_token: {x.shape}") # 3. Position Encoding x = x + model.interpolate_pos_encoding(x, w, h) shapes.append(f"After adding pos_embed: {x.shape}") # 4. Register Tokens if model.num_register_tokens > 0: x = torch.cat( ( x[:, :1], # cls_token model.register_tokens.expand(x.shape[0], -1, -1), # register tokens x[:, 1:], # patch tokens ), dim=1, ) shapes.append(f"After adding register_tokens: {x.shape}") # 5. Transformer Blocks for i, blk in enumerate(model.blocks): x = blk(x) shapes.append(f"After block {i+1}: {x.shape}") # 6. Normalization x_norm = model.norm(x) shapes.append(f"After norm: {x_norm.shape}") # 7. Extract outputs output = { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : model.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, model.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } shapes.append(f"x_norm_clstoken: {output['x_norm_clstoken'].shape}") shapes.append(f"x_norm_regtokens: {output['x_norm_regtokens'].shape}") shapes.append(f"x_norm_patchtokens: {output['x_norm_patchtokens'].shape}") shapes.append(f"x_prenorm: {output['x_prenorm'].shape}") return output, shapes # 主处理函数 def process_tensor(batch_size, channels, height, width): # 生成随机输入 img_tensor = generate_input(batch_size, channels, height, width) # 创建模型,动态调整 img_size model = create_model(height=img_tensor.shape[2], width=img_tensor.shape[3]) # 前向传播并获取形状 output, shapes = forward_with_shapes(model, img_tensor) # 将形状列表转换为字符串 shapes_text = f"Input shape: {img_tensor.shape}\n" + "\n".join(shapes) return shapes_text # Gradio 界面 demo = gr.Interface( fn=process_tensor, inputs=[ gr.Slider(minimum=1, maximum=8, step=1, value=1, label="Batch Size"), gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Channels"), gr.Slider(minimum=16, maximum=512, step=16, value=224, label="Height"), gr.Slider(minimum=16, maximum=512, step=16, value=224, label="Width"), ], outputs=gr.Textbox(label="Feature Map Shapes"), title="DinoVisionTransformer Feature Map Shapes", description="Adjust the sliders to set the input tensor dimensions (B, C, H, W) and see the shapes of feature maps at each step of DinoVisionTransformer (vit_small, 4 register tokens). Height and Width will be adjusted to be multiples of 16.", ) demo.launch()