Spaces:
Running
Running
import torch | |
def average_over_durations(values: torch.Tensor, durs: torch.Tensor) -> torch.Tensor: | |
r"""Function calculates the average of values over specified durations. | |
Args: | |
values (torch.Tensor): A 3D tensor of shape [B, 1, T_de] where B is the batch size, | |
T_de is the duration of each element in the batch. The values | |
represent some quantity that needs to be averaged over durations. | |
durs (torch.Tensor): A 2D tensor of shape [B, T_en] where B is the batch size, | |
T_en is the number of elements in each batch. The values represent | |
the durations over which the averaging needs to be done. | |
Returns: | |
avg (torch.Tensor): A 3D tensor of shape [B, 1, T_en] where B is the batch size, | |
T_en is the number of elements in each batch. The values represent | |
the average of the input values over the specified durations. | |
Note: | |
The function uses PyTorch operations for efficient computation on GPU. | |
Shapes: | |
- values: :math:`[B, 1, T_de]` | |
- durs: :math:`[B, T_en]` | |
- avg: :math:`[B, 1, T_en]` | |
""" | |
durs_cums_ends = torch.cumsum(durs, dim=1).long() | |
durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) | |
values_nonzero_cums = torch.nn.functional.pad(torch.cumsum(values != 0.0, dim=2), (1, 0)) | |
values_cums = torch.nn.functional.pad(torch.cumsum(values, dim=2), (1, 0)) | |
bs, l = durs_cums_ends.size() | |
n_formants = values.size(1) | |
dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l) | |
dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l) | |
values_sums = (torch.gather(values_cums, 2, dce) - torch.gather(values_cums, 2, dcs)).float() | |
values_nelems = (torch.gather(values_nonzero_cums, 2, dce) - torch.gather(values_nonzero_cums, 2, dcs)).float() | |
avg = torch.where(values_nelems == 0.0, values_nelems, values_sums / values_nelems) | |
return avg | |