|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import matplotlib.cm |
|
from PIL import Image |
|
|
|
|
|
class Hook: |
|
"""Attaches to a module and records its activations and gradients.""" |
|
|
|
def __init__(self, module: nn.Module): |
|
self.data = None |
|
self.hook = module.register_forward_hook(self.save_grad) |
|
|
|
def save_grad(self, module, input, output): |
|
self.data = output |
|
output.requires_grad_(True) |
|
output.retain_grad() |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_value, exc_traceback): |
|
self.hook.remove() |
|
|
|
@property |
|
def activation(self) -> torch.Tensor: |
|
return self.data |
|
|
|
@property |
|
def gradient(self) -> torch.Tensor: |
|
return self.data.grad |
|
|
|
|
|
|
|
def gradCAM( |
|
model: nn.Module, |
|
input: torch.Tensor, |
|
target: torch.Tensor, |
|
layer: nn.Module |
|
) -> torch.Tensor: |
|
|
|
if input.grad is not None: |
|
input.grad.data.zero_() |
|
|
|
|
|
requires_grad = {} |
|
for name, param in model.named_parameters(): |
|
requires_grad[name] = param.requires_grad |
|
param.requires_grad_(False) |
|
|
|
|
|
assert isinstance(layer, nn.Module) |
|
with Hook(layer) as hook: |
|
|
|
output = model(input) |
|
output.backward(target) |
|
|
|
grad = hook.gradient.float() |
|
act = hook.activation.float() |
|
|
|
|
|
|
|
alpha = grad.mean(dim=(2, 3), keepdim=True) |
|
|
|
|
|
gradcam = torch.sum(act * alpha, dim=1, keepdim=True) |
|
|
|
|
|
gradcam = torch.clamp(gradcam, min=0) |
|
|
|
|
|
gradcam = F.interpolate( |
|
gradcam, |
|
input.shape[2:], |
|
mode='bicubic', |
|
align_corners=False) |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
param.requires_grad_(requires_grad[name]) |
|
|
|
return gradcam |
|
|
|
|
|
|
|
def getAttMap(img, attn_map): |
|
|
|
attn_map = attn_map - attn_map.min() |
|
if attn_map.max() > 0: |
|
attn_map = attn_map / attn_map.max() |
|
|
|
H = matplotlib.cm.jet(attn_map) |
|
H = (H * 255).astype(np.uint8)[:, :, :3] |
|
img_heatmap = Image.fromarray(H) |
|
img_heatmap = img_heatmap.resize((256, 256)) |
|
|
|
return Image.blend( |
|
img.resize((256, 256)), img_heatmap, 0.4) |
|
|