import os import torch import torchaudio from functools import wraps from types import SimpleNamespace from torch.nn import SyncBatchNorm from hyperpyyaml import load_hyperpyyaml from torch.nn import DataParallel as DP from torch.nn.parallel import DistributedDataParallel as DDP MAIN_PROC_ONLY = 0 def fetch(filename, source): return os.path.abspath(os.path.join(source, filename)) def run_on_main(func, args=None, kwargs=None, post_func=None, post_args=None, post_kwargs=None, run_post_on_main=False): if args is None: args = [] if kwargs is None: kwargs = {} if post_args is None: post_args = [] if post_kwargs is None: post_kwargs = {} main_process_only(func)(*args, **kwargs) ddp_barrier() if post_func is not None: if run_post_on_main: post_func(*post_args, **post_kwargs) else: if not if_main_process(): post_func(*post_args, **post_kwargs) ddp_barrier() def is_distributed_initialized(): return (torch.distributed.is_available() and torch.distributed.is_initialized()) def if_main_process(): if is_distributed_initialized(): return torch.distributed.get_rank() == 0 else: return True class MainProcessContext: def __enter__(self): global MAIN_PROC_ONLY MAIN_PROC_ONLY += 1 return self def __exit__(self, exc_type, exc_value, traceback): global MAIN_PROC_ONLY MAIN_PROC_ONLY -= 1 def main_process_only(function): @wraps(function) def main_proc_wrapped_func(*args, **kwargs): with MainProcessContext(): return function(*args, **kwargs) if if_main_process() else None return main_proc_wrapped_func def ddp_barrier(): if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized(): return if torch.distributed.get_backend() == torch.distributed.Backend.NCCL: torch.distributed.barrier(device_ids=[torch.cuda.current_device()]) else: torch.distributed.barrier() class Resample(torch.nn.Module): def __init__(self, orig_freq=16000, new_freq=16000, *args, **kwargs): super().__init__() self.orig_freq = orig_freq self.new_freq = new_freq self.resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq, *args, **kwargs) def forward(self, waveforms): if self.orig_freq == self.new_freq: return waveforms unsqueezed = False if len(waveforms.shape) == 2: waveforms = waveforms.unsqueeze(1) unsqueezed = True elif len(waveforms.shape) == 3: waveforms = waveforms.transpose(1, 2) else: raise ValueError self.resampler.to(waveforms.device) resampled_waveform = self.resampler(waveforms) return resampled_waveform.squeeze(1) if unsqueezed else resampled_waveform.transpose(1, 2) class AudioNormalizer: def __init__(self, sample_rate=16000, mix="avg-to-mono"): self.sample_rate = sample_rate if mix not in ["avg-to-mono", "keep"]: raise ValueError self.mix = mix self._cached_resamplers = {} def __call__(self, audio, sample_rate): if sample_rate not in self._cached_resamplers: self._cached_resamplers[sample_rate] = Resample(sample_rate, self.sample_rate) return self._mix(self._cached_resamplers[sample_rate](audio.unsqueeze(0)).squeeze(0)) def _mix(self, audio): flat_input = audio.dim() == 1 if self.mix == "avg-to-mono": if flat_input: return audio return torch.mean(audio, 1) if self.mix == "keep": return audio class Pretrained(torch.nn.Module): HPARAMS_NEEDED, MODULES_NEEDED = [], [] def __init__(self, modules=None, hparams=None, run_opts=None, freeze_params=True): super().__init__() for arg, default in {"device": "cpu", "data_parallel_count": -1, "data_parallel_backend": False, "distributed_launch": False, "distributed_backend": "nccl", "jit": False, "jit_module_keys": None, "compile": False, "compile_module_keys": None, "compile_mode": "reduce-overhead", "compile_using_fullgraph": False, "compile_using_dynamic_shape_tracing": False}.items(): if run_opts is not None and arg in run_opts: setattr(self, arg, run_opts[arg]) elif hparams is not None and arg in hparams: setattr(self, arg, hparams[arg]) else: setattr(self, arg, default) self.mods = torch.nn.ModuleDict(modules) for module in self.mods.values(): if module is not None: module.to(self.device) if self.HPARAMS_NEEDED and hparams is None: raise ValueError if hparams is not None: for hp in self.HPARAMS_NEEDED: if hp not in hparams: raise ValueError self.hparams = SimpleNamespace(**hparams) self._prepare_modules(freeze_params) self.audio_normalizer = hparams.get("audio_normalizer", AudioNormalizer()) def _prepare_modules(self, freeze_params): self._compile() self._wrap_distributed() if freeze_params: self.mods.eval() for p in self.mods.parameters(): p.requires_grad = False def _compile(self): compile_available = hasattr(torch, "compile") if not compile_available and self.compile_module_keys is not None: raise ValueError compile_module_keys = set() if self.compile: compile_module_keys = set(self.mods) if self.compile_module_keys is None else set(self.compile_module_keys) jit_module_keys = set() if self.jit: jit_module_keys = set(self.mods) if self.jit_module_keys is None else set(self.jit_module_keys) for name in compile_module_keys | jit_module_keys: if name not in self.mods: raise ValueError for name in compile_module_keys: try: module = torch.compile(self.mods[name], mode=self.compile_mode, fullgraph=self.compile_using_fullgraph, dynamic=self.compile_using_dynamic_shape_tracing) except Exception: continue self.mods[name] = module.to(self.device) jit_module_keys.discard(name) for name in jit_module_keys: module = torch.jit.script(self.mods[name]) self.mods[name] = module.to(self.device) def _compile_jit(self): self._compile() def _wrap_distributed(self): if not self.distributed_launch and not self.data_parallel_backend: return elif self.distributed_launch: for name, module in self.mods.items(): if any(p.requires_grad for p in module.parameters()): self.mods[name] = DDP(SyncBatchNorm.convert_sync_batchnorm(module), device_ids=[self.device]) else: for name, module in self.mods.items(): if any(p.requires_grad for p in module.parameters()): self.mods[name] = DP(module) if self.data_parallel_count == -1 else DP(module, [i for i in range(self.data_parallel_count)]) @classmethod def from_hparams(cls, source, hparams_file="hyperparams.yaml", overrides={}, download_only=False, overrides_must_match=True, **kwargs): with open(fetch(filename=hparams_file, source=source)) as fin: hparams = load_hyperpyyaml(fin, overrides, overrides_must_match=overrides_must_match) pretrainer = hparams.get("pretrainer", None) if pretrainer is not None: run_on_main(pretrainer.collect_files, kwargs={"default_source": source}) if not download_only: pretrainer.load_collected() return cls(hparams["modules"], hparams, **kwargs) else: return cls(hparams["modules"], hparams, **kwargs) class EncoderClassifier(Pretrained): MODULES_NEEDED = ["compute_features", "mean_var_norm", "embedding_model", "classifier"] def encode_batch(self, wavs, wav_lens=None, normalize=False): if len(wavs.shape) == 1: wavs = wavs.unsqueeze(0) if wav_lens is None: wav_lens = torch.ones(wavs.shape[0], device=self.device) wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) wavs = wavs.float() embeddings = self.mods.embedding_model(self.mods.mean_var_norm(self.mods.compute_features(wavs), wav_lens), wav_lens) if normalize: embeddings = self.hparams.mean_var_norm_emb(embeddings, torch.ones(embeddings.shape[0], device=self.device)) return embeddings def classify_batch(self, wavs, wav_lens=None): out_prob = self.mods.classifier(self.encode_batch(wavs, wav_lens)).squeeze(1) score, index = torch.max(out_prob, dim=-1) return out_prob, score, index, self.hparams.label_encoder.decode_torch(index) def forward(self, wavs, wav_lens=None): return self.classify_batch(wavs, wav_lens)