# -*- coding: utf-8 -*- # Copyright (c) XiMing Xing. All rights reserved. # Author: XiMing Xing # Description: import torch import torch.nn as nn import torch.nn.functional as F # Reference: https://arxiv.org/abs/1610.02391 def gradCAM( model: nn.Module, input: torch.Tensor, target: torch.Tensor, layer: nn.Module ) -> torch.Tensor: # Zero out any gradients at the input. if input.grad is not None: input.grad.data.zero_() # Disable gradient settings. requires_grad = {} for name, param in model.named_parameters(): requires_grad[name] = param.requires_grad param.requires_grad_(False) # Attach a hook to the model at the desired layer. assert isinstance(layer, nn.Module) with Hook(layer) as hook: # Do a forward and backward pass. output = model(input) output.backward(target) grad = hook.gradient.float() act = hook.activation.float() # Global average pool gradient across spatial dimension # to obtain importance weights. alpha = grad.mean(dim=(2, 3), keepdim=True) # Weighted combination of activation maps over channel # dimension. gradcam = torch.sum(act * alpha, dim=1, keepdim=True) # We only want neurons with positive influence so we # clamp any negative ones. gradcam = torch.clamp(gradcam, min=0) # Resize gradcam to input resolution. gradcam = F.interpolate(gradcam, input.shape[2:], mode='bicubic', align_corners=False) # Restore gradient settings. for name, param in model.named_parameters(): param.requires_grad_(requires_grad[name]) return gradcam class Hook: """Attaches to a module and records its activations and gradients.""" def __init__(self, module: nn.Module): self.data = None self.hook = module.register_forward_hook(self.save_grad) def save_grad(self, module, input, output): self.data = output output.requires_grad_(True) output.retain_grad() def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_traceback): self.hook.remove() @property def activation(self) -> torch.Tensor: return self.data @property def gradient(self) -> torch.Tensor: return self.data.grad