#!/usr/bin/env python3 import warnings from typing import Callable, Tuple import torch from torch import Tensor def _divide_and_aggregate_metrics( inputs: Tuple[Tensor, ...], n_perturb_samples: int, metric_func: Callable, agg_func: Callable = torch.add, max_examples_per_batch: int = None, ) -> Tensor: r""" This function is used to slice large number of samples `n_perturb_samples` per input example into smaller pieces, computing the metrics for each small piece and aggregating the results across all `n_perturb_samples` per example. The function returns overall aggregated metric per sample. The size of each slice is determined by the `max_examples_per_batch` input parameter. Args: inputs (tuple): The original inputs formatted in a tuple that are passed to the metrics function and that are used to compute the attributions for. n_perturb_samples (int): The number of samples per example that are used for perturbation purposes for example. metric_func (callable): This function takes the number of samples per input batch and returns an overall metric for each example. agg_func (callable, optional): This function is used to aggregate the metrics across multiple sub-batches and that are generated by `metric_func`. max_examples_per_batch (int, optional): The maximum number of allowed examples per batch. Returns: metric (tensor): A metric score estimated by `metric_func` per input example. """ bsz = inputs[0].size(0) if max_examples_per_batch is not None and ( max_examples_per_batch // bsz < 1 or max_examples_per_batch // bsz > n_perturb_samples ): warnings.warn( ( "`max_examples_per_batch` must be at least equal to the" " input batch size and at most to " "`input batch size` * `n_perturb_samples`." "`max_examples_per_batch` is: {} and the input batch size is: {}." "This is necessary because we require that each sub-batch that is used " "to compute the metrics, contains at least an instance of " "the original example and doesn't exceed the number of " "expanded n_perturb_samples." ).format(max_examples_per_batch, bsz) ) max_inps_per_batch = ( n_perturb_samples if max_examples_per_batch is None else min(max(max_examples_per_batch // bsz, 1), n_perturb_samples) ) current_n_steps = max_inps_per_batch metrics_sum = metric_func(max_inps_per_batch) while current_n_steps < n_perturb_samples: current_n_steps += max_inps_per_batch metric = metric_func( max_inps_per_batch if current_n_steps <= n_perturb_samples else max_inps_per_batch - (current_n_steps - n_perturb_samples) ) current_n_steps = min(current_n_steps, n_perturb_samples) metrics_sum = agg_func(metrics_sum, metric) return metrics_sum