File size: 17,833 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
"""The new soundfile backend which will become default in 0.8.0 onward"""
import warnings
from typing import Optional, Tuple

import torch
from torchaudio._internal import module_utils as _mod_utils

from .common import AudioMetaData


_IS_SOUNDFILE_AVAILABLE = False

# TODO: import soundfile only when it is used.
if _mod_utils.is_module_available("soundfile"):
    try:
        import soundfile

        _requires_soundfile = _mod_utils.no_op
        _IS_SOUNDFILE_AVAILABLE = True
    except Exception:
        _requires_soundfile = _mod_utils.fail_with_message(
            "requires soundfile, but we failed to import it. Please check the installation of soundfile."
        )
else:
    _requires_soundfile = _mod_utils.fail_with_message(
        "requires soundfile, but it is not installed. Please install soundfile."
    )


# Mapping from soundfile subtype to number of bits per sample.
# This is mostly heuristical and the value is set to 0 when it is irrelevant
# (lossy formats) or when it can't be inferred.
# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
# the default seems to be 8 bits but it can be compressed further to 4 bits.
# The dict is inspired from
# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
_SUBTYPE_TO_BITS_PER_SAMPLE = {
    "PCM_S8": 8,  # Signed 8 bit data
    "PCM_16": 16,  # Signed 16 bit data
    "PCM_24": 24,  # Signed 24 bit data
    "PCM_32": 32,  # Signed 32 bit data
    "PCM_U8": 8,  # Unsigned 8 bit data (WAV and RAW only)
    "FLOAT": 32,  # 32 bit float data
    "DOUBLE": 64,  # 64 bit float data
    "ULAW": 8,  # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
    "ALAW": 8,  # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
    "IMA_ADPCM": 0,  # IMA ADPCM.
    "MS_ADPCM": 0,  # Microsoft ADPCM.
    "GSM610": 0,  # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
    "VOX_ADPCM": 0,  # OKI / Dialogix ADPCM
    "G721_32": 0,  # 32kbs G721 ADPCM encoding.
    "G723_24": 0,  # 24kbs G723 ADPCM encoding.
    "G723_40": 0,  # 40kbs G723 ADPCM encoding.
    "DWVW_12": 12,  # 12 bit Delta Width Variable Word encoding.
    "DWVW_16": 16,  # 16 bit Delta Width Variable Word encoding.
    "DWVW_24": 24,  # 24 bit Delta Width Variable Word encoding.
    "DWVW_N": 0,  # N bit Delta Width Variable Word encoding.
    "DPCM_8": 8,  # 8 bit differential PCM (XI only)
    "DPCM_16": 16,  # 16 bit differential PCM (XI only)
    "VORBIS": 0,  # Xiph Vorbis encoding. (lossy)
    "ALAC_16": 16,  # Apple Lossless Audio Codec (16 bit).
    "ALAC_20": 20,  # Apple Lossless Audio Codec (20 bit).
    "ALAC_24": 24,  # Apple Lossless Audio Codec (24 bit).
    "ALAC_32": 32,  # Apple Lossless Audio Codec (32 bit).
}


def _get_bit_depth(subtype):
    if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
        warnings.warn(
            f"The {subtype} subtype is unknown to TorchAudio. As a result, the bits_per_sample "
            "attribute will be set to 0. If you are seeing this warning, please "
            "report by opening an issue on github (after checking for existing/closed ones). "
            "You may otherwise ignore this warning."
        )
    return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)


_SUBTYPE_TO_ENCODING = {
    "PCM_S8": "PCM_S",
    "PCM_16": "PCM_S",
    "PCM_24": "PCM_S",
    "PCM_32": "PCM_S",
    "PCM_U8": "PCM_U",
    "FLOAT": "PCM_F",
    "DOUBLE": "PCM_F",
    "ULAW": "ULAW",
    "ALAW": "ALAW",
    "VORBIS": "VORBIS",
}


def _get_encoding(format: str, subtype: str):
    if format == "FLAC":
        return "FLAC"
    return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")


@_requires_soundfile
def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
    """Get signal information of an audio file.



    Note:

        ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts

        ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,

        which has a restriction on type annotation due to TorchScript compiler compatiblity.



    Args:

        filepath (path-like object or file-like object):

            Source of audio data.

        format (str or None, optional):

            Not used. PySoundFile does not accept format hint.



    Returns:

        AudioMetaData: meta data of the given audio.



    """
    sinfo = soundfile.info(filepath)
    return AudioMetaData(
        sinfo.samplerate,
        sinfo.frames,
        sinfo.channels,
        bits_per_sample=_get_bit_depth(sinfo.subtype),
        encoding=_get_encoding(sinfo.format, sinfo.subtype),
    )


_SUBTYPE2DTYPE = {
    "PCM_S8": "int8",
    "PCM_U8": "uint8",
    "PCM_16": "int16",
    "PCM_32": "int32",
    "FLOAT": "float32",
    "DOUBLE": "float64",
}


@_requires_soundfile
def load(

    filepath: str,

    frame_offset: int = 0,

    num_frames: int = -1,

    normalize: bool = True,

    channels_first: bool = True,

    format: Optional[str] = None,

) -> Tuple[torch.Tensor, int]:
    """Load audio data from file.



    Note:

        The formats this function can handle depend on the soundfile installation.

        This function is tested on the following formats;



        * WAV



            * 32-bit floating-point

            * 32-bit signed integer

            * 16-bit signed integer

            * 8-bit unsigned integer



        * FLAC

        * OGG/VORBIS

        * SPHERE



    By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with

    ``float32`` dtype, and the shape of `[channel, time]`.



    .. warning::



       ``normalize`` argument does not perform volume normalization.

       It only converts the sample type to `torch.float32` from the native sample

       type.



       When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit

       signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``,

       this function can return integer Tensor, where the samples are expressed within the whole range

       of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM,

       ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not

       support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors.



       ``normalize`` argument has no effect on 32-bit floating-point WAV and other formats, such as

       ``flac`` and ``mp3``.



       For these formats, this function always returns ``float32`` Tensor with values.



    Note:

        ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts

        ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,

        which has a restriction on type annotation due to TorchScript compiler compatiblity.



    Args:

        filepath (path-like object or file-like object):

            Source of audio data.

        frame_offset (int, optional):

            Number of frames to skip before start reading data.

        num_frames (int, optional):

            Maximum number of frames to read. ``-1`` reads all the remaining samples,

            starting from ``frame_offset``.

            This function may return the less number of frames if there is not enough

            frames in the given file.

        normalize (bool, optional):

            When ``True``, this function converts the native sample type to ``float32``.

            Default: ``True``.



            If input file is integer WAV, giving ``False`` will change the resulting Tensor type to

            integer type.

            This argument has no effect for formats other than integer WAV type.



        channels_first (bool, optional):

            When True, the returned Tensor has dimension `[channel, time]`.

            Otherwise, the returned Tensor's dimension is `[time, channel]`.

        format (str or None, optional):

            Not used. PySoundFile does not accept format hint.



    Returns:

        (torch.Tensor, int): Resulting Tensor and sample rate.

            If the input file has integer wav format and normalization is off, then it has

            integer type, else ``float32`` type. If ``channels_first=True``, it has

            `[channel, time]` else `[time, channel]`.

    """
    with soundfile.SoundFile(filepath, "r") as file_:
        if file_.format != "WAV" or normalize:
            dtype = "float32"
        elif file_.subtype not in _SUBTYPE2DTYPE:
            raise ValueError(f"Unsupported subtype: {file_.subtype}")
        else:
            dtype = _SUBTYPE2DTYPE[file_.subtype]

        frames = file_._prepare_read(frame_offset, None, num_frames)
        waveform = file_.read(frames, dtype, always_2d=True)
        sample_rate = file_.samplerate

    waveform = torch.from_numpy(waveform)
    if channels_first:
        waveform = waveform.t()
    return waveform, sample_rate


def _get_subtype_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int):
    if not encoding:
        if not bits_per_sample:
            subtype = {
                torch.uint8: "PCM_U8",
                torch.int16: "PCM_16",
                torch.int32: "PCM_32",
                torch.float32: "FLOAT",
                torch.float64: "DOUBLE",
            }.get(dtype)
            if not subtype:
                raise ValueError(f"Unsupported dtype for wav: {dtype}")
            return subtype
        if bits_per_sample == 8:
            return "PCM_U8"
        return f"PCM_{bits_per_sample}"
    if encoding == "PCM_S":
        if not bits_per_sample:
            return "PCM_32"
        if bits_per_sample == 8:
            raise ValueError("wav does not support 8-bit signed PCM encoding.")
        return f"PCM_{bits_per_sample}"
    if encoding == "PCM_U":
        if bits_per_sample in (None, 8):
            return "PCM_U8"
        raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
    if encoding == "PCM_F":
        if bits_per_sample in (None, 32):
            return "FLOAT"
        if bits_per_sample == 64:
            return "DOUBLE"
        raise ValueError("wav only supports 32/64-bit float PCM encoding.")
    if encoding == "ULAW":
        if bits_per_sample in (None, 8):
            return "ULAW"
        raise ValueError("wav only supports 8-bit mu-law encoding.")
    if encoding == "ALAW":
        if bits_per_sample in (None, 8):
            return "ALAW"
        raise ValueError("wav only supports 8-bit a-law encoding.")
    raise ValueError(f"wav does not support {encoding}.")


def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
    if encoding in (None, "PCM_S"):
        return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
    if encoding in ("PCM_U", "PCM_F"):
        raise ValueError(f"sph does not support {encoding} encoding.")
    if encoding == "ULAW":
        if bits_per_sample in (None, 8):
            return "ULAW"
        raise ValueError("sph only supports 8-bit for mu-law encoding.")
    if encoding == "ALAW":
        return "ALAW"
    raise ValueError(f"sph does not support {encoding}.")


def _get_subtype(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int):
    if format == "wav":
        return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
    if format == "flac":
        if encoding:
            raise ValueError("flac does not support encoding.")
        if not bits_per_sample:
            return "PCM_16"
        if bits_per_sample > 24:
            raise ValueError("flac does not support bits_per_sample > 24.")
        return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
    if format in ("ogg", "vorbis"):
        if bits_per_sample:
            raise ValueError("ogg/vorbis does not support bits_per_sample.")
        if encoding is None or encoding == "vorbis":
            return "VORBIS"
        if encoding == "opus":
            return "OPUS"
        raise ValueError(f"Unexpected encoding: {encoding}")
    if format == "mp3":
        return "MPEG_LAYER_III"
    if format == "sph":
        return _get_subtype_for_sphere(encoding, bits_per_sample)
    if format in ("nis", "nist"):
        return "PCM_16"
    raise ValueError(f"Unsupported format: {format}")


@_requires_soundfile
def save(

    filepath: str,

    src: torch.Tensor,

    sample_rate: int,

    channels_first: bool = True,

    compression: Optional[float] = None,

    format: Optional[str] = None,

    encoding: Optional[str] = None,

    bits_per_sample: Optional[int] = None,

):
    """Save audio data to file.



    Note:

        The formats this function can handle depend on the soundfile installation.

        This function is tested on the following formats;



        * WAV



            * 32-bit floating-point

            * 32-bit signed integer

            * 16-bit signed integer

            * 8-bit unsigned integer



        * FLAC

        * OGG/VORBIS

        * SPHERE



    Note:

        ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts

        ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,

        which has a restriction on type annotation due to TorchScript compiler compatiblity.



    Args:

        filepath (str or pathlib.Path): Path to audio file.

        src (torch.Tensor): Audio data to save. must be 2D tensor.

        sample_rate (int): sampling rate

        channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,

            otherwise `[time, channel]`.

        compression (float of None, optional): Not used.

            It is here only for interface compatibility reson with "sox_io" backend.

        format (str or None, optional): Override the audio format.

            When ``filepath`` argument is path-like object, audio format is

            inferred from file extension. If the file extension is missing or

            different, you can specify the correct format with this argument.



            When ``filepath`` argument is file-like object,

            this argument is required.



            Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,

            ``"flac"`` and ``"sph"``.

        encoding (str or None, optional): Changes the encoding for supported formats.

            This argument is effective only for supported formats, sush as

            ``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;



                - ``"PCM_S"`` (signed integer Linear PCM)

                - ``"PCM_U"`` (unsigned integer Linear PCM)

                - ``"PCM_F"`` (floating point PCM)

                - ``"ULAW"`` (mu-law)

                - ``"ALAW"`` (a-law)



        bits_per_sample (int or None, optional): Changes the bit depth for the

            supported formats.

            When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,

            you can change the bit depth.

            Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.



    Supported formats/encodings/bit depth/compression are:



    ``"wav"``

        - 32-bit floating-point PCM

        - 32-bit signed integer PCM

        - 24-bit signed integer PCM

        - 16-bit signed integer PCM

        - 8-bit unsigned integer PCM

        - 8-bit mu-law

        - 8-bit a-law



        Note:

            Default encoding/bit depth is determined by the dtype of

            the input Tensor.



    ``"flac"``

        - 8-bit

        - 16-bit (default)

        - 24-bit



    ``"ogg"``, ``"vorbis"``

        - Doesn't accept changing configuration.



    ``"sph"``

        - 8-bit signed integer PCM

        - 16-bit signed integer PCM

        - 24-bit signed integer PCM

        - 32-bit signed integer PCM (default)

        - 8-bit mu-law

        - 8-bit a-law

        - 16-bit a-law

        - 24-bit a-law

        - 32-bit a-law



    """
    if src.ndim != 2:
        raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
    if compression is not None:
        warnings.warn(
            '`save` function of "soundfile" backend does not support "compression" parameter. '
            "The argument is silently ignored."
        )
    if hasattr(filepath, "write"):
        if format is None:
            raise RuntimeError("`format` is required when saving to file object.")
        ext = format.lower()
    else:
        ext = str(filepath).split(".")[-1].lower()

    if bits_per_sample not in (None, 8, 16, 24, 32, 64):
        raise ValueError("Invalid bits_per_sample.")
    if bits_per_sample == 24:
        warnings.warn(
            "Saving audio with 24 bits per sample might warp samples near -1. "
            "Using 16 bits per sample might be able to avoid this."
        )
    subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)

    # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
    # so we extend the extensions manually here
    if ext in ["nis", "nist", "sph"] and format is None:
        format = "NIST"

    if channels_first:
        src = src.t()

    soundfile.write(file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format)