zzzzzeee commited on
Commit
f9ff13b
·
verified ·
1 Parent(s): 77b3a26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -19
app.py CHANGED
@@ -6,10 +6,11 @@ import torchvision.transforms as transforms
6
  from functools import partial
7
  from dinov2 import DinoVisionTransformer
8
 
 
9
  # 定义模型(使用 vit_small 作为示例)
10
- def create_model():
11
  model = DinoVisionTransformer(
12
- img_size=224,
13
  patch_size=16,
14
  in_chans=3,
15
  embed_dim=384,
@@ -22,14 +23,13 @@ def create_model():
22
  model.eval()
23
  return model
24
 
25
- # 图像预处理
26
- def preprocess_image(image):
27
- transform = transforms.Compose([
28
- transforms.Resize((224, 224)),
29
- transforms.ToTensor(),
30
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31
- ])
32
- img_tensor = transform(image).unsqueeze(0) # 添加 batch 维度
33
  return img_tensor
34
 
35
  # 修改 forward_features 以输出中间形状
@@ -41,7 +41,7 @@ def forward_with_shapes(model, x, masks=None):
41
  shapes.append(f"After patch_embed: {x.shape}")
42
 
43
  # 2. Prepare tokens with masks
44
- B, nc, w, h = x.shape[0], 3, 224, 224 # 原始图像尺寸
45
  if masks is not None:
46
  x = torch.where(masks.unsqueeze(-1), model.mask_token.to(x.dtype).unsqueeze(0), x)
47
  x = torch.cat((model.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
@@ -88,24 +88,32 @@ def forward_with_shapes(model, x, masks=None):
88
  return output, shapes
89
 
90
  # 主处理函数
91
- def process_image(image):
92
- model = create_model()
93
- img_tensor = preprocess_image(image)
 
 
 
94
 
95
  # 前向传播并获取形状
96
  output, shapes = forward_with_shapes(model, img_tensor)
97
 
98
  # 将形状列表转换为字符串
99
- shapes_text = "\n".join(shapes)
100
  return shapes_text
101
 
102
  # Gradio 界面
103
  demo = gr.Interface(
104
- fn=process_image,
105
- inputs=gr.Image(type="pil", label="Upload an Image"),
 
 
 
 
 
106
  outputs=gr.Textbox(label="Feature Map Shapes"),
107
  title="DinoVisionTransformer Feature Map Shapes",
108
- description="Upload an image to see the shapes of feature maps at each step of DinoVisionTransformer (vit_small, 4 register tokens).",
109
  )
110
 
111
- demo.launch()
 
6
  from functools import partial
7
  from dinov2 import DinoVisionTransformer
8
 
9
+
10
  # 定义模型(使用 vit_small 作为示例)
11
+ def create_model(height, width):
12
  model = DinoVisionTransformer(
13
+ img_size=(height, width), # 动态调整输入尺寸
14
  patch_size=16,
15
  in_chans=3,
16
  embed_dim=384,
 
23
  model.eval()
24
  return model
25
 
26
+ # 生成随机输入张量
27
+ def generate_input(batch_size, channels, height, width):
28
+ # 确保输入尺寸是 patch_size 的整数倍
29
+ patch_size = 16
30
+ height = (height // patch_size) * patch_size
31
+ width = (width // patch_size) * patch_size
32
+ img_tensor = torch.randn(batch_size, channels, height, width)
 
33
  return img_tensor
34
 
35
  # 修改 forward_features 以输出中间形状
 
41
  shapes.append(f"After patch_embed: {x.shape}")
42
 
43
  # 2. Prepare tokens with masks
44
+ B, nc, w, h = x.shape[0], 3, x.shape[-2], x.shape[-1] # 使用动态输入尺寸
45
  if masks is not None:
46
  x = torch.where(masks.unsqueeze(-1), model.mask_token.to(x.dtype).unsqueeze(0), x)
47
  x = torch.cat((model.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
 
88
  return output, shapes
89
 
90
  # 主处理函数
91
+ def process_tensor(batch_size, channels, height, width):
92
+ # 生成随机输入
93
+ img_tensor = generate_input(batch_size, channels, height, width)
94
+
95
+ # 创建模型,动态调整 img_size
96
+ model = create_model(height=img_tensor.shape[2], width=img_tensor.shape[3])
97
 
98
  # 前向传播并获取形状
99
  output, shapes = forward_with_shapes(model, img_tensor)
100
 
101
  # 将形状列表转换为字符串
102
+ shapes_text = f"Input shape: {img_tensor.shape}\n" + "\n".join(shapes)
103
  return shapes_text
104
 
105
  # Gradio 界面
106
  demo = gr.Interface(
107
+ fn=process_tensor,
108
+ inputs=[
109
+ gr.Slider(minimum=1, maximum=8, step=1, value=1, label="Batch Size"),
110
+ gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Channels"),
111
+ gr.Slider(minimum=16, maximum=512, step=16, value=224, label="Height"),
112
+ gr.Slider(minimum=16, maximum=512, step=16, value=224, label="Width"),
113
+ ],
114
  outputs=gr.Textbox(label="Feature Map Shapes"),
115
  title="DinoVisionTransformer Feature Map Shapes",
116
+ description="Adjust the sliders to set the input tensor dimensions (B, C, H, W) and see the shapes of feature maps at each step of DinoVisionTransformer (vit_small, 4 register tokens). Height and Width will be adjusted to be multiples of 16.",
117
  )
118
 
119
+ demo.launch()