import os import math import random import torchaudio from io import IOBase from torch.nn.functional import pad def get_torchaudio_info(file, backend = None): if not backend: backends = (torchaudio.list_audio_backends()) backend = "soundfile" if "soundfile" in backends else backends[0] info = torchaudio.info(file["audio"], backend=backend) if isinstance(file["audio"], IOBase): file["audio"].seek(0) return info class Audio: @staticmethod def power_normalize(waveform): return waveform / (waveform.square().mean(dim=-1, keepdim=True).sqrt() + 1e-8) @staticmethod def validate_file(file): if isinstance(file, (str, os.PathLike)): file = {"audio": str(file), "uri": os.path.splitext(os.path.basename(file))[0]} elif isinstance(file, IOBase): return {"audio": file, "uri": "stream"} else: raise ValueError if "waveform" in file: waveform = file["waveform"] if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]: raise ValueError sample_rate: int = file.get("sample_rate", None) if sample_rate is None: raise ValueError file.setdefault("uri", "waveform") elif "audio" in file: if isinstance(file["audio"], IOBase): return file path = os.path.abspath(file["audio"]) file.setdefault("uri", os.path.splitext(os.path.basename(path))[0]) else: raise ValueError return file def __init__(self, sample_rate: int = None, mono=None, backend: str = None): super().__init__() self.sample_rate = sample_rate self.mono = mono if not backend: backends = (torchaudio.list_audio_backends()) backend = "soundfile" if "soundfile" in backends else backends[0] self.backend = backend def downmix_and_resample(self, waveform, sample_rate): num_channels = waveform.shape[0] if num_channels > 1: if self.mono == "random": channel = random.randint(0, num_channels - 1) waveform = waveform[channel : channel + 1] elif self.mono == "downmix": waveform = waveform.mean(dim=0, keepdim=True) if (self.sample_rate is not None) and (self.sample_rate != sample_rate): waveform = torchaudio.functional.resample(waveform, sample_rate, self.sample_rate) sample_rate = self.sample_rate return waveform, sample_rate def get_duration(self, file): file = self.validate_file(file) if "waveform" in file: frames = len(file["waveform"].T) sample_rate = file["sample_rate"] else: info = file["torchaudio.info"] if "torchaudio.info" in file else get_torchaudio_info(file, backend=self.backend) frames = info.num_frames sample_rate = info.sample_rate return frames / sample_rate def get_num_samples(self, duration, sample_rate = None): sample_rate = sample_rate or self.sample_rate if sample_rate is None: raise ValueError return math.floor(duration * sample_rate) def __call__(self, file): file = self.validate_file(file) if "waveform" in file: waveform = file["waveform"] sample_rate = file["sample_rate"] elif "audio" in file: waveform, sample_rate = torchaudio.load(file["audio"], backend=self.backend) if isinstance(file["audio"], IOBase): file["audio"].seek(0) channel = file.get("channel", None) if channel is not None: waveform = waveform[channel : channel + 1] return self.downmix_and_resample(waveform, sample_rate) def crop(self, file, segment, duration = None, mode="raise"): file = self.validate_file(file) if "waveform" in file: waveform = file["waveform"] frames = waveform.shape[1] sample_rate = file["sample_rate"] elif "torchaudio.info" in file: info = file["torchaudio.info"] frames = info.num_frames sample_rate = info.sample_rate else: info = get_torchaudio_info(file, backend=self.backend) frames = info.num_frames sample_rate = info.sample_rate channel = file.get("channel", None) start_frame = math.floor(segment.start * sample_rate) if duration: num_frames = math.floor(duration * sample_rate) end_frame = start_frame + num_frames else: end_frame = math.floor(segment.end * sample_rate) num_frames = end_frame - start_frame if mode == "raise": if num_frames > frames: raise ValueError if end_frame > frames + math.ceil(0.001 * sample_rate): raise ValueError else: end_frame = min(end_frame, frames) start_frame = end_frame - num_frames if start_frame < 0: raise ValueError elif mode == "pad": pad_start = -min(0, start_frame) pad_end = max(end_frame, frames) - frames start_frame = max(0, start_frame) end_frame = min(end_frame, frames) num_frames = end_frame - start_frame if "waveform" in file: data = file["waveform"][:, start_frame:end_frame] else: try: data, _ = torchaudio.load(file["audio"], frame_offset=start_frame, num_frames=num_frames, backend=self.backend) if isinstance(file["audio"], IOBase): file["audio"].seek(0) except RuntimeError: if isinstance(file["audio"], IOBase): raise RuntimeError waveform, sample_rate = self.__call__(file) data = waveform[:, start_frame:end_frame] file["waveform"] = waveform file["sample_rate"] = sample_rate if channel is not None: data = data[channel : channel + 1, :] if mode == "pad": data = pad(data, (pad_start, pad_end)) return self.downmix_and_resample(data, sample_rate)