Spaces:
Build error
Build error
File size: 6,733 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
#!/usr/bin/env python3
import inspect
from collections import namedtuple
from typing import (
Callable,
cast,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)
import torch
from captum._utils.common import _run_forward, safe_div
from captum.insights.attr_vis.config import (
ATTRIBUTION_METHOD_CONFIG,
ATTRIBUTION_NAMES_TO_METHODS,
)
from captum.insights.attr_vis.features import BaseFeature
from torch import Tensor
from torch.nn import Module
OutputScore = namedtuple("OutputScore", "score index label")
class AttributionCalculation:
def __init__(
self,
models: Sequence[Module],
classes: Sequence[str],
features: List[BaseFeature],
score_func: Optional[Callable] = None,
use_label_for_attr: bool = True,
) -> None:
self.models = models
self.classes = classes
self.features = features
self.score_func = score_func
self.use_label_for_attr = use_label_for_attr
self.baseline_cache: dict = {}
self.transformed_input_cache: dict = {}
def calculate_predicted_scores(
self, inputs, additional_forward_args, model
) -> Tuple[
List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...]
]:
# Check if inputs have cached baselines and transformed inputs
hashable_inputs = tuple(inputs)
if hashable_inputs in self.baseline_cache:
baselines_group = self.baseline_cache[hashable_inputs]
transformed_inputs = self.transformed_input_cache[hashable_inputs]
else:
# Initialize baselines
baseline_transforms_len = 1 # todo support multiple baselines
baselines: List[List[Optional[Tensor]]] = [
[None] * len(self.features) for _ in range(baseline_transforms_len)
]
transformed_inputs = list(inputs)
for feature_i, feature in enumerate(self.features):
transformed_inputs[feature_i] = self._transform(
feature.input_transforms, transformed_inputs[feature_i], True
)
for baseline_i in range(baseline_transforms_len):
if baseline_i > len(feature.baseline_transforms) - 1:
baselines[baseline_i][feature_i] = torch.zeros_like(
transformed_inputs[feature_i]
)
else:
baselines[baseline_i][feature_i] = self._transform(
[feature.baseline_transforms[baseline_i]],
transformed_inputs[feature_i],
True,
)
baselines = cast(List[List[Optional[Tensor]]], baselines)
baselines_group = [tuple(b) for b in baselines]
self.baseline_cache[hashable_inputs] = baselines_group
self.transformed_input_cache[hashable_inputs] = transformed_inputs
outputs = _run_forward(
model,
tuple(transformed_inputs),
additional_forward_args=additional_forward_args,
)
if self.score_func is not None:
outputs = self.score_func(outputs)
if outputs.nelement() == 1:
scores = outputs
predicted = scores.round().to(torch.int)
else:
scores, predicted = outputs.topk(min(4, outputs.shape[-1]))
scores = scores.cpu().squeeze(0)
predicted = predicted.cpu().squeeze(0)
predicted_scores = self._get_labels_from_scores(scores, predicted)
return predicted_scores, baselines_group, tuple(transformed_inputs)
def calculate_attribution(
self,
baselines: Optional[Sequence[Tuple[Tensor, ...]]],
data: Tuple[Tensor, ...],
additional_forward_args: Optional[Tuple[Tensor, ...]],
label: Optional[Union[Tensor]],
attribution_method_name: str,
attribution_arguments: Dict,
model: Module,
) -> Tuple[Tensor, ...]:
attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[attribution_method_name]
attribution_method = attribution_cls(model)
if attribution_method_name in ATTRIBUTION_METHOD_CONFIG:
param_config = ATTRIBUTION_METHOD_CONFIG[attribution_method_name]
if param_config.post_process:
for k, v in attribution_arguments.items():
if k in param_config.post_process:
attribution_arguments[k] = param_config.post_process[k](v)
# TODO support multiple baselines
baseline = baselines[0] if baselines and len(baselines) > 0 else None
label = (
None
if not self.use_label_for_attr or label is None or label.nelement() == 0
else label
)
if "baselines" in inspect.signature(attribution_method.attribute).parameters:
attribution_arguments["baselines"] = baseline
attr = attribution_method.attribute.__wrapped__(
attribution_method, # self
data,
additional_forward_args=additional_forward_args,
target=label,
**attribution_arguments,
)
return attr
def calculate_net_contrib(
self, attrs_per_input_feature: Tuple[Tensor, ...]
) -> List[float]:
# get the net contribution per feature (input)
net_contrib = torch.stack(
[attrib.flatten().sum() for attrib in attrs_per_input_feature]
)
# normalise the contribution, s.t. sum(abs(x_i)) = 1
norm = torch.norm(net_contrib, p=1)
# if norm is 0, all net_contrib elements are 0
net_contrib = safe_div(net_contrib, norm)
return net_contrib.tolist()
def _transform(
self, transforms: Iterable[Callable], inputs: Tensor, batch: bool = False
) -> Tensor:
transformed_inputs = inputs
# TODO support batch size > 1
if batch:
transformed_inputs = inputs.squeeze(0)
for t in transforms:
transformed_inputs = t(transformed_inputs)
if batch:
transformed_inputs = transformed_inputs.unsqueeze(0)
return transformed_inputs
def _get_labels_from_scores(
self, scores: Tensor, indices: Tensor
) -> List[OutputScore]:
pred_scores: List[OutputScore] = []
if indices.nelement() < 2:
return pred_scores
for i in range(len(indices)):
score = scores[i]
pred_scores.append(
OutputScore(score, indices[i], self.classes[int(indices[i])])
)
return pred_scores
|