File size: 2,027 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch


def average_over_durations(values: torch.Tensor, durs: torch.Tensor) -> torch.Tensor:
    r"""Function calculates the average of values over specified durations.

    Args:
    values (torch.Tensor): A 3D tensor of shape [B, 1, T_de] where B is the batch size,
                           T_de is the duration of each element in the batch. The values
                           represent some quantity that needs to be averaged over durations.
    durs (torch.Tensor): A 2D tensor of shape [B, T_en] where B is the batch size,
                         T_en is the number of elements in each batch. The values represent
                         the durations over which the averaging needs to be done.

    Returns:
    avg (torch.Tensor): A 3D tensor of shape [B, 1, T_en] where B is the batch size,
                        T_en is the number of elements in each batch. The values represent
                        the average of the input values over the specified durations.

    Note:
    The function uses PyTorch operations for efficient computation on GPU.

    Shapes:
        - values: :math:`[B, 1, T_de]`
        - durs: :math:`[B, T_en]`
        - avg: :math:`[B, 1, T_en]`
    """
    durs_cums_ends = torch.cumsum(durs, dim=1).long()
    durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0))
    values_nonzero_cums = torch.nn.functional.pad(torch.cumsum(values != 0.0, dim=2), (1, 0))
    values_cums = torch.nn.functional.pad(torch.cumsum(values, dim=2), (1, 0))

    bs, l = durs_cums_ends.size()
    n_formants = values.size(1)
    dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l)
    dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l)

    values_sums = (torch.gather(values_cums, 2, dce) - torch.gather(values_cums, 2, dcs)).float()
    values_nelems = (torch.gather(values_nonzero_cums, 2, dce) - torch.gather(values_nonzero_cums, 2, dcs)).float()

    avg = torch.where(values_nelems == 0.0, values_nelems, values_sums / values_nelems)
    return avg