zzzzzeee commited on
Commit
c441492
·
verified ·
1 Parent(s): e41a656

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ from functools import partial
7
+ from dinov2 import DinoVisionTransformer
8
+
9
+ # 定义模型(使用 vit_small 作为示例)
10
+ def create_model():
11
+ model = DinoVisionTransformer(
12
+ img_size=224,
13
+ patch_size=16,
14
+ in_chans=3,
15
+ embed_dim=384,
16
+ depth=12,
17
+ num_heads=6,
18
+ mlp_ratio=4,
19
+ block_fn=partial(Block, attn_class=MemEffAttention),
20
+ num_register_tokens=4, # 设置 4 个 register tokens 用于测试
21
+ )
22
+ model.eval()
23
+ return model
24
+
25
+ # 图像预处理
26
+ def preprocess_image(image):
27
+ transform = transforms.Compose([
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31
+ ])
32
+ img_tensor = transform(image).unsqueeze(0) # 添加 batch 维度
33
+ return img_tensor
34
+
35
+ # 修改 forward_features 以输出中间形状
36
+ def forward_with_shapes(model, x, masks=None):
37
+ shapes = []
38
+
39
+ # 1. Patch Embedding
40
+ x = model.patch_embed(x)
41
+ shapes.append(f"After patch_embed: {x.shape}")
42
+
43
+ # 2. Prepare tokens with masks
44
+ B, nc, w, h = x.shape[0], 3, 224, 224 # 原始图像尺寸
45
+ if masks is not None:
46
+ x = torch.where(masks.unsqueeze(-1), model.mask_token.to(x.dtype).unsqueeze(0), x)
47
+ x = torch.cat((model.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
48
+ shapes.append(f"After adding cls_token: {x.shape}")
49
+
50
+ # 3. Position Encoding
51
+ x = x + model.interpolate_pos_encoding(x, w, h)
52
+ shapes.append(f"After adding pos_embed: {x.shape}")
53
+
54
+ # 4. Register Tokens
55
+ if model.num_register_tokens > 0:
56
+ x = torch.cat(
57
+ (
58
+ x[:, :1], # cls_token
59
+ model.register_tokens.expand(x.shape[0], -1, -1), # register tokens
60
+ x[:, 1:], # patch tokens
61
+ ),
62
+ dim=1,
63
+ )
64
+ shapes.append(f"After adding register_tokens: {x.shape}")
65
+
66
+ # 5. Transformer Blocks
67
+ for i, blk in enumerate(model.blocks):
68
+ x = blk(x)
69
+ shapes.append(f"After block {i+1}: {x.shape}")
70
+
71
+ # 6. Normalization
72
+ x_norm = model.norm(x)
73
+ shapes.append(f"After norm: {x_norm.shape}")
74
+
75
+ # 7. Extract outputs
76
+ output = {
77
+ "x_norm_clstoken": x_norm[:, 0],
78
+ "x_norm_regtokens": x_norm[:, 1 : model.num_register_tokens + 1],
79
+ "x_norm_patchtokens": x_norm[:, model.num_register_tokens + 1 :],
80
+ "x_prenorm": x,
81
+ "masks": masks,
82
+ }
83
+ shapes.append(f"x_norm_clstoken: {output['x_norm_clstoken'].shape}")
84
+ shapes.append(f"x_norm_regtokens: {output['x_norm_regtokens'].shape}")
85
+ shapes.append(f"x_norm_patchtokens: {output['x_norm_patchtokens'].shape}")
86
+ shapes.append(f"x_prenorm: {output['x_prenorm'].shape}")
87
+
88
+ return output, shapes
89
+
90
+ # 主处理函数
91
+ def process_image(image):
92
+ model = create_model()
93
+ img_tensor = preprocess_image(image)
94
+
95
+ # 前向传播并获取形状
96
+ output, shapes = forward_with_shapes(model, img_tensor)
97
+
98
+ # 将形状列表转换为字符串
99
+ shapes_text = "\n".join(shapes)
100
+ return shapes_text
101
+
102
+ # Gradio 界面
103
+ demo = gr.Interface(
104
+ fn=process_image,
105
+ inputs=gr.Image(type="pil", label="Upload an Image"),
106
+ outputs=gr.Textbox(label="Feature Map Shapes"),
107
+ title="DinoVisionTransformer Feature Map Shapes",
108
+ description="Upload an image to see the shapes of feature maps at each step of DinoVisionTransformer (vit_small, 4 register tokens).",
109
+ )
110
+
111
+ demo.launch()