Spaces:
Running
Running
import os | |
import re | |
import sys | |
from typing import BinaryIO, Optional, Tuple, Union | |
import torch | |
import torchaudio | |
from .backend import Backend | |
from .common import AudioMetaData | |
InputType = Union[BinaryIO, str, os.PathLike] | |
def info_audio( | |
src: InputType, | |
format: Optional[str], | |
buffer_size: int = 4096, | |
) -> AudioMetaData: | |
s = torchaudio.io.StreamReader(src, format, None, buffer_size) | |
sinfo = s.get_src_stream_info(s.default_audio_stream) | |
if sinfo.num_frames == 0: | |
waveform = _load_audio(s) | |
num_frames = waveform.size(1) | |
else: | |
num_frames = sinfo.num_frames | |
return AudioMetaData( | |
int(sinfo.sample_rate), | |
num_frames, | |
sinfo.num_channels, | |
sinfo.bits_per_sample, | |
sinfo.codec.upper(), | |
) | |
def _get_load_filter( | |
frame_offset: int = 0, | |
num_frames: int = -1, | |
convert: bool = True, | |
) -> Optional[str]: | |
if frame_offset < 0: | |
raise RuntimeError("Invalid argument: frame_offset must be non-negative. Found: {}".format(frame_offset)) | |
if num_frames == 0 or num_frames < -1: | |
raise RuntimeError("Invalid argument: num_frames must be -1 or greater than 0. Found: {}".format(num_frames)) | |
# All default values -> no filter | |
if frame_offset == 0 and num_frames == -1 and not convert: | |
return None | |
# Only convert | |
aformat = "aformat=sample_fmts=fltp" | |
if frame_offset == 0 and num_frames == -1 and convert: | |
return aformat | |
# At least one of frame_offset or num_frames has non-default value | |
if num_frames > 0: | |
atrim = "atrim=start_sample={}:end_sample={}".format(frame_offset, frame_offset + num_frames) | |
else: | |
atrim = "atrim=start_sample={}".format(frame_offset) | |
if not convert: | |
return atrim | |
return "{},{}".format(atrim, aformat) | |
def _load_audio( | |
s: "torchaudio.io.StreamReader", | |
filter: Optional[str] = None, | |
channels_first: bool = True, | |
) -> torch.Tensor: | |
s.add_audio_stream(-1, -1, filter_desc=filter) | |
s.process_all_packets() | |
chunk = s.pop_chunks()[0] | |
if chunk is None: | |
raise RuntimeError("Failed to decode audio.") | |
waveform = chunk._elem | |
return waveform.T if channels_first else waveform | |
def load_audio( | |
src: InputType, | |
frame_offset: int = 0, | |
num_frames: int = -1, | |
convert: bool = True, | |
channels_first: bool = True, | |
format: Optional[str] = None, | |
buffer_size: int = 4096, | |
) -> Tuple[torch.Tensor, int]: | |
if hasattr(src, "read") and format == "vorbis": | |
format = "ogg" | |
s = torchaudio.io.StreamReader(src, format, None, buffer_size) | |
sample_rate = int(s.get_src_stream_info(s.default_audio_stream).sample_rate) | |
filter = _get_load_filter(frame_offset, num_frames, convert) | |
waveform = _load_audio(s, filter, channels_first) | |
return waveform, sample_rate | |
def _get_sample_format(dtype: torch.dtype) -> str: | |
dtype_to_format = { | |
torch.uint8: "u8", | |
torch.int16: "s16", | |
torch.int32: "s32", | |
torch.int64: "s64", | |
torch.float32: "flt", | |
torch.float64: "dbl", | |
} | |
format = dtype_to_format.get(dtype) | |
if format is None: | |
raise ValueError(f"No format found for dtype {dtype}; dtype must be one of {list(dtype_to_format.keys())}.") | |
return format | |
def _native_endianness() -> str: | |
if sys.byteorder == "little": | |
return "le" | |
else: | |
return "be" | |
def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str: | |
if bits_per_sample not in {None, 8, 16, 24, 32, 64}: | |
raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.") | |
endianness = _native_endianness() | |
if not encoding: | |
if not bits_per_sample: | |
# default to PCM S16 | |
return f"pcm_s16{endianness}" | |
if bits_per_sample == 8: | |
return "pcm_u8" | |
return f"pcm_s{bits_per_sample}{endianness}" | |
if encoding == "PCM_S": | |
if not bits_per_sample: | |
bits_per_sample = 16 | |
if bits_per_sample == 8: | |
raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.") | |
return f"pcm_s{bits_per_sample}{endianness}" | |
if encoding == "PCM_U": | |
if bits_per_sample in (None, 8): | |
return "pcm_u8" | |
raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.") | |
if encoding == "PCM_F": | |
if not bits_per_sample: | |
bits_per_sample = 32 | |
if bits_per_sample in (32, 64): | |
return f"pcm_f{bits_per_sample}{endianness}" | |
raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.") | |
if encoding == "ULAW": | |
if bits_per_sample in (None, 8): | |
return "pcm_mulaw" | |
raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.") | |
if encoding == "ALAW": | |
if bits_per_sample in (None, 8): | |
return "pcm_alaw" | |
raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.") | |
raise ValueError(f"WAV encoding {encoding} is not supported.") | |
def _get_flac_sample_fmt(bps): | |
if bps is None or bps == 16: | |
return "s16" | |
if bps == 24: | |
return "s32" | |
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).") | |
def _parse_save_args( | |
ext: Optional[str], | |
format: Optional[str], | |
encoding: Optional[str], | |
bps: Optional[int], | |
): | |
# torchaudio's save function accepts the followings, which do not 1to1 map | |
# to FFmpeg. | |
# | |
# - format: audio format | |
# - bits_per_sample: encoder sample format | |
# - encoding: such as PCM_U8. | |
# | |
# In FFmpeg, format is specified with the following three (and more) | |
# | |
# - muxer: could be audio format or container format. | |
# the one we passed to the constructor of StreamWriter | |
# - encoder: the audio encoder used to encode audio | |
# - encoder sample format: the format used by encoder to encode audio. | |
# | |
# If encoder sample format is different from source sample format, StreamWriter | |
# will insert a filter automatically. | |
# | |
def _type(spec): | |
# either format is exactly the specified one | |
# or extension matches to the spec AND there is no format override. | |
return format == spec or (format is None and ext == spec) | |
if _type("wav") or _type("amb"): | |
# wav is special because it supports different encoding through encoders | |
# each encoder only supports one encoder format | |
# | |
# amb format is a special case originated from libsox. | |
# It is basically a WAV format, with slight modification. | |
# https://github.com/chirlu/sox/commit/4a4ea33edbca5972a1ed8933cc3512c7302fa67a#diff-39171191a858add9df87f5f210a34a776ac2c026842ae6db6ce97f5e68836795 | |
# It is a format so that decoders will recognize it as ambisonic. | |
# https://www.ambisonia.com/Members/mleese/file-format-for-b-format/ | |
# FFmpeg does not recognize amb because it is basically a WAV format. | |
muxer = "wav" | |
encoder = _get_encoder_for_wav(encoding, bps) | |
sample_fmt = None | |
elif _type("vorbis"): | |
# FFpmeg does not recognize vorbis extension, while libsox used to do. | |
# For the sake of bakward compatibility, (and the simplicity), | |
# we support the case where users want to do save("foo.vorbis") | |
muxer = "ogg" | |
encoder = "vorbis" | |
sample_fmt = None | |
else: | |
muxer = format | |
encoder = None | |
sample_fmt = None | |
if _type("flac"): | |
sample_fmt = _get_flac_sample_fmt(bps) | |
if _type("ogg"): | |
sample_fmt = _get_flac_sample_fmt(bps) | |
return muxer, encoder, sample_fmt | |
def save_audio( | |
uri: InputType, | |
src: torch.Tensor, | |
sample_rate: int, | |
channels_first: bool = True, | |
format: Optional[str] = None, | |
encoding: Optional[str] = None, | |
bits_per_sample: Optional[int] = None, | |
buffer_size: int = 4096, | |
compression: Optional[torchaudio.io.CodecConfig] = None, | |
) -> None: | |
ext = None | |
if hasattr(uri, "write"): | |
if format is None: | |
raise RuntimeError("'format' is required when saving to file object.") | |
else: | |
uri = os.path.normpath(uri) | |
if tokens := str(uri).split(".")[1:]: | |
ext = tokens[-1].lower() | |
muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample) | |
if channels_first: | |
src = src.T | |
s = torchaudio.io.StreamWriter(uri, format=muxer, buffer_size=buffer_size) | |
s.add_audio_stream( | |
sample_rate, | |
num_channels=src.size(-1), | |
format=_get_sample_format(src.dtype), | |
encoder=encoder, | |
encoder_format=enc_fmt, | |
codec_config=compression, | |
) | |
with s.open(): | |
s.write_audio_chunk(0, src) | |
def _map_encoding(encoding: str) -> str: | |
for dst in ["PCM_S", "PCM_U", "PCM_F"]: | |
if dst in encoding: | |
return dst | |
if encoding == "PCM_MULAW": | |
return "ULAW" | |
elif encoding == "PCM_ALAW": | |
return "ALAW" | |
return encoding | |
def _get_bits_per_sample(encoding: str, bits_per_sample: int) -> str: | |
if m := re.search(r"PCM_\w(\d+)\w*", encoding): | |
return int(m.group(1)) | |
elif encoding in ["PCM_ALAW", "PCM_MULAW"]: | |
return 8 | |
return bits_per_sample | |
class FFmpegBackend(Backend): | |
def info(uri: InputType, format: Optional[str], buffer_size: int = 4096) -> AudioMetaData: | |
metadata = info_audio(uri, format, buffer_size) | |
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample) | |
metadata.encoding = _map_encoding(metadata.encoding) | |
return metadata | |
def load( | |
uri: InputType, | |
frame_offset: int = 0, | |
num_frames: int = -1, | |
normalize: bool = True, | |
channels_first: bool = True, | |
format: Optional[str] = None, | |
buffer_size: int = 4096, | |
) -> Tuple[torch.Tensor, int]: | |
return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format) | |
def save( | |
uri: InputType, | |
src: torch.Tensor, | |
sample_rate: int, | |
channels_first: bool = True, | |
format: Optional[str] = None, | |
encoding: Optional[str] = None, | |
bits_per_sample: Optional[int] = None, | |
buffer_size: int = 4096, | |
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None, | |
) -> None: | |
if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))): | |
raise ValueError( | |
"FFmpeg backend expects non-`None` value for argument `compression` to be of ", | |
f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}", | |
) | |
save_audio( | |
uri, | |
src, | |
sample_rate, | |
channels_first, | |
format, | |
encoding, | |
bits_per_sample, | |
buffer_size, | |
compression, | |
) | |
def can_decode(uri: InputType, format: Optional[str]) -> bool: | |
return True | |
def can_encode(uri: InputType, format: Optional[str]) -> bool: | |
return True | |