import os import sys import torch import numpy as np import torch.nn.functional as F from functools import cached_property from torch.nn.utils.rnn import pad_sequence sys.path.append(os.getcwd()) from main.library.speaker_diarization.speechbrain import EncoderClassifier class BaseInference: pass class SpeechBrainPretrainedSpeakerEmbedding(BaseInference): def __init__(self, embedding = "assets/models/speaker_diarization/models/speechbrain", device = None): super().__init__() self.embedding = embedding self.device = device or torch.device("cpu") self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": self.device}) def to(self, device): if not isinstance(device, torch.device): raise TypeError self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": device}) self.device = device return self @cached_property def sample_rate(self): return self.classifier_.audio_normalizer.sample_rate @cached_property def dimension(self): *_, dimension = self.classifier_.encode_batch(torch.rand(1, 16000).to(self.device)).shape return dimension @cached_property def metric(self): return "cosine" @cached_property def min_num_samples(self): with torch.inference_mode(): lower, upper = 2, round(0.5 * self.sample_rate) middle = (lower + upper) // 2 while lower + 1 < upper: try: _ = self.classifier_.encode_batch(torch.randn(1, middle).to(self.device)) upper = middle except RuntimeError: lower = middle middle = (lower + upper) // 2 return upper def __call__(self, waveforms, masks = None): batch_size, num_channels, num_samples = waveforms.shape assert num_channels == 1 waveforms = waveforms.squeeze(dim=1) if masks is None: signals = waveforms.squeeze(dim=1) wav_lens = signals.shape[1] * torch.ones(batch_size) else: batch_size_masks, _ = masks.shape assert batch_size == batch_size_masks imasks = F.interpolate(masks.unsqueeze(dim=1), size=num_samples, mode="nearest").squeeze(dim=1) > 0.5 signals = pad_sequence([waveform[imask].contiguous() for waveform, imask in zip(waveforms, imasks)], batch_first=True) wav_lens = imasks.sum(dim=1) max_len = wav_lens.max() if max_len < self.min_num_samples: return np.nan * np.zeros((batch_size, self.dimension)) too_short = wav_lens < self.min_num_samples wav_lens = wav_lens / max_len wav_lens[too_short] = 1.0 embeddings = (self.classifier_.encode_batch(signals, wav_lens=wav_lens).squeeze(dim=1).cpu().numpy()) embeddings[too_short.cpu().numpy()] = np.nan return embeddings