Spaces:
gentr
/
Runtime error

gentr nielsr HF staff commited on
Commit
1423e1b
·
0 Parent(s):

Duplicate from nielsr/DINO

Browse files

Co-authored-by: Niels Rogge <[email protected]>

Files changed (4) hide show
  1. .gitattributes +27 -0
  2. README.md +34 -0
  3. app.py +78 -0
  4. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DINO
3
+ emoji: 📉
4
+ colorFrom: pink
5
+ colorTo: pink
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ duplicated_from: nielsr/DINO
10
+ ---
11
+
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio` or `streamlit`
28
+
29
+ `app_file`: _string_
30
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
31
+ Path is relative to the root of the repository.
32
+
33
+ `pinned`: _boolean_
34
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install git+https://github.com/huggingface/transformers.git --upgrade')
3
+
4
+ import gradio as gr
5
+ from transformers import ViTFeatureExtractor, ViTModel
6
+ import torch
7
+ import matplotlib.pyplot as plt
8
+
9
+ torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
10
+
11
+ def get_attention_maps(pixel_values, attentions, nh):
12
+ threshold = 0.6
13
+ w_featmap = pixel_values.shape[-2] // model.config.patch_size
14
+ h_featmap = pixel_values.shape[-1] // model.config.patch_size
15
+
16
+ # we keep only a certain percentage of the mass
17
+ val, idx = torch.sort(attentions)
18
+ val /= torch.sum(val, dim=1, keepdim=True)
19
+ cumval = torch.cumsum(val, dim=1)
20
+ th_attn = cumval > (1 - threshold)
21
+ idx2 = torch.argsort(idx)
22
+ for head in range(nh):
23
+ th_attn[head] = th_attn[head][idx2[head]]
24
+ th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
25
+ # interpolate
26
+ th_attn = torch.nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu().numpy()
27
+
28
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
29
+ attentions = torch.nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu()
30
+ attentions = attentions.detach().numpy()
31
+
32
+ # save attentions heatmaps and return list of filenames
33
+ output_dir = '.'
34
+ os.makedirs(output_dir, exist_ok=True)
35
+ attention_maps = []
36
+ for j in range(nh):
37
+ fname = os.path.join(output_dir, "attn-head" + str(j) + ".png")
38
+ # save the attention map
39
+ plt.imsave(fname=fname, arr=attentions[j], format='png')
40
+ # append file name
41
+ attention_maps.append(fname)
42
+
43
+ return attention_maps
44
+
45
+ feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits8", do_resize=False)
46
+ model = ViTModel.from_pretrained("facebook/dino-vits8", add_pooling_layer=False)
47
+
48
+ def visualize_attention(image):
49
+ # normalize channels
50
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
51
+
52
+ # forward pass
53
+ outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True)
54
+
55
+ # get attentions of last layer
56
+ attentions = outputs.attentions[-1]
57
+ nh = attentions.shape[1] # number of heads
58
+
59
+ # we keep only the output patch attention
60
+ attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
61
+
62
+ attention_maps = get_attention_maps(pixel_values, attentions, nh)
63
+
64
+ return attention_maps
65
+
66
+ title = "Interactive demo: DINO"
67
+ description = "Demo for Facebook AI's DINO, a new method for self-supervised training of Vision Transformers. Using this method, they are capable of segmenting objects within an image without having ever been trained to do so. This can be observed by displaying the self-attention of the heads from the last layer for the [CLS] token query. This demo uses a ViT-S/8 trained with DINO. To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
68
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.14294'>Emerging Properties in Self-Supervised Vision Transformers</a> | <a href='https://github.com/facebookresearch/dino'>Github Repo</a></p>"
69
+ examples =[['cats.jpg']]
70
+
71
+ iface = gr.Interface(fn=visualize_attention,
72
+ inputs=gr.inputs.Image(shape=(480, 480), type="pil"),
73
+ outputs=[gr.outputs.Image(type='file', label=f'attention_head_{i}') for i in range(6)],
74
+ title=title,
75
+ description=description,
76
+ article=article,
77
+ examples=examples)
78
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ Pillow
4
+ matplotlib