strexp / modules_trba /guided_backprop.py
markytools's picture
added strexp
d61b9c7
raw
history blame
4.15 kB
"""
Created on Thu Oct 26 11:23:47 2017
@author: Utku Ozbulak - github.com/utkuozbulak
"""
import torch
from torch.nn import ReLU
class GuidedBackprop():
"""
Produces gradients generated with guided back propagation from the given image
"""
def __init__(self, model, arch):
self.model = model
self.arch = arch
self.gradients = None
self.forward_relu_outputs = []
# Put model in evaluation mode
self.model.train()
self.update_relus()
self.hook_layers()
def hook_layers(self):
def hook_function(module, grad_in, grad_out):
self.gradients = grad_in[0]
# Register hook to the first layer
if 'alexnet' in self.arch:
first_layer = list(self.model.features._modules.items())[0][1]
elif 'resnet' in self.arch:
first_layer = list(self.model._modules.items())[0][1]
first_layer.register_backward_hook(hook_function)
def update_relus(self):
"""
Updates relu activation functions so that
1- stores output in forward pass
2- imputes zero for gradient values that are less than zero
"""
def relu_backward_hook_function(module, grad_in, grad_out):
"""
If there is a negative gradient, change it to zero
"""
# Get last forward output
corresponding_forward_output = self.forward_relu_outputs[-1]
corresponding_forward_output[corresponding_forward_output > 0] = 1
modified_grad_out = corresponding_forward_output * torch.clamp(grad_in[0], min=0.0)
del self.forward_relu_outputs[-1] # Remove last forward output
return (modified_grad_out,)
def relu_forward_hook_function(module, ten_in, ten_out):
"""
Store results of forward pass
"""
self.forward_relu_outputs.append(ten_out)
# Loop through layers, hook up ReLUs
if 'alexnet' in self.arch:
for pos, module in self.model.features._modules.items():
if isinstance(module, ReLU):
module.register_backward_hook(relu_backward_hook_function)
module.register_forward_hook(relu_forward_hook_function)
elif 'resnet' in self.arch:
for module in self.model.modules():
if isinstance(module, ReLU):
module.register_backward_hook(relu_backward_hook_function)
module.register_forward_hook(relu_forward_hook_function)
def generate_gradients(self, input_image, one_hot_output_guided, text_for_pred):
# Forward pass
model_output = self.model(input_image, text_for_pred, is_train=False)
# Zero gradients
self.model.zero_grad()
# Backward pass
model_output.backward(gradient=one_hot_output_guided)
# Convert Pytorch variable to numpy array
# [0] to get rid of the first channel (1,3,224,224)
gradients_as_arr = self.gradients.data.to('cpu').numpy()[0]
print("gradients_as_arr shape: ", gradients_as_arr.shape)
return gradients_as_arr
if __name__ == '__main__':
target_example = 0 # Snake
(original_image, prep_img, target_class, file_name_to_export, pretrained_model) =\
get_example_params(target_example)
# Guided backprop
GBP = GuidedBackprop(pretrained_model)
# Get gradients
guided_grads = GBP.generate_gradients(prep_img, target_class)
# Save colored gradients
save_gradient_images(guided_grads, file_name_to_export + '_Guided_BP_color')
# Convert to grayscale
grayscale_guided_grads = convert_to_grayscale(guided_grads)
# Save grayscale gradients
save_gradient_images(grayscale_guided_grads, file_name_to_export + '_Guided_BP_gray')
# Positive and negative saliency maps
pos_sal, neg_sal = get_positive_negative_saliency(guided_grads)
save_gradient_images(pos_sal, file_name_to_export + '_pos_sal')
save_gradient_images(neg_sal, file_name_to_export + '_neg_sal')
print('Guided backprop completed')