File size: 3,631 Bytes
104a2dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import timm
from timm import create_model
from timm.models.layers import PatchEmbed
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.transforms import functional as F

cait_model = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
transform = timm.data.create_transform(
    **timm.data.resolve_data_config(cait_model.pretrained_cfg)
)

patch_size = 16


def create_attn_extractor(model, block_id=0):
    """Creates a model that produces the softmax attention scores.
    References:
        https://github.com/huggingface/pytorch-image-models/discussions/926
    """
    feature_extractor = create_feature_extractor(
        cait_model,
        return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"],
        tracer_kwargs={"leaf_modules": [PatchEmbed]},
    )
    return feature_extractor


def get_cls_attention_map(
    image, attn_score_dict=out, block_key="blocks_token_only.0.attn.softmax"
):
    """Prepares attention maps so that they can be visualized."""
    w_featmap = image.shape[3] // patch_size
    h_featmap = image.shape[2] // patch_size

    attention_scores = attn_score_dict[block_key]
    nh = attention_scores.shape[1]  # Number of attention heads.

    # Taking the representations from CLS token.
    attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
    print(attentions.shape)

    # Reshape the attention scores to resemble mini patches.
    attentions = attentions.reshape(nh, w_featmap, h_featmap)
    print(attentions.shape)

    # Resize the attention patches to 224x224 (224: 14x16)
    attentions = F.resize(
        attentions,
        size=(h_featmap * patch_size, w_featmap * patch_size),
        interpolation=3,
    )
    print(attentions.shape)

    return attentions


def generate_plot(processed_map):
    """Generates a class attention map plot."""
    fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
    img_count = 0

    for i in range(processed_map.shape[0]):
        if img_count < processed_map.shape[0]:
            axes[i].imshow(processed_map[img_count].numpy())
            axes[i].title.set_text(f"Attention head: {img_count}")
            axes[i].axis("off")
            img_count += 1

    fig.tight_layout()
    return fig


def generate_class_attn_map(image, block_id=0):
    """Collates the above utilities together for generating
    a class attention map."""
    image_tensor = transform(image).unsqueeze(0)
    feature_extractor = create_attn_extractor(cait_model, block_id)

    with torch.no_grad():
        out = feature_extractor(image_tensor)

    block_key = f"blocks_token_only.{block_id}.attn.softmax"
    processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
    return generate_plot(processed_cls_attn_map)


title = "Class Attention Maps"
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.). We use the [cait_xxs24_224](https://huggingface.co/timm/cait_xxs24_224.fb_dist_in1k) variant of CaiT. One can find all the other variants [here](https://huggingface.co/models?search=cait)."

iface = gr.Interface(
    generate_class_attn_map,
    inputs=[
        gr.inputs.Image(type="pil", label="Input Image"),
        gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
    ],
    outputs=[gr.Plot(type="auto").style()],
    title=title,
    article=article,
    allow_flagging="never",
    cache_examples=True,
    examples=[["./bird.png", 0]],
)
iface.launch()