File size: 3,001 Bytes
96134ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import sys
import torch

import numpy as np
import torch.nn.functional as F

from functools import cached_property
from torch.nn.utils.rnn import pad_sequence

sys.path.append(os.getcwd())

from main.library.speaker_diarization.speechbrain import EncoderClassifier

class BaseInference:
    pass

class SpeechBrainPretrainedSpeakerEmbedding(BaseInference):
    def __init__(self, embedding = "assets/models/speaker_diarization/models/speechbrain", device = None):
        super().__init__()

        self.embedding = embedding
        self.device = device or torch.device("cpu")
        self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": self.device})

    def to(self, device):
        if not isinstance(device, torch.device): raise TypeError

        self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": device})
        self.device = device
        return self

    @cached_property
    def sample_rate(self):
        return self.classifier_.audio_normalizer.sample_rate

    @cached_property
    def dimension(self):
        *_, dimension = self.classifier_.encode_batch(torch.rand(1, 16000).to(self.device)).shape
        return dimension

    @cached_property
    def metric(self):
        return "cosine"

    @cached_property
    def min_num_samples(self):
        with torch.inference_mode():
            lower, upper = 2, round(0.5 * self.sample_rate)
            middle = (lower + upper) // 2

            while lower + 1 < upper:
                try:
                    _ = self.classifier_.encode_batch(torch.randn(1, middle).to(self.device))
                    upper = middle
                except RuntimeError:
                    lower = middle

                middle = (lower + upper) // 2

        return upper

    def __call__(self, waveforms, masks = None):
        batch_size, num_channels, num_samples = waveforms.shape
        assert num_channels == 1

        waveforms = waveforms.squeeze(dim=1)

        if masks is None:
            signals = waveforms.squeeze(dim=1)
            wav_lens = signals.shape[1] * torch.ones(batch_size)
        else:
            batch_size_masks, _ = masks.shape
            assert batch_size == batch_size_masks

            imasks = F.interpolate(masks.unsqueeze(dim=1), size=num_samples, mode="nearest").squeeze(dim=1) > 0.5
            signals = pad_sequence([waveform[imask].contiguous() for waveform, imask in zip(waveforms, imasks)], batch_first=True)
            wav_lens = imasks.sum(dim=1)

        max_len = wav_lens.max()
        if max_len < self.min_num_samples: return np.nan * np.zeros((batch_size, self.dimension))

        too_short = wav_lens < self.min_num_samples
        wav_lens = wav_lens / max_len
        wav_lens[too_short] = 1.0

        embeddings = (self.classifier_.encode_batch(signals, wav_lens=wav_lens).squeeze(dim=1).cpu().numpy())
        embeddings[too_short.cpu().numpy()] = np.nan

        return embeddings