zzzzzeee's picture
Update app.py
7c12ce1 verified
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
from functools import partial
from dinov2 import DinoVisionTransformer
from dinov2layers.block import Block
from dinov2layers.attention import MemEffAttention
# 定义模型(使用 vit_small 作为示例)
def create_model(height, width):
model = DinoVisionTransformer(
img_size=(height, width), # 动态调整输入尺寸
patch_size=16,
in_chans=3,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=4, # 设置 4 个 register tokens 用于测试
)
model.eval()
return model
# 生成随机输入张量
def generate_input(batch_size, channels, height, width):
# 确保输入尺寸是 patch_size 的整数倍
patch_size = 16
height = (height // patch_size) * patch_size
width = (width // patch_size) * patch_size
img_tensor = torch.randn(batch_size, channels, height, width)
return img_tensor
# 修改 forward_features 以输出中间形状
def forward_with_shapes(model, x, masks=None):
shapes = []
# 1. Patch Embedding
x = model.patch_embed(x)
shapes.append(f"After patch_embed: {x.shape}")
# 2. Prepare tokens with masks
B, nc, w, h = x.shape[0], 3, x.shape[-2], x.shape[-1] # 使用动态输入尺寸
if masks is not None:
x = torch.where(masks.unsqueeze(-1), model.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((model.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
shapes.append(f"After adding cls_token: {x.shape}")
# 3. Position Encoding
x = x + model.interpolate_pos_encoding(x, w, h)
shapes.append(f"After adding pos_embed: {x.shape}")
# 4. Register Tokens
if model.num_register_tokens > 0:
x = torch.cat(
(
x[:, :1], # cls_token
model.register_tokens.expand(x.shape[0], -1, -1), # register tokens
x[:, 1:], # patch tokens
),
dim=1,
)
shapes.append(f"After adding register_tokens: {x.shape}")
# 5. Transformer Blocks
for i, blk in enumerate(model.blocks):
x = blk(x)
shapes.append(f"After block {i+1}: {x.shape}")
# 6. Normalization
x_norm = model.norm(x)
shapes.append(f"After norm: {x_norm.shape}")
# 7. Extract outputs
output = {
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : model.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, model.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
shapes.append(f"x_norm_clstoken: {output['x_norm_clstoken'].shape}")
shapes.append(f"x_norm_regtokens: {output['x_norm_regtokens'].shape}")
shapes.append(f"x_norm_patchtokens: {output['x_norm_patchtokens'].shape}")
shapes.append(f"x_prenorm: {output['x_prenorm'].shape}")
return output, shapes
# 主处理函数
def process_tensor(batch_size, channels, height, width):
# 生成随机输入
img_tensor = generate_input(batch_size, channels, height, width)
# 创建模型,动态调整 img_size
model = create_model(height=img_tensor.shape[2], width=img_tensor.shape[3])
# 前向传播并获取形状
output, shapes = forward_with_shapes(model, img_tensor)
# 将形状列表转换为字符串
shapes_text = f"Input shape: {img_tensor.shape}\n" + "\n".join(shapes)
return shapes_text
# Gradio 界面
demo = gr.Interface(
fn=process_tensor,
inputs=[
gr.Slider(minimum=1, maximum=8, step=1, value=1, label="Batch Size"),
gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Channels"),
gr.Slider(minimum=16, maximum=512, step=16, value=224, label="Height"),
gr.Slider(minimum=16, maximum=512, step=16, value=224, label="Width"),
],
outputs=gr.Textbox(label="Feature Map Shapes"),
title="DinoVisionTransformer Feature Map Shapes",
description="Adjust the sliders to set the input tensor dimensions (B, C, H, W) and see the shapes of feature maps at each step of DinoVisionTransformer (vit_small, 4 register tokens). Height and Width will be adjusted to be multiples of 16.",
)
demo.launch()