osanseviero commited on
Commit
60d58ef
·
1 Parent(s): c362c8b

Add gradio app

Browse files
Files changed (2) hide show
  1. app.py +78 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install git+https://github.com/huggingface/transformers.git --upgrade')
3
+
4
+ import gradio as gr
5
+ from transformers import ViTFeatureExtractor, ViTModel
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchvision
9
+ import matplotlib.pyplot as plt
10
+
11
+ torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
12
+
13
+ def get_attention_maps(pixel_values, attentions, nh):
14
+ threshold = 0.6
15
+ w_featmap = pixel_values.shape[-2] // model.config.patch_size
16
+ h_featmap = pixel_values.shape[-1] // model.config.patch_size
17
+
18
+ # we keep only a certain percentage of the mass
19
+ val, idx = torch.sort(attentions)
20
+ val /= torch.sum(val, dim=1, keepdim=True)
21
+ cumval = torch.cumsum(val, dim=1)
22
+ th_attn = cumval > (1 - threshold)
23
+ idx2 = torch.argsort(idx)
24
+ for head in range(nh):
25
+ th_attn[head] = th_attn[head][idx2[head]]
26
+ th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
27
+
28
+ # interpolate
29
+ th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu().numpy()
30
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
31
+ attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu()
32
+ attentions = attentions.detach().numpy()
33
+
34
+ # save attentions heatmaps and return list of filenames
35
+ output_dir = '.'
36
+ os.makedirs(output_dir, exist_ok=True)
37
+ attention_maps = []
38
+ for j in range(nh):
39
+ fname = os.path.join(output_dir, "attn-head" + str(j) + ".png")
40
+
41
+ # save the attention map
42
+ plt.imsave(fname=fname, arr=attentions[j], format='png')
43
+
44
+ # append file name
45
+ attention_maps.append(fname)
46
+
47
+ return attention_maps
48
+
49
+
50
+ def visualize_attention(image):
51
+ # normalize channels
52
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
53
+ # forward pass
54
+ outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True)
55
+ # get attentions of last layer
56
+ attentions = outputs.attentions[-1]
57
+ nh = attentions.shape[1] # number of heads
58
+ # we keep only the output patch attention
59
+ attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
60
+ attention_maps = get_attention_maps(pixel_values, attentions, nh)
61
+
62
+ return attention_maps
63
+
64
+ feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits8", do_resize=False)
65
+ model = ViTModel.from_pretrained("facebook/dino-vits8", add_pooling_layer=False)
66
+
67
+ title = "Interactive demo: DINO"
68
+ description = "Demo for Facebook AI's DINO, a new method for self-supervised training of Vision Transformers. Using this method, they are capable of segmenting objects within an image without having ever been trained to do so. This can be observed by displaying the self-attention of the heads from the last layer for the [CLS] token query. This demo uses a ViT-S/8 trained with DINO. To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
69
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.14294'>Emerging Properties in Self-Supervised Vision Transformers</a> | <a href='https://github.com/facebookresearch/dino'>Github Repo</a></p>"
70
+ examples =[['cats.jpg']]
71
+ iface = gr.Interface(fn=visualize_attention,
72
+ inputs=gr.inputs.Image(shape=(480, 480), type="pil"),
73
+ outputs=[gr.outputs.Image(type='file', label=f'attention_head_{i}') for i in range(6)],
74
+ title=title,
75
+ description=description,
76
+ article=article,
77
+ examples=examples)
78
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ matplotlib