#!/usr/bin/env python3 import inspect from collections import namedtuple from typing import ( Callable, cast, Dict, Iterable, List, Optional, Sequence, Tuple, Union, ) import torch from captum._utils.common import _run_forward, safe_div from captum.insights.attr_vis.config import ( ATTRIBUTION_METHOD_CONFIG, ATTRIBUTION_NAMES_TO_METHODS, ) from captum.insights.attr_vis.features import BaseFeature from torch import Tensor from torch.nn import Module OutputScore = namedtuple("OutputScore", "score index label") class AttributionCalculation: def __init__( self, models: Sequence[Module], classes: Sequence[str], features: List[BaseFeature], score_func: Optional[Callable] = None, use_label_for_attr: bool = True, ) -> None: self.models = models self.classes = classes self.features = features self.score_func = score_func self.use_label_for_attr = use_label_for_attr self.baseline_cache: dict = {} self.transformed_input_cache: dict = {} def calculate_predicted_scores( self, inputs, additional_forward_args, model ) -> Tuple[ List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...] ]: # Check if inputs have cached baselines and transformed inputs hashable_inputs = tuple(inputs) if hashable_inputs in self.baseline_cache: baselines_group = self.baseline_cache[hashable_inputs] transformed_inputs = self.transformed_input_cache[hashable_inputs] else: # Initialize baselines baseline_transforms_len = 1 # todo support multiple baselines baselines: List[List[Optional[Tensor]]] = [ [None] * len(self.features) for _ in range(baseline_transforms_len) ] transformed_inputs = list(inputs) for feature_i, feature in enumerate(self.features): transformed_inputs[feature_i] = self._transform( feature.input_transforms, transformed_inputs[feature_i], True ) for baseline_i in range(baseline_transforms_len): if baseline_i > len(feature.baseline_transforms) - 1: baselines[baseline_i][feature_i] = torch.zeros_like( transformed_inputs[feature_i] ) else: baselines[baseline_i][feature_i] = self._transform( [feature.baseline_transforms[baseline_i]], transformed_inputs[feature_i], True, ) baselines = cast(List[List[Optional[Tensor]]], baselines) baselines_group = [tuple(b) for b in baselines] self.baseline_cache[hashable_inputs] = baselines_group self.transformed_input_cache[hashable_inputs] = transformed_inputs outputs = _run_forward( model, tuple(transformed_inputs), additional_forward_args=additional_forward_args, ) if self.score_func is not None: outputs = self.score_func(outputs) if outputs.nelement() == 1: scores = outputs predicted = scores.round().to(torch.int) else: scores, predicted = outputs.topk(min(4, outputs.shape[-1])) scores = scores.cpu().squeeze(0) predicted = predicted.cpu().squeeze(0) predicted_scores = self._get_labels_from_scores(scores, predicted) return predicted_scores, baselines_group, tuple(transformed_inputs) def calculate_attribution( self, baselines: Optional[Sequence[Tuple[Tensor, ...]]], data: Tuple[Tensor, ...], additional_forward_args: Optional[Tuple[Tensor, ...]], label: Optional[Union[Tensor]], attribution_method_name: str, attribution_arguments: Dict, model: Module, ) -> Tuple[Tensor, ...]: attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[attribution_method_name] attribution_method = attribution_cls(model) if attribution_method_name in ATTRIBUTION_METHOD_CONFIG: param_config = ATTRIBUTION_METHOD_CONFIG[attribution_method_name] if param_config.post_process: for k, v in attribution_arguments.items(): if k in param_config.post_process: attribution_arguments[k] = param_config.post_process[k](v) # TODO support multiple baselines baseline = baselines[0] if baselines and len(baselines) > 0 else None label = ( None if not self.use_label_for_attr or label is None or label.nelement() == 0 else label ) if "baselines" in inspect.signature(attribution_method.attribute).parameters: attribution_arguments["baselines"] = baseline attr = attribution_method.attribute.__wrapped__( attribution_method, # self data, additional_forward_args=additional_forward_args, target=label, **attribution_arguments, ) return attr def calculate_net_contrib( self, attrs_per_input_feature: Tuple[Tensor, ...] ) -> List[float]: # get the net contribution per feature (input) net_contrib = torch.stack( [attrib.flatten().sum() for attrib in attrs_per_input_feature] ) # normalise the contribution, s.t. sum(abs(x_i)) = 1 norm = torch.norm(net_contrib, p=1) # if norm is 0, all net_contrib elements are 0 net_contrib = safe_div(net_contrib, norm) return net_contrib.tolist() def _transform( self, transforms: Iterable[Callable], inputs: Tensor, batch: bool = False ) -> Tensor: transformed_inputs = inputs # TODO support batch size > 1 if batch: transformed_inputs = inputs.squeeze(0) for t in transforms: transformed_inputs = t(transformed_inputs) if batch: transformed_inputs = transformed_inputs.unsqueeze(0) return transformed_inputs def _get_labels_from_scores( self, scores: Tensor, indices: Tensor ) -> List[OutputScore]: pred_scores: List[OutputScore] = [] if indices.nelement() < 2: return pred_scores for i in range(len(indices)): score = scores[i] pred_scores.append( OutputScore(score, indices[i], self.classes[int(indices[i])]) ) return pred_scores