File size: 1,322 Bytes
cb8f733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from speechbrain.inference.interfaces import Pretrained


class CustomEncoderBestRQ(Pretrained):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def encode_batch(self, wavs, wav_lens=None, normalize=False):
        # Manage single waveforms in input
        if len(wavs.shape) == 1:
            wavs = wavs.unsqueeze(0)

        # Assign full length if wav_lens is not assigned
        if wav_lens is None:
            wav_lens = torch.ones(wavs.shape[0], device=self.device)

        # Storing waveform in the specified device
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
        wavs = wavs.float()

        feats = self.hparams.compute_features(wavs)
        feats = self.mods.normalizer(feats, wav_lens)
        src = self.mods.extractor(feats)
        enc_out = self.mods.encoder(src, wav_lens)
        return enc_out

    def encode_file(self, path, normalize=False):
        waveform = self.load_audio(path)
        # Fake a batch:
        batch = waveform.unsqueeze(0)
        rel_length = torch.tensor([1.0])
        outputs = self.encode_batch(batch, rel_length)
        return outputs

    def forward(self, wavs, wav_lens=None, normalize=False):
        return self.encode_batch(wavs=wavs, wav_lens=wav_lens, normalize=normalize)