|
import cv2 |
|
from PIL import Image |
|
import numpy as np |
|
import torch |
|
|
|
import PIL |
|
|
|
def overlay_attn(original_image,mask): |
|
|
|
|
|
|
|
colormap_attn, alpha_attn = cv2.COLORMAP_VIRIDIS, 1 |
|
|
|
|
|
w, h = original_image.shape[0], original_image.shape[1] |
|
mask = cv2.resize(mask / mask.max(), (h, w))[..., np.newaxis] |
|
|
|
|
|
cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn) |
|
|
|
print(cmap.shape) |
|
|
|
|
|
|
|
|
|
alpha_blended = cv2.addWeighted(np.uint8(original_image),0.4, cmap, 0.6, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
final_im = Image.fromarray(alpha_blended) |
|
|
|
return final_im |
|
|
|
|
|
|
|
class VITAttentionGradRollout: |
|
''' |
|
Expects timm ViT transformer model |
|
Adapted from https://github.com/samiraabnar/attention_flow |
|
''' |
|
def __init__(self, model, head_fusion='min', discard_ratio=0): |
|
self.model = model |
|
self.head_fusion = head_fusion |
|
self.discard_ratio = discard_ratio |
|
print(list(model.blocks.children())) |
|
|
|
self.attentions = {} |
|
for idx, module in enumerate(list(model.blocks.children())): |
|
module.attn.register_forward_hook(self.get_attention(f"attn{idx}")) |
|
|
|
|
|
def get_attention(self, name): |
|
def hook(module, input, output): |
|
with torch.no_grad(): |
|
input = input[0] |
|
B, N, C = input.shape |
|
qkv = ( |
|
module.qkv(input) |
|
.detach() |
|
.reshape(B, N, 3, module.num_heads, C // module.num_heads) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
q, k, _ = ( |
|
qkv[0], |
|
qkv[1], |
|
qkv[2], |
|
) |
|
attn = (q @ k.transpose(-2, -1)) * module.scale |
|
attn = attn.softmax(dim=-1) |
|
self.attentions[name] = attn |
|
return hook |
|
|
|
def get_attn_mask(self,k=0): |
|
attn_key = "attn" + str() |
|
result = torch.eye(self.attentions['attn0'].size(-1)).to(self.attentions['attn0'].device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
for k in range(11, len(self.attentions.keys())): |
|
attention = self.attentions[f'attn{k}'] |
|
if self.head_fusion == "mean": |
|
attention_heads_fused = attention.mean(axis=1) |
|
elif self.head_fusion == "max": |
|
attention_heads_fused = attention.max(axis=1)[0] |
|
elif self.head_fusion == "min": |
|
attention_heads_fused = attention.min(axis=1)[0] |
|
else: |
|
raise "Attention head fusion type Not supported" |
|
|
|
|
|
|
|
flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) |
|
_, indices = flat.topk(int(flat.size(-1)*self.discard_ratio), -1, False) |
|
indices = indices[indices != 0] |
|
flat[0, indices] = 0 |
|
I = torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device) |
|
a = (attention_heads_fused + 1.0*I)/2 |
|
a = a / a.sum(dim=-1).unsqueeze(-1) |
|
|
|
result = torch.matmul(a, result) |
|
|
|
|
|
|
|
mask = result[0, 0 , 1 :] |
|
|
|
width = int(mask.size(-1)**0.5) |
|
mask = mask.reshape(width, width).detach().cpu().numpy() |
|
mask = mask / np.max(mask) |
|
return mask |
|
|