Duplicate from nielsr/DINO
Browse filesCo-authored-by: Niels Rogge <[email protected]>
- .gitattributes +27 -0
- README.md +34 -0
- app.py +78 -0
- requirements.txt +4 -0
.gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: DINO
|
3 |
+
emoji: 📉
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
duplicated_from: nielsr/DINO
|
10 |
+
---
|
11 |
+
|
12 |
+
# Configuration
|
13 |
+
|
14 |
+
`title`: _string_
|
15 |
+
Display title for the Space
|
16 |
+
|
17 |
+
`emoji`: _string_
|
18 |
+
Space emoji (emoji-only character allowed)
|
19 |
+
|
20 |
+
`colorFrom`: _string_
|
21 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
22 |
+
|
23 |
+
`colorTo`: _string_
|
24 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
25 |
+
|
26 |
+
`sdk`: _string_
|
27 |
+
Can be either `gradio` or `streamlit`
|
28 |
+
|
29 |
+
`app_file`: _string_
|
30 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
31 |
+
Path is relative to the root of the repository.
|
32 |
+
|
33 |
+
`pinned`: _boolean_
|
34 |
+
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.system('pip install git+https://github.com/huggingface/transformers.git --upgrade')
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from transformers import ViTFeatureExtractor, ViTModel
|
6 |
+
import torch
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
|
10 |
+
|
11 |
+
def get_attention_maps(pixel_values, attentions, nh):
|
12 |
+
threshold = 0.6
|
13 |
+
w_featmap = pixel_values.shape[-2] // model.config.patch_size
|
14 |
+
h_featmap = pixel_values.shape[-1] // model.config.patch_size
|
15 |
+
|
16 |
+
# we keep only a certain percentage of the mass
|
17 |
+
val, idx = torch.sort(attentions)
|
18 |
+
val /= torch.sum(val, dim=1, keepdim=True)
|
19 |
+
cumval = torch.cumsum(val, dim=1)
|
20 |
+
th_attn = cumval > (1 - threshold)
|
21 |
+
idx2 = torch.argsort(idx)
|
22 |
+
for head in range(nh):
|
23 |
+
th_attn[head] = th_attn[head][idx2[head]]
|
24 |
+
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
|
25 |
+
# interpolate
|
26 |
+
th_attn = torch.nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu().numpy()
|
27 |
+
|
28 |
+
attentions = attentions.reshape(nh, w_featmap, h_featmap)
|
29 |
+
attentions = torch.nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu()
|
30 |
+
attentions = attentions.detach().numpy()
|
31 |
+
|
32 |
+
# save attentions heatmaps and return list of filenames
|
33 |
+
output_dir = '.'
|
34 |
+
os.makedirs(output_dir, exist_ok=True)
|
35 |
+
attention_maps = []
|
36 |
+
for j in range(nh):
|
37 |
+
fname = os.path.join(output_dir, "attn-head" + str(j) + ".png")
|
38 |
+
# save the attention map
|
39 |
+
plt.imsave(fname=fname, arr=attentions[j], format='png')
|
40 |
+
# append file name
|
41 |
+
attention_maps.append(fname)
|
42 |
+
|
43 |
+
return attention_maps
|
44 |
+
|
45 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits8", do_resize=False)
|
46 |
+
model = ViTModel.from_pretrained("facebook/dino-vits8", add_pooling_layer=False)
|
47 |
+
|
48 |
+
def visualize_attention(image):
|
49 |
+
# normalize channels
|
50 |
+
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
51 |
+
|
52 |
+
# forward pass
|
53 |
+
outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True)
|
54 |
+
|
55 |
+
# get attentions of last layer
|
56 |
+
attentions = outputs.attentions[-1]
|
57 |
+
nh = attentions.shape[1] # number of heads
|
58 |
+
|
59 |
+
# we keep only the output patch attention
|
60 |
+
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
|
61 |
+
|
62 |
+
attention_maps = get_attention_maps(pixel_values, attentions, nh)
|
63 |
+
|
64 |
+
return attention_maps
|
65 |
+
|
66 |
+
title = "Interactive demo: DINO"
|
67 |
+
description = "Demo for Facebook AI's DINO, a new method for self-supervised training of Vision Transformers. Using this method, they are capable of segmenting objects within an image without having ever been trained to do so. This can be observed by displaying the self-attention of the heads from the last layer for the [CLS] token query. This demo uses a ViT-S/8 trained with DINO. To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
|
68 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.14294'>Emerging Properties in Self-Supervised Vision Transformers</a> | <a href='https://github.com/facebookresearch/dino'>Github Repo</a></p>"
|
69 |
+
examples =[['cats.jpg']]
|
70 |
+
|
71 |
+
iface = gr.Interface(fn=visualize_attention,
|
72 |
+
inputs=gr.inputs.Image(shape=(480, 480), type="pil"),
|
73 |
+
outputs=[gr.outputs.Image(type='file', label=f'attention_head_{i}') for i in range(6)],
|
74 |
+
title=title,
|
75 |
+
description=description,
|
76 |
+
article=article,
|
77 |
+
examples=examples)
|
78 |
+
iface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
Pillow
|
4 |
+
matplotlib
|