chaanks commited on
Commit
cb8f733
1 Parent(s): e67bf97

Upload 5 files

Browse files
Files changed (5) hide show
  1. custom_interface.py +37 -0
  2. hyperparams.yaml +83 -0
  3. model.ckpt +3 -0
  4. normalizer.ckpt +3 -0
  5. test.flac +0 -0
custom_interface.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.inference.interfaces import Pretrained
3
+
4
+
5
+ class CustomEncoderBestRQ(Pretrained):
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+
9
+ def encode_batch(self, wavs, wav_lens=None, normalize=False):
10
+ # Manage single waveforms in input
11
+ if len(wavs.shape) == 1:
12
+ wavs = wavs.unsqueeze(0)
13
+
14
+ # Assign full length if wav_lens is not assigned
15
+ if wav_lens is None:
16
+ wav_lens = torch.ones(wavs.shape[0], device=self.device)
17
+
18
+ # Storing waveform in the specified device
19
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
20
+ wavs = wavs.float()
21
+
22
+ feats = self.hparams.compute_features(wavs)
23
+ feats = self.mods.normalizer(feats, wav_lens)
24
+ src = self.mods.extractor(feats)
25
+ enc_out = self.mods.encoder(src, wav_lens)
26
+ return enc_out
27
+
28
+ def encode_file(self, path, normalize=False):
29
+ waveform = self.load_audio(path)
30
+ # Fake a batch:
31
+ batch = waveform.unsqueeze(0)
32
+ rel_length = torch.tensor([1.0])
33
+ outputs = self.encode_batch(batch, rel_length)
34
+ return outputs
35
+
36
+ def forward(self, wavs, wav_lens=None, normalize=False):
37
+ return self.encode_batch(wavs=wavs, wav_lens=wav_lens, normalize=normalize)
hyperparams.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ################################
2
+ # Model: Best-RQ
3
+ # Authors: Jarod Duret 2024
4
+ # ################################
5
+
6
+ sample_rate: 16000
7
+ n_fft: 512
8
+ n_mels: 80
9
+ win_length: 32
10
+ hop_length: 10
11
+
12
+ ####################### Model parameters ###########################
13
+
14
+ # Transformer
15
+ d_model: 768
16
+ nhead: 8
17
+ num_encoder_layers: 12
18
+ num_decoder_layers: 0
19
+ d_ffn: 2048
20
+ transformer_dropout: 0.1
21
+ activation: !name:torch.nn.GELU
22
+ output_neurons: 5000
23
+ encoder_layerdrop: 0.0
24
+
25
+ compute_features: !new:speechbrain.lobes.features.Fbank
26
+ sample_rate: !ref <sample_rate>
27
+ n_fft: !ref <n_fft>
28
+ n_mels: !ref <n_mels>
29
+ hop_length: !ref <hop_length>
30
+ win_length: !ref <win_length>
31
+
32
+ normalizer: !new:speechbrain.processing.features.InputNormalization
33
+ norm_type: global
34
+ update_until_epoch: 0
35
+
36
+ ############################## Models ################################
37
+
38
+ latent_extractor: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
39
+ input_shape: (8, 10, 80)
40
+ num_blocks: 2
41
+ num_layers_per_block: 1
42
+ out_channels: (64, 32)
43
+ kernel_sizes: (3, 3)
44
+ strides: (2, 2)
45
+ residuals: (False, False)
46
+
47
+ latent_encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR
48
+ input_size: 640
49
+ tgt_vocab: !ref <output_neurons>
50
+ d_model: !ref <d_model>
51
+ nhead: !ref <nhead>
52
+ num_encoder_layers: !ref <num_encoder_layers>
53
+ num_decoder_layers: !ref <num_decoder_layers>
54
+ d_ffn: !ref <d_ffn>
55
+ dropout: !ref <transformer_dropout>
56
+ activation: !ref <activation>
57
+ conformer_activation: !ref <activation>
58
+ encoder_module: conformer
59
+ attention_type: RelPosMHAXL
60
+ normalize_before: True
61
+ causal: False
62
+ layerdrop_prob: !ref <encoder_layerdrop>
63
+
64
+ # We must call an encoder wrapper so the decoder isn't run (we don't have any)
65
+ encoder_wrapper: !new:speechbrain.lobes.models.transformer.TransformerASR.EncoderWrapper
66
+ transformer: !ref <latent_encoder>
67
+
68
+ # encoder: !new:speechbrain.nnet.containers.LengthsCapableSequential
69
+ # latent_extractor: !ref <latent_extractor>
70
+ # encoder_wrapper: !ref <encoder_wrapper>
71
+
72
+ model: !new:torch.nn.ModuleList
73
+ - [!ref <latent_extractor>, !ref <encoder_wrapper>]
74
+
75
+ modules:
76
+ normalizer: !ref <normalizer>
77
+ extractor: !ref <latent_extractor>
78
+ encoder: !ref <encoder_wrapper>
79
+
80
+ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
81
+ loadables:
82
+ model: !ref <model>
83
+ normalizer: !ref <normalizer>
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0637e00b93015b37d3d946c791b6110441540c494c818b98815be21f149be68
3
+ size 540502386
normalizer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92244ada292c7d670d1dc88549e74ed24b3e25e70f27fe443420cf4832d6811b
3
+ size 1578
test.flac ADDED
Binary file (74.1 kB). View file