Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import inspect | |
from collections import namedtuple | |
from typing import ( | |
Callable, | |
cast, | |
Dict, | |
Iterable, | |
List, | |
Optional, | |
Sequence, | |
Tuple, | |
Union, | |
) | |
import torch | |
from captum._utils.common import _run_forward, safe_div | |
from captum.insights.attr_vis.config import ( | |
ATTRIBUTION_METHOD_CONFIG, | |
ATTRIBUTION_NAMES_TO_METHODS, | |
) | |
from captum.insights.attr_vis.features import BaseFeature | |
from torch import Tensor | |
from torch.nn import Module | |
OutputScore = namedtuple("OutputScore", "score index label") | |
class AttributionCalculation: | |
def __init__( | |
self, | |
models: Sequence[Module], | |
classes: Sequence[str], | |
features: List[BaseFeature], | |
score_func: Optional[Callable] = None, | |
use_label_for_attr: bool = True, | |
) -> None: | |
self.models = models | |
self.classes = classes | |
self.features = features | |
self.score_func = score_func | |
self.use_label_for_attr = use_label_for_attr | |
self.baseline_cache: dict = {} | |
self.transformed_input_cache: dict = {} | |
def calculate_predicted_scores( | |
self, inputs, additional_forward_args, model | |
) -> Tuple[ | |
List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...] | |
]: | |
# Check if inputs have cached baselines and transformed inputs | |
hashable_inputs = tuple(inputs) | |
if hashable_inputs in self.baseline_cache: | |
baselines_group = self.baseline_cache[hashable_inputs] | |
transformed_inputs = self.transformed_input_cache[hashable_inputs] | |
else: | |
# Initialize baselines | |
baseline_transforms_len = 1 # todo support multiple baselines | |
baselines: List[List[Optional[Tensor]]] = [ | |
[None] * len(self.features) for _ in range(baseline_transforms_len) | |
] | |
transformed_inputs = list(inputs) | |
for feature_i, feature in enumerate(self.features): | |
transformed_inputs[feature_i] = self._transform( | |
feature.input_transforms, transformed_inputs[feature_i], True | |
) | |
for baseline_i in range(baseline_transforms_len): | |
if baseline_i > len(feature.baseline_transforms) - 1: | |
baselines[baseline_i][feature_i] = torch.zeros_like( | |
transformed_inputs[feature_i] | |
) | |
else: | |
baselines[baseline_i][feature_i] = self._transform( | |
[feature.baseline_transforms[baseline_i]], | |
transformed_inputs[feature_i], | |
True, | |
) | |
baselines = cast(List[List[Optional[Tensor]]], baselines) | |
baselines_group = [tuple(b) for b in baselines] | |
self.baseline_cache[hashable_inputs] = baselines_group | |
self.transformed_input_cache[hashable_inputs] = transformed_inputs | |
outputs = _run_forward( | |
model, | |
tuple(transformed_inputs), | |
additional_forward_args=additional_forward_args, | |
) | |
if self.score_func is not None: | |
outputs = self.score_func(outputs) | |
if outputs.nelement() == 1: | |
scores = outputs | |
predicted = scores.round().to(torch.int) | |
else: | |
scores, predicted = outputs.topk(min(4, outputs.shape[-1])) | |
scores = scores.cpu().squeeze(0) | |
predicted = predicted.cpu().squeeze(0) | |
predicted_scores = self._get_labels_from_scores(scores, predicted) | |
return predicted_scores, baselines_group, tuple(transformed_inputs) | |
def calculate_attribution( | |
self, | |
baselines: Optional[Sequence[Tuple[Tensor, ...]]], | |
data: Tuple[Tensor, ...], | |
additional_forward_args: Optional[Tuple[Tensor, ...]], | |
label: Optional[Union[Tensor]], | |
attribution_method_name: str, | |
attribution_arguments: Dict, | |
model: Module, | |
) -> Tuple[Tensor, ...]: | |
attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[attribution_method_name] | |
attribution_method = attribution_cls(model) | |
if attribution_method_name in ATTRIBUTION_METHOD_CONFIG: | |
param_config = ATTRIBUTION_METHOD_CONFIG[attribution_method_name] | |
if param_config.post_process: | |
for k, v in attribution_arguments.items(): | |
if k in param_config.post_process: | |
attribution_arguments[k] = param_config.post_process[k](v) | |
# TODO support multiple baselines | |
baseline = baselines[0] if baselines and len(baselines) > 0 else None | |
label = ( | |
None | |
if not self.use_label_for_attr or label is None or label.nelement() == 0 | |
else label | |
) | |
if "baselines" in inspect.signature(attribution_method.attribute).parameters: | |
attribution_arguments["baselines"] = baseline | |
attr = attribution_method.attribute.__wrapped__( | |
attribution_method, # self | |
data, | |
additional_forward_args=additional_forward_args, | |
target=label, | |
**attribution_arguments, | |
) | |
return attr | |
def calculate_net_contrib( | |
self, attrs_per_input_feature: Tuple[Tensor, ...] | |
) -> List[float]: | |
# get the net contribution per feature (input) | |
net_contrib = torch.stack( | |
[attrib.flatten().sum() for attrib in attrs_per_input_feature] | |
) | |
# normalise the contribution, s.t. sum(abs(x_i)) = 1 | |
norm = torch.norm(net_contrib, p=1) | |
# if norm is 0, all net_contrib elements are 0 | |
net_contrib = safe_div(net_contrib, norm) | |
return net_contrib.tolist() | |
def _transform( | |
self, transforms: Iterable[Callable], inputs: Tensor, batch: bool = False | |
) -> Tensor: | |
transformed_inputs = inputs | |
# TODO support batch size > 1 | |
if batch: | |
transformed_inputs = inputs.squeeze(0) | |
for t in transforms: | |
transformed_inputs = t(transformed_inputs) | |
if batch: | |
transformed_inputs = transformed_inputs.unsqueeze(0) | |
return transformed_inputs | |
def _get_labels_from_scores( | |
self, scores: Tensor, indices: Tensor | |
) -> List[OutputScore]: | |
pred_scores: List[OutputScore] = [] | |
if indices.nelement() < 2: | |
return pred_scores | |
for i in range(len(indices)): | |
score = scores[i] | |
pred_scores.append( | |
OutputScore(score, indices[i], self.classes[int(indices[i])]) | |
) | |
return pred_scores | |