|
import cv2 |
|
import numpy as np |
|
import torch |
|
import tqdm |
|
from pytorch_grad_cam.base_cam import BaseCAM |
|
|
|
|
|
class AblationLayer(torch.nn.Module): |
|
def __init__(self, layer, reshape_transform, indices): |
|
super(AblationLayer, self).__init__() |
|
|
|
self.layer = layer |
|
self.reshape_transform = reshape_transform |
|
|
|
self.indices = indices |
|
|
|
def forward(self, x): |
|
self.__call__(x) |
|
|
|
def __call__(self, x): |
|
output = self.layer(x) |
|
|
|
|
|
|
|
|
|
if self.reshape_transform is not None: |
|
output = output.transpose(1, 2) |
|
|
|
for i in range(output.size(0)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.min(output) == 0: |
|
output[i, self.indices[i], :] = 0 |
|
else: |
|
ABLATION_VALUE = 1e5 |
|
output[i, self.indices[i], :] = torch.min( |
|
output) - ABLATION_VALUE |
|
|
|
if self.reshape_transform is not None: |
|
output = output.transpose(2, 1) |
|
|
|
return output |
|
|
|
|
|
def replace_layer_recursive(model, old_layer, new_layer): |
|
for name, layer in model._modules.items(): |
|
if layer == old_layer: |
|
model._modules[name] = new_layer |
|
return True |
|
elif replace_layer_recursive(layer, old_layer, new_layer): |
|
return True |
|
return False |
|
|
|
|
|
class AblationCAM(BaseCAM): |
|
def __init__(self, model, target_layers, use_cuda=False, |
|
reshape_transform=None): |
|
super(AblationCAM, self).__init__(model, target_layers, use_cuda, |
|
reshape_transform) |
|
|
|
if len(target_layers) > 1: |
|
print( |
|
"Warning. You are usign Ablation CAM with more than 1 layers. " |
|
"This is supported only if all layers have the same output shape") |
|
|
|
def set_ablation_layers(self): |
|
self.ablation_layers = [] |
|
for target_layer in self.target_layers: |
|
ablation_layer = AblationLayer(target_layer, |
|
self.reshape_transform, indices=[]) |
|
self.ablation_layers.append(ablation_layer) |
|
replace_layer_recursive(self.model, target_layer, ablation_layer) |
|
|
|
def unset_ablation_layers(self): |
|
|
|
for ablation_layer, target_layer in zip( |
|
self.ablation_layers, self.target_layers): |
|
replace_layer_recursive(self.model, ablation_layer, target_layer) |
|
|
|
def set_ablation_layer_batch_indices(self, indices): |
|
for ablation_layer in self.ablation_layers: |
|
ablation_layer.indices = indices |
|
|
|
def trim_ablation_layer_batch_indices(self, keep): |
|
for ablation_layer in self.ablation_layers: |
|
ablation_layer.indices = ablation_layer.indices[:keep] |
|
|
|
def get_cam_weights(self, |
|
input_tensor, |
|
target_category, |
|
activations, |
|
grads): |
|
with torch.no_grad(): |
|
outputs = self.model(input_tensor).cpu().numpy() |
|
original_scores = [] |
|
for i in range(input_tensor.size(0)): |
|
original_scores.append(outputs[i, target_category[i]]) |
|
original_scores = np.float32(original_scores) |
|
|
|
self.set_ablation_layers() |
|
|
|
if hasattr(self, "batch_size"): |
|
BATCH_SIZE = self.batch_size |
|
else: |
|
BATCH_SIZE = 32 |
|
|
|
number_of_channels = activations.shape[1] |
|
weights = [] |
|
|
|
with torch.no_grad(): |
|
|
|
for tensor, category in zip(input_tensor, target_category): |
|
batch_tensor = tensor.repeat(BATCH_SIZE, 1, 1, 1) |
|
for i in tqdm.tqdm(range(0, number_of_channels, BATCH_SIZE)): |
|
self.set_ablation_layer_batch_indices( |
|
list(range(i, i + BATCH_SIZE))) |
|
|
|
if i + BATCH_SIZE > number_of_channels: |
|
keep = number_of_channels - i |
|
batch_tensor = batch_tensor[:keep] |
|
self.trim_ablation_layer_batch_indices(self, keep) |
|
score = self.model(batch_tensor)[:, category].cpu().numpy() |
|
weights.extend(score) |
|
|
|
weights = np.float32(weights) |
|
weights = weights.reshape(activations.shape[:2]) |
|
original_scores = original_scores[:, None] |
|
weights = (original_scores - weights) / original_scores |
|
|
|
|
|
self.unset_ablation_layers() |
|
return weights |
|
|