Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import warnings | |
from typing import Callable, Tuple | |
import torch | |
from torch import Tensor | |
def _divide_and_aggregate_metrics( | |
inputs: Tuple[Tensor, ...], | |
n_perturb_samples: int, | |
metric_func: Callable, | |
agg_func: Callable = torch.add, | |
max_examples_per_batch: int = None, | |
) -> Tensor: | |
r""" | |
This function is used to slice large number of samples `n_perturb_samples` per | |
input example into smaller pieces, computing the metrics for each small piece and | |
aggregating the results across all `n_perturb_samples` per example. The function | |
returns overall aggregated metric per sample. The size of each slice is determined | |
by the `max_examples_per_batch` input parameter. | |
Args: | |
inputs (tuple): The original inputs formatted in a tuple that are passed to | |
the metrics function and that are used to compute the | |
attributions for. | |
n_perturb_samples (int): The number of samples per example that are used for | |
perturbation purposes for example. | |
metric_func (callable): This function takes the number of samples per | |
input batch and returns an overall metric for each example. | |
agg_func (callable, optional): This function is used to aggregate the | |
metrics across multiple sub-batches and that are | |
generated by `metric_func`. | |
max_examples_per_batch (int, optional): The maximum number of allowed examples | |
per batch. | |
Returns: | |
metric (tensor): A metric score estimated by `metric_func` per | |
input example. | |
""" | |
bsz = inputs[0].size(0) | |
if max_examples_per_batch is not None and ( | |
max_examples_per_batch // bsz < 1 | |
or max_examples_per_batch // bsz > n_perturb_samples | |
): | |
warnings.warn( | |
( | |
"`max_examples_per_batch` must be at least equal to the" | |
" input batch size and at most to " | |
"`input batch size` * `n_perturb_samples`." | |
"`max_examples_per_batch` is: {} and the input batch size is: {}." | |
"This is necessary because we require that each sub-batch that is used " | |
"to compute the metrics, contains at least an instance of " | |
"the original example and doesn't exceed the number of " | |
"expanded n_perturb_samples." | |
).format(max_examples_per_batch, bsz) | |
) | |
max_inps_per_batch = ( | |
n_perturb_samples | |
if max_examples_per_batch is None | |
else min(max(max_examples_per_batch // bsz, 1), n_perturb_samples) | |
) | |
current_n_steps = max_inps_per_batch | |
metrics_sum = metric_func(max_inps_per_batch) | |
while current_n_steps < n_perturb_samples: | |
current_n_steps += max_inps_per_batch | |
metric = metric_func( | |
max_inps_per_batch | |
if current_n_steps <= n_perturb_samples | |
else max_inps_per_batch - (current_n_steps - n_perturb_samples) | |
) | |
current_n_steps = min(current_n_steps, n_perturb_samples) | |
metrics_sum = agg_func(metrics_sum, metric) | |
return metrics_sum | |