File size: 3,261 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python3

import warnings
from typing import Callable, Tuple

import torch
from torch import Tensor


def _divide_and_aggregate_metrics(
    inputs: Tuple[Tensor, ...],
    n_perturb_samples: int,
    metric_func: Callable,
    agg_func: Callable = torch.add,
    max_examples_per_batch: int = None,
) -> Tensor:
    r"""
    This function is used to slice large number of samples `n_perturb_samples` per
    input example into smaller pieces, computing the metrics for each small piece and
    aggregating the results across all `n_perturb_samples` per example. The function
    returns overall aggregated metric per sample. The size of each slice is determined
    by the `max_examples_per_batch` input parameter.

    Args:

        inputs (tuple): The original inputs formatted in a tuple that are passed to
                        the metrics function and that are used to compute the
                        attributions for.
        n_perturb_samples (int): The number of samples per example that are used for
                        perturbation purposes for example.
        metric_func (callable): This function takes the number of samples per
                        input batch and returns an overall metric for each example.
        agg_func (callable, optional): This function is used to aggregate the
                        metrics across multiple sub-batches and that are
                        generated by `metric_func`.
        max_examples_per_batch (int, optional): The maximum number of allowed examples
                        per batch.

        Returns:

            metric (tensor): A metric score estimated by `metric_func` per
                        input example.
    """
    bsz = inputs[0].size(0)

    if max_examples_per_batch is not None and (
        max_examples_per_batch // bsz < 1
        or max_examples_per_batch // bsz > n_perturb_samples
    ):
        warnings.warn(
            (
                "`max_examples_per_batch` must be at least equal to the"
                " input batch size and at most to "
                "`input batch size` * `n_perturb_samples`."
                "`max_examples_per_batch` is: {} and the input batch size is: {}."
                "This is necessary because we require that each sub-batch that is used "
                "to compute the metrics, contains at least an instance of "
                "the original example and doesn't exceed the number of "
                "expanded n_perturb_samples."
            ).format(max_examples_per_batch, bsz)
        )

    max_inps_per_batch = (
        n_perturb_samples
        if max_examples_per_batch is None
        else min(max(max_examples_per_batch // bsz, 1), n_perturb_samples)
    )

    current_n_steps = max_inps_per_batch

    metrics_sum = metric_func(max_inps_per_batch)

    while current_n_steps < n_perturb_samples:
        current_n_steps += max_inps_per_batch

        metric = metric_func(
            max_inps_per_batch
            if current_n_steps <= n_perturb_samples
            else max_inps_per_batch - (current_n_steps - n_perturb_samples)
        )

        current_n_steps = min(current_n_steps, n_perturb_samples)

        metrics_sum = agg_func(metrics_sum, metric)
    return metrics_sum