Bill Psomas commited on
Commit
7578ae0
·
1 Parent(s): 02a709c

higher resolution choice

Browse files
Files changed (1) hide show
  1. app.py +10 -15
app.py CHANGED
@@ -15,13 +15,7 @@ import vision_transformer as vits
15
  arch = "vit_small"
16
  mode = "simpool"
17
  gamma = None
18
-
19
- patch_size = 16
20
- input_size = 224
21
-
22
  patch_size = 16
23
- input_size = 448
24
-
25
  num_classes = 0
26
  checkpoint = "checkpoints/vits_dino_simpool_no_gamma_ep100.pth"
27
  checkpoint_key = "teacher"
@@ -51,14 +45,13 @@ msg = model.load_state_dict(state_dict, strict=True)
51
 
52
  model.eval()
53
 
54
- # Define transformations
55
- data_transforms = transforms.Compose([
56
- transforms.Resize((input_size, input_size), interpolation=3),
57
- transforms.ToTensor(),
58
- transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
59
- ])
60
-
61
- def get_attention_map(img):
62
  x = data_transforms(img)
63
  attn = model.get_simpool_attention(x[None, :, :, :])
64
  attn = attn.reshape(1, 1, input_size//patch_size, input_size//patch_size)
@@ -73,7 +66,9 @@ def get_attention_map(img):
73
 
74
  attention_interface = gr.Interface(
75
  fn=get_attention_map,
76
- inputs=[gr.Image(type="pil", label="Input Image")],
 
 
77
  outputs=gr.Image(type="pil", label="SimPool Attention Map", width=width_display, height=height_display),
78
  examples=example_list,
79
  title="Explore the Attention Maps of SimPool🔍",
 
15
  arch = "vit_small"
16
  mode = "simpool"
17
  gamma = None
 
 
 
 
18
  patch_size = 16
 
 
19
  num_classes = 0
20
  checkpoint = "checkpoints/vits_dino_simpool_no_gamma_ep100.pth"
21
  checkpoint_key = "teacher"
 
45
 
46
  model.eval()
47
 
48
+ def get_attention_map(img, resolution):
49
+ input_size = resolution * 14
50
+ data_transforms = transforms.Compose([
51
+ transforms.Resize((input_size, input_size), interpolation=3),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
54
+ ])
 
55
  x = data_transforms(img)
56
  attn = model.get_simpool_attention(x[None, :, :, :])
57
  attn = attn.reshape(1, 1, input_size//patch_size, input_size//patch_size)
 
66
 
67
  attention_interface = gr.Interface(
68
  fn=get_attention_map,
69
+ inputs=[gr.Image(type="pil", label="Input Image"),
70
+ gr.Dropdown(choices=["16", "32", "64", "128"],
71
+ label="Attention Map Resolution", value="32", type="index")],
72
  outputs=gr.Image(type="pil", label="SimPool Attention Map", width=width_display, height=height_display),
73
  examples=example_list,
74
  title="Explore the Attention Maps of SimPool🔍",