Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Audio IO methods are defined in this module (info, read, write), | |
We rely on av library for faster read when possible, otherwise on torchaudio. | |
""" | |
from dataclasses import dataclass | |
from pathlib import Path | |
import logging | |
import typing as tp | |
import numpy as np | |
import soundfile | |
import torch | |
from torch.nn import functional as F | |
import torchaudio as ta | |
import av | |
from .audio_utils import f32_pcm, i16_pcm, normalize_audio | |
_av_initialized = False | |
def _init_av(): | |
global _av_initialized | |
if _av_initialized: | |
return | |
logger = logging.getLogger('libav.mp3') | |
logger.setLevel(logging.ERROR) | |
_av_initialized = True | |
class AudioFileInfo: | |
sample_rate: int | |
duration: float | |
channels: int | |
def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: | |
_init_av() | |
with av.open(str(filepath)) as af: | |
stream = af.streams.audio[0] | |
sample_rate = stream.codec_context.sample_rate | |
duration = float(stream.duration * stream.time_base) | |
channels = stream.channels | |
return AudioFileInfo(sample_rate, duration, channels) | |
def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: | |
info = soundfile.info(filepath) | |
return AudioFileInfo(info.samplerate, info.duration, info.channels) | |
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: | |
# torchaudio no longer returns useful duration informations for some formats like mp3s. | |
filepath = Path(filepath) | |
if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info | |
# ffmpeg has some weird issue with flac. | |
return _soundfile_info(filepath) | |
else: | |
return _av_info(filepath) | |
def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]: | |
"""FFMPEG-based audio file reading using PyAV bindings. | |
Soundfile cannot read mp3 and av_read is more efficient than torchaudio. | |
Args: | |
filepath (str or Path): Path to audio file to read. | |
seek_time (float): Time at which to start reading in the file. | |
duration (float): Duration to read from the file. If set to -1, the whole file is read. | |
Returns: | |
Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate | |
""" | |
_init_av() | |
with av.open(str(filepath)) as af: | |
stream = af.streams.audio[0] | |
sr = stream.codec_context.sample_rate | |
num_frames = int(sr * duration) if duration >= 0 else -1 | |
frame_offset = int(sr * seek_time) | |
# we need a small negative offset otherwise we get some edge artifact | |
# from the mp3 decoder. | |
af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream) | |
frames = [] | |
length = 0 | |
for frame in af.decode(streams=stream.index): | |
current_offset = int(frame.rate * frame.pts * frame.time_base) | |
strip = max(0, frame_offset - current_offset) | |
buf = torch.from_numpy(frame.to_ndarray()) | |
if buf.shape[0] != stream.channels: | |
buf = buf.view(-1, stream.channels).t() | |
buf = buf[:, strip:] | |
frames.append(buf) | |
length += buf.shape[1] | |
if num_frames > 0 and length >= num_frames: | |
break | |
assert frames | |
# If the above assert fails, it is likely because we seeked past the end of file point, | |
# in which case ffmpeg returns a single frame with only zeros, and a weird timestamp. | |
# This will need proper debugging, in due time. | |
wav = torch.cat(frames, dim=1) | |
assert wav.shape[0] == stream.channels | |
if num_frames > 0: | |
wav = wav[:, :num_frames] | |
return f32_pcm(wav), sr | |
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., | |
duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]: | |
"""Read audio by picking the most appropriate backend tool based on the audio format. | |
Args: | |
filepath (str or Path): Path to audio file to read. | |
seek_time (float): Time at which to start reading in the file. | |
duration (float): Duration to read from the file. If set to -1, the whole file is read. | |
pad (bool): Pad output audio if not reaching expected duration. | |
Returns: | |
Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate. | |
""" | |
fp = Path(filepath) | |
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg | |
# There is some bug with ffmpeg and reading flac | |
info = _soundfile_info(filepath) | |
frames = -1 if duration <= 0 else int(duration * info.sample_rate) | |
frame_offset = int(seek_time * info.sample_rate) | |
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32) | |
assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}" | |
wav = torch.from_numpy(wav).t().contiguous() | |
if len(wav.shape) == 1: | |
wav = torch.unsqueeze(wav, 0) | |
elif ( | |
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats() | |
and duration <= 0 and seek_time == 0 | |
): | |
# Torchaudio is faster if we load an entire file at once. | |
wav, sr = ta.load(fp) | |
else: | |
wav, sr = _av_read(filepath, seek_time, duration) | |
if pad and duration > 0: | |
expected_frames = int(duration * sr) | |
wav = F.pad(wav, (0, expected_frames - wav.shape[-1])) | |
return wav, sr | |
def audio_write(stem_name: tp.Union[str, Path], | |
wav: torch.Tensor, sample_rate: int, | |
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True, | |
strategy: str = 'peak', peak_clip_headroom_db: float = 1, | |
rms_headroom_db: float = 18, loudness_headroom_db: float = 14, | |
log_clipping: bool = True, make_parent_dir: bool = True, | |
add_suffix: bool = True) -> Path: | |
"""Convenience function for saving audio to disk. Returns the filename the audio was written to. | |
Args: | |
stem_name (str or Path): Filename without extension which will be added automatically. | |
format (str): Either "wav" or "mp3". | |
mp3_rate (int): kbps when using mp3s. | |
normalize (bool): if `True` (default), normalizes according to the prescribed | |
strategy (see after). If `False`, the strategy is only used in case clipping | |
would happen. | |
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', | |
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square | |
with extra headroom to avoid clipping. 'clip' just clips. | |
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. | |
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger | |
than the `peak_clip` one to avoid further clipping. | |
loudness_headroom_db (float): Target loudness for loudness normalization. | |
log_clipping (bool): If True, basic logging on stderr when clipping still | |
occurs despite strategy (only for 'rms'). | |
make_parent_dir (bool): Make parent directory if it doesn't exist. | |
Returns: | |
Path: Path of the saved audio. | |
""" | |
assert wav.dtype.is_floating_point, "wav is not floating point" | |
if wav.dim() == 1: | |
wav = wav[None] | |
elif wav.dim() > 2: | |
raise ValueError("Input wav should be at most 2 dimension.") | |
assert wav.isfinite().all() | |
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db, | |
rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping, | |
sample_rate=sample_rate, stem_name=str(stem_name)) | |
kwargs: dict = {} | |
if format == 'mp3': | |
suffix = '.mp3' | |
kwargs.update({"compression": mp3_rate}) | |
elif format == 'wav': | |
wav = i16_pcm(wav) | |
suffix = '.wav' | |
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16}) | |
else: | |
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.") | |
if not add_suffix: | |
suffix = '' | |
path = Path(str(stem_name) + suffix) | |
if make_parent_dir: | |
path.parent.mkdir(exist_ok=True, parents=True) | |
try: | |
ta.save(path, wav, sample_rate, **kwargs) | |
except Exception: | |
if path.exists(): | |
# we do not want to leave half written files around. | |
path.unlink() | |
raise | |
return path | |