NeoPy's picture
Upload 115 files
96134ee verified
raw
history blame
3 kB
import os
import sys
import torch
import numpy as np
import torch.nn.functional as F
from functools import cached_property
from torch.nn.utils.rnn import pad_sequence
sys.path.append(os.getcwd())
from main.library.speaker_diarization.speechbrain import EncoderClassifier
class BaseInference:
pass
class SpeechBrainPretrainedSpeakerEmbedding(BaseInference):
def __init__(self, embedding = "assets/models/speaker_diarization/models/speechbrain", device = None):
super().__init__()
self.embedding = embedding
self.device = device or torch.device("cpu")
self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": self.device})
def to(self, device):
if not isinstance(device, torch.device): raise TypeError
self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": device})
self.device = device
return self
@cached_property
def sample_rate(self):
return self.classifier_.audio_normalizer.sample_rate
@cached_property
def dimension(self):
*_, dimension = self.classifier_.encode_batch(torch.rand(1, 16000).to(self.device)).shape
return dimension
@cached_property
def metric(self):
return "cosine"
@cached_property
def min_num_samples(self):
with torch.inference_mode():
lower, upper = 2, round(0.5 * self.sample_rate)
middle = (lower + upper) // 2
while lower + 1 < upper:
try:
_ = self.classifier_.encode_batch(torch.randn(1, middle).to(self.device))
upper = middle
except RuntimeError:
lower = middle
middle = (lower + upper) // 2
return upper
def __call__(self, waveforms, masks = None):
batch_size, num_channels, num_samples = waveforms.shape
assert num_channels == 1
waveforms = waveforms.squeeze(dim=1)
if masks is None:
signals = waveforms.squeeze(dim=1)
wav_lens = signals.shape[1] * torch.ones(batch_size)
else:
batch_size_masks, _ = masks.shape
assert batch_size == batch_size_masks
imasks = F.interpolate(masks.unsqueeze(dim=1), size=num_samples, mode="nearest").squeeze(dim=1) > 0.5
signals = pad_sequence([waveform[imask].contiguous() for waveform, imask in zip(waveforms, imasks)], batch_first=True)
wav_lens = imasks.sum(dim=1)
max_len = wav_lens.max()
if max_len < self.min_num_samples: return np.nan * np.zeros((batch_size, self.dimension))
too_short = wav_lens < self.min_num_samples
wav_lens = wav_lens / max_len
wav_lens[too_short] = 1.0
embeddings = (self.classifier_.encode_batch(signals, wav_lens=wav_lens).squeeze(dim=1).cpu().numpy())
embeddings[too_short.cpu().numpy()] = np.nan
return embeddings