File size: 3,306 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union

from captum._utils.common import _format_tensor_into_tuples
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.stat import Stat
from captum.attr._utils.summarizer import Summarizer
from captum.log import log_usage
from torch import Tensor


class ClassSummarizer(Summarizer):
    r"""
    Used to keep track of summaries for associated classes. The
    classes/labels can be of any type that are supported by `dict`.

    This also keeps track of an aggregate of all class summaries.
    """

    @log_usage()
    def __init__(self, stats: List[Stat]) -> None:
        Summarizer.__init__.__wrapped__(self, stats)
        self.summaries: Dict[Any, Summarizer] = defaultdict(
            lambda: Summarizer(stats=stats)
        )

    def update(  # type: ignore
        self,
        x: TensorOrTupleOfTensorsGeneric,
        labels: TargetType = None,
    ):
        r"""
        Updates the stats of the summarizer, optionally associated to classes.

        This accepts either a single tensor to summarise or a tuple of tensors.

        Args:
            x (Tensor or Tuple[Tensor, ...]):
                The input tensor to be summarised. The first
                dimension of this input must be associated to
                the batch size of the inputs.
            labels (int, tuple, tensor or list, optional):
                The associated labels for `x`. If Any, we
                assume `labels` represents the label for all inputs in `x`.

                If this is None we simply aggregate the total summary.
        """
        if labels is None:
            super().update(x)
            return

        x = _format_tensor_into_tuples(x)

        num_labels = 1

        labels_typed: Union[List[Any], Tensor]
        if isinstance(labels, list) or isinstance(labels, Tensor):
            labels_typed = labels
            num_labels = len(labels)  # = labels.size(0) if tensor
        else:
            labels_typed = [labels]

        # mypy doesn't realise I have made the int a list
        if len(labels_typed) > 1:
            for x_i in x:
                assert x_i.size(0) == num_labels, (
                    "batch size does not equal amount of labels; "
                    "please ensure length of labels is equal to 1 "
                    "or to the `batch_size` corresponding to the "
                    "number of examples in the input(s)"
                )

        batch_size = x[0].size(0)

        for i in range(batch_size):
            tensors_to_summarize = tuple(tensor[i] for tensor in x)
            tensors_to_summarize_copy = tuple(tensor[i].clone() for tensor in x)
            label = labels_typed[0] if len(labels_typed) == 1 else labels_typed[i]

            self.summaries[label].update(tensors_to_summarize)
            super().update(tensors_to_summarize_copy)

    @property
    def class_summaries(
        self,
    ) -> Dict[
        Any, Union[None, Dict[str, Optional[Tensor]], List[Dict[str, Optional[Tensor]]]]
    ]:
        r"""
        Returns:
             The summaries for each class.
        """
        return {key: value.summary for key, value in self.summaries.items()}