Spaces:
Build error
Build error
""" | |
Created on Thu Oct 26 11:23:47 2017 | |
@author: Utku Ozbulak - github.com/utkuozbulak | |
""" | |
import torch | |
from torch.nn import ReLU | |
class GuidedBackprop(): | |
""" | |
Produces gradients generated with guided back propagation from the given image | |
""" | |
def __init__(self, model, arch): | |
self.model = model | |
self.arch = arch | |
self.gradients = None | |
self.forward_relu_outputs = [] | |
# Put model in evaluation mode | |
self.model.train() | |
self.update_relus() | |
self.hook_layers() | |
def hook_layers(self): | |
def hook_function(module, grad_in, grad_out): | |
self.gradients = grad_in[0] | |
# Register hook to the first layer | |
if 'alexnet' in self.arch: | |
first_layer = list(self.model.features._modules.items())[0][1] | |
elif 'resnet' in self.arch: | |
first_layer = list(self.model._modules.items())[0][1] | |
first_layer.register_backward_hook(hook_function) | |
def update_relus(self): | |
""" | |
Updates relu activation functions so that | |
1- stores output in forward pass | |
2- imputes zero for gradient values that are less than zero | |
""" | |
def relu_backward_hook_function(module, grad_in, grad_out): | |
""" | |
If there is a negative gradient, change it to zero | |
""" | |
# Get last forward output | |
corresponding_forward_output = self.forward_relu_outputs[-1] | |
corresponding_forward_output[corresponding_forward_output > 0] = 1 | |
modified_grad_out = corresponding_forward_output * torch.clamp(grad_in[0], min=0.0) | |
del self.forward_relu_outputs[-1] # Remove last forward output | |
return (modified_grad_out,) | |
def relu_forward_hook_function(module, ten_in, ten_out): | |
""" | |
Store results of forward pass | |
""" | |
self.forward_relu_outputs.append(ten_out) | |
# Loop through layers, hook up ReLUs | |
if 'alexnet' in self.arch: | |
for pos, module in self.model.features._modules.items(): | |
if isinstance(module, ReLU): | |
module.register_backward_hook(relu_backward_hook_function) | |
module.register_forward_hook(relu_forward_hook_function) | |
elif 'resnet' in self.arch: | |
for module in self.model.modules(): | |
if isinstance(module, ReLU): | |
module.register_backward_hook(relu_backward_hook_function) | |
module.register_forward_hook(relu_forward_hook_function) | |
def generate_gradients(self, input_image, one_hot_output_guided, text_for_pred): | |
# Forward pass | |
model_output = self.model(input_image, text_for_pred, is_train=False) | |
# Zero gradients | |
self.model.zero_grad() | |
# Backward pass | |
model_output.backward(gradient=one_hot_output_guided) | |
# Convert Pytorch variable to numpy array | |
# [0] to get rid of the first channel (1,3,224,224) | |
gradients_as_arr = self.gradients.data.to('cpu').numpy()[0] | |
print("gradients_as_arr shape: ", gradients_as_arr.shape) | |
return gradients_as_arr | |
if __name__ == '__main__': | |
target_example = 0 # Snake | |
(original_image, prep_img, target_class, file_name_to_export, pretrained_model) =\ | |
get_example_params(target_example) | |
# Guided backprop | |
GBP = GuidedBackprop(pretrained_model) | |
# Get gradients | |
guided_grads = GBP.generate_gradients(prep_img, target_class) | |
# Save colored gradients | |
save_gradient_images(guided_grads, file_name_to_export + '_Guided_BP_color') | |
# Convert to grayscale | |
grayscale_guided_grads = convert_to_grayscale(guided_grads) | |
# Save grayscale gradients | |
save_gradient_images(grayscale_guided_grads, file_name_to_export + '_Guided_BP_gray') | |
# Positive and negative saliency maps | |
pos_sal, neg_sal = get_positive_negative_saliency(guided_grads) | |
save_gradient_images(pos_sal, file_name_to_export + '_pos_sal') | |
save_gradient_images(neg_sal, file_name_to_export + '_neg_sal') | |
print('Guided backprop completed') | |