Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,10 +6,11 @@ import torchvision.transforms as transforms
|
|
6 |
from functools import partial
|
7 |
from dinov2 import DinoVisionTransformer
|
8 |
|
|
|
9 |
# 定义模型(使用 vit_small 作为示例)
|
10 |
-
def create_model():
|
11 |
model = DinoVisionTransformer(
|
12 |
-
img_size=
|
13 |
patch_size=16,
|
14 |
in_chans=3,
|
15 |
embed_dim=384,
|
@@ -22,14 +23,13 @@ def create_model():
|
|
22 |
model.eval()
|
23 |
return model
|
24 |
|
25 |
-
#
|
26 |
-
def
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
img_tensor = transform(image).unsqueeze(0) # 添加 batch 维度
|
33 |
return img_tensor
|
34 |
|
35 |
# 修改 forward_features 以输出中间形状
|
@@ -41,7 +41,7 @@ def forward_with_shapes(model, x, masks=None):
|
|
41 |
shapes.append(f"After patch_embed: {x.shape}")
|
42 |
|
43 |
# 2. Prepare tokens with masks
|
44 |
-
B, nc, w, h = x.shape[0], 3,
|
45 |
if masks is not None:
|
46 |
x = torch.where(masks.unsqueeze(-1), model.mask_token.to(x.dtype).unsqueeze(0), x)
|
47 |
x = torch.cat((model.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
@@ -88,24 +88,32 @@ def forward_with_shapes(model, x, masks=None):
|
|
88 |
return output, shapes
|
89 |
|
90 |
# 主处理函数
|
91 |
-
def
|
92 |
-
|
93 |
-
img_tensor =
|
|
|
|
|
|
|
94 |
|
95 |
# 前向传播并获取形状
|
96 |
output, shapes = forward_with_shapes(model, img_tensor)
|
97 |
|
98 |
# 将形状列表转换为字符串
|
99 |
-
shapes_text = "\n".join(shapes)
|
100 |
return shapes_text
|
101 |
|
102 |
# Gradio 界面
|
103 |
demo = gr.Interface(
|
104 |
-
fn=
|
105 |
-
inputs=
|
|
|
|
|
|
|
|
|
|
|
106 |
outputs=gr.Textbox(label="Feature Map Shapes"),
|
107 |
title="DinoVisionTransformer Feature Map Shapes",
|
108 |
-
description="
|
109 |
)
|
110 |
|
111 |
-
demo.launch()
|
|
|
6 |
from functools import partial
|
7 |
from dinov2 import DinoVisionTransformer
|
8 |
|
9 |
+
|
10 |
# 定义模型(使用 vit_small 作为示例)
|
11 |
+
def create_model(height, width):
|
12 |
model = DinoVisionTransformer(
|
13 |
+
img_size=(height, width), # 动态调整输入尺寸
|
14 |
patch_size=16,
|
15 |
in_chans=3,
|
16 |
embed_dim=384,
|
|
|
23 |
model.eval()
|
24 |
return model
|
25 |
|
26 |
+
# 生成随机输入张量
|
27 |
+
def generate_input(batch_size, channels, height, width):
|
28 |
+
# 确保输入尺寸是 patch_size 的整数倍
|
29 |
+
patch_size = 16
|
30 |
+
height = (height // patch_size) * patch_size
|
31 |
+
width = (width // patch_size) * patch_size
|
32 |
+
img_tensor = torch.randn(batch_size, channels, height, width)
|
|
|
33 |
return img_tensor
|
34 |
|
35 |
# 修改 forward_features 以输出中间形状
|
|
|
41 |
shapes.append(f"After patch_embed: {x.shape}")
|
42 |
|
43 |
# 2. Prepare tokens with masks
|
44 |
+
B, nc, w, h = x.shape[0], 3, x.shape[-2], x.shape[-1] # 使用动态输入尺寸
|
45 |
if masks is not None:
|
46 |
x = torch.where(masks.unsqueeze(-1), model.mask_token.to(x.dtype).unsqueeze(0), x)
|
47 |
x = torch.cat((model.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
|
|
88 |
return output, shapes
|
89 |
|
90 |
# 主处理函数
|
91 |
+
def process_tensor(batch_size, channels, height, width):
|
92 |
+
# 生成随机输入
|
93 |
+
img_tensor = generate_input(batch_size, channels, height, width)
|
94 |
+
|
95 |
+
# 创建模型,动态调整 img_size
|
96 |
+
model = create_model(height=img_tensor.shape[2], width=img_tensor.shape[3])
|
97 |
|
98 |
# 前向传播并获取形状
|
99 |
output, shapes = forward_with_shapes(model, img_tensor)
|
100 |
|
101 |
# 将形状列表转换为字符串
|
102 |
+
shapes_text = f"Input shape: {img_tensor.shape}\n" + "\n".join(shapes)
|
103 |
return shapes_text
|
104 |
|
105 |
# Gradio 界面
|
106 |
demo = gr.Interface(
|
107 |
+
fn=process_tensor,
|
108 |
+
inputs=[
|
109 |
+
gr.Slider(minimum=1, maximum=8, step=1, value=1, label="Batch Size"),
|
110 |
+
gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Channels"),
|
111 |
+
gr.Slider(minimum=16, maximum=512, step=16, value=224, label="Height"),
|
112 |
+
gr.Slider(minimum=16, maximum=512, step=16, value=224, label="Width"),
|
113 |
+
],
|
114 |
outputs=gr.Textbox(label="Feature Map Shapes"),
|
115 |
title="DinoVisionTransformer Feature Map Shapes",
|
116 |
+
description="Adjust the sliders to set the input tensor dimensions (B, C, H, W) and see the shapes of feature maps at each step of DinoVisionTransformer (vit_small, 4 register tokens). Height and Width will be adjusted to be multiples of 16.",
|
117 |
)
|
118 |
|
119 |
+
demo.launch()
|