""" Created on Thu Oct 26 11:23:47 2017 @author: Utku Ozbulak - github.com/utkuozbulak """ import torch from torch.nn import ReLU class GuidedBackprop(): """ Produces gradients generated with guided back propagation from the given image """ def __init__(self, model, arch): self.model = model self.arch = arch self.gradients = None self.forward_relu_outputs = [] # Put model in evaluation mode self.model.train() self.update_relus() self.hook_layers() def hook_layers(self): def hook_function(module, grad_in, grad_out): self.gradients = grad_in[0] # Register hook to the first layer if 'alexnet' in self.arch: first_layer = list(self.model.features._modules.items())[0][1] elif 'resnet' in self.arch: first_layer = list(self.model._modules.items())[0][1] first_layer.register_backward_hook(hook_function) def update_relus(self): """ Updates relu activation functions so that 1- stores output in forward pass 2- imputes zero for gradient values that are less than zero """ def relu_backward_hook_function(module, grad_in, grad_out): """ If there is a negative gradient, change it to zero """ # Get last forward output corresponding_forward_output = self.forward_relu_outputs[-1] corresponding_forward_output[corresponding_forward_output > 0] = 1 modified_grad_out = corresponding_forward_output * torch.clamp(grad_in[0], min=0.0) del self.forward_relu_outputs[-1] # Remove last forward output return (modified_grad_out,) def relu_forward_hook_function(module, ten_in, ten_out): """ Store results of forward pass """ self.forward_relu_outputs.append(ten_out) # Loop through layers, hook up ReLUs if 'alexnet' in self.arch: for pos, module in self.model.features._modules.items(): if isinstance(module, ReLU): module.register_backward_hook(relu_backward_hook_function) module.register_forward_hook(relu_forward_hook_function) elif 'resnet' in self.arch: for module in self.model.modules(): if isinstance(module, ReLU): module.register_backward_hook(relu_backward_hook_function) module.register_forward_hook(relu_forward_hook_function) def generate_gradients(self, input_image, one_hot_output_guided, text_for_pred): # Forward pass model_output = self.model(input_image, text_for_pred, is_train=False) # Zero gradients self.model.zero_grad() # Backward pass model_output.backward(gradient=one_hot_output_guided) # Convert Pytorch variable to numpy array # [0] to get rid of the first channel (1,3,224,224) gradients_as_arr = self.gradients.data.to('cpu').numpy()[0] print("gradients_as_arr shape: ", gradients_as_arr.shape) return gradients_as_arr if __name__ == '__main__': target_example = 0 # Snake (original_image, prep_img, target_class, file_name_to_export, pretrained_model) =\ get_example_params(target_example) # Guided backprop GBP = GuidedBackprop(pretrained_model) # Get gradients guided_grads = GBP.generate_gradients(prep_img, target_class) # Save colored gradients save_gradient_images(guided_grads, file_name_to_export + '_Guided_BP_color') # Convert to grayscale grayscale_guided_grads = convert_to_grayscale(guided_grads) # Save grayscale gradients save_gradient_images(grayscale_guided_grads, file_name_to_export + '_Guided_BP_gray') # Positive and negative saliency maps pos_sal, neg_sal = get_positive_negative_saliency(guided_grads) save_gradient_images(pos_sal, file_name_to_export + '_pos_sal') save_gradient_images(neg_sal, file_name_to_export + '_neg_sal') print('Guided backprop completed')