File size: 1,548 Bytes
aeda668
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import BinaryIO, Union

import torch
import torchaudio


def decode_audio(
    input_file: Union[str, BinaryIO],
    sampling_rate: int = 16000,
    split_stereo: bool = False,
):
    """Decodes the audio.

    Args:
      input_file: Path to the input file or a file-like object.
      sampling_rate: Resample the audio to this sample rate.
      split_stereo: Return separate left and right channels.

    Returns:
      A float32 Torch Tensor.

      If `split_stereo` is enabled, the function returns a 2-tuple with the
      separated left and right channels.
    """

    waveform, audio_sf = torchaudio.load(input_file)  # waveform: channels X T

    if audio_sf != sampling_rate:
        waveform = torchaudio.functional.resample(
            waveform, orig_freq=audio_sf, new_freq=sampling_rate
        )
    if split_stereo:
        return waveform[0], waveform[1]

    return waveform.mean(0)


def pad_or_trim(array, length: int, *, axis: int = -1):
    """
    Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
    """
    axis = axis % array.ndim
    if array.shape[axis] > length:
        idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
        return array[idx]

    if array.shape[axis] < length:
        pad_widths = (
            [
                0,
            ]
            * array.ndim
            * 2
        )
        pad_widths[2 * axis] = length - array.shape[axis]
        array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))

    return array