File size: 4,289 Bytes
cf37148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import cv2
from PIL import Image
import numpy as np
import torch

import PIL

def overlay_attn(original_image,mask):
    # Colormap and alpha for attention mask
    # COLORMAP_OCEAN
    # COLORMAP_OCEAN
    colormap_attn, alpha_attn = cv2.COLORMAP_JET, 1 #0.85
    
    # Resize mask to original image size
    w, h = original_image.shape[0], original_image.shape[1]
    mask = cv2.resize(mask / mask.max(), (h, w))[..., np.newaxis]
    
    # Apply colormap to mask
    cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn)

    print(cmap.shape)
    # Blend mask and original image
    # grayscale_img =  cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
    # grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
    # alpha_blended = cv2.addWeighted(np.uint8(original_image),1, cmap, alpha_attn, 0)
    alpha_blended = cv2.addWeighted(np.uint8(original_image),0.1, cmap, 0.9, 0)


    # alpha_blended = cmap
    

    # Save image
    final_im = Image.fromarray(alpha_blended)
    # final_im = final_im.crop((0,0,250,250))
    return final_im



class VITAttentionGradRollout:
    '''
        Expects timm ViT transformer model
        Adapted from https://github.com/samiraabnar/attention_flow
    '''
    def __init__(self, model, head_fusion='min', discard_ratio=0):
        self.model = model
        self.head_fusion = head_fusion
        self.discard_ratio = discard_ratio
        print(list(model.blocks.children()))
        
        self.attentions = {}
        for idx, module in enumerate(list(model.blocks.children())):
            module.attn.register_forward_hook(self.get_attention(f"attn{idx}"))


    def get_attention(self, name):
        def hook(module, input, output):
            with torch.no_grad():
                input = input[0]
                B, N, C = input.shape
                qkv = (
                    module.qkv(input)
                    .detach()
                    .reshape(B, N, 3, module.num_heads, C // module.num_heads)
                    .permute(2, 0, 3, 1, 4)
                )
                q, k, _ = (
                    qkv[0],
                    qkv[1],
                    qkv[2],
                )  # make torchscript happy (cannot use tensor as tuple)
                attn = (q @ k.transpose(-2, -1)) * module.scale
                attn = attn.softmax(dim=-1)
                self.attentions[name] = attn
        return hook

    def get_attn_mask(self,k=0):
        attn_key = "attn" + str()
        result = torch.eye(self.attentions['attn0'].size(-1)).to(self.attentions['attn0'].device)

        # result = torch.eye(self.attentions['attn2'].size(-1)).to(self.attentions['attn2'].device)
        with torch.no_grad():
            # for attention in self.attentions.values():
            for k in range(11, len(self.attentions.keys())):
                attention = self.attentions[f'attn{k}']
                if self.head_fusion == "mean":
                    attention_heads_fused = attention.mean(axis=1)
                elif self.head_fusion == "max":
                    attention_heads_fused = attention.max(axis=1)[0]
                elif self.head_fusion == "min":
                    attention_heads_fused = attention.min(axis=1)[0]
                else:
                    raise "Attention head fusion type Not supported"

                # Drop the lowest attentions, but
                # don't drop the class token
                flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
                _, indices = flat.topk(int(flat.size(-1)*self.discard_ratio), -1, False)
                indices = indices[indices != 0]
                flat[0, indices] = 0
                I = torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device)
                a = (attention_heads_fused + 1.0*I)/2
                a = a / a.sum(dim=-1).unsqueeze(-1)

                result = torch.matmul(a, result)

        # Look at the total attention between the class token,
        # and the image patches
        mask = result[0, 0 , 1 :]
        # In case of 224x224 image, this brings us from 196 to 14
        width = int(mask.size(-1)**0.5)
        mask = mask.reshape(width, width).detach().cpu().numpy()
        mask = mask / np.max(mask)
        return mask