aheba31 commited on
Commit
8789192
1 Parent(s): 7ab725a

add py_module + hparams

Browse files
Files changed (2) hide show
  1. custom_interface.py +100 -0
  2. hyperparams.yaml +86 -0
custom_interface.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.pretrained import Pretrained
3
+
4
+
5
+ class CustomSLUDecoder(Pretrained):
6
+ """A end-to-end SLU model using hubert self-supervised encoder.
7
+
8
+ The class can be used either to run only the encoder (encode()) to extract
9
+ features or to run the entire model (decode()) to map the speech to its semantics.
10
+
11
+ Example
12
+ -------
13
+ >>> from speechbrain.pretrained.interfaces import foreign_class
14
+ >>> slu_model = foreign_class(source="speechbrain/slu-timers-and-such-direct-librispeech-asr",
15
+ pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
16
+ >>> slu_model.decode_file("samples/audio_samples/example6.wav")
17
+ "{'intent': 'SimpleMath', 'slots': {'number1': 37.67, 'number2': 75.7, 'op': ' minus '}}"
18
+ """
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+ self.tokenizer = self.hparams.tokenizer
23
+
24
+ def decode_file(self, path):
25
+ """Maps the given audio file to a string representing the
26
+ semantic dictionary for the utterance.
27
+
28
+ Arguments
29
+ ---------
30
+ path : str
31
+ Path to audio file to decode.
32
+
33
+ Returns
34
+ -------
35
+ str
36
+ The predicted semantics.
37
+ """
38
+ waveform = self.load_audio(path)
39
+ waveform = waveform.to(self.device)
40
+ # Fake a batch:
41
+ batch = waveform.unsqueeze(0)
42
+ rel_length = torch.tensor([1.0])
43
+ predicted_words, predicted_tokens = self.decode_batch(batch, rel_length)
44
+ return predicted_words[0]
45
+
46
+ def encode_batch(self, wavs):
47
+ """Encodes the input audio into a sequence of hidden states
48
+
49
+ Arguments
50
+ ---------
51
+ wavs : torch.tensor
52
+ Batch of waveforms [batch, time, channels] or [batch, time]
53
+ depending on the model.
54
+
55
+ Returns
56
+ -------
57
+ torch.tensor
58
+ The encoded batch
59
+ """
60
+ wavs = wavs.float()
61
+ wavs = wavs.to(self.device)
62
+ encoder_out = self.mods.hubert(wavs.detach())
63
+ return encoder_out
64
+
65
+ def decode_batch(self, wavs, wav_lens):
66
+ """Maps the input audio to its semantics
67
+
68
+ Arguments
69
+ ---------
70
+ wavs : torch.tensor
71
+ Batch of waveforms [batch, time, channels] or [batch, time]
72
+ depending on the model.
73
+ wav_lens : torch.tensor
74
+ Lengths of the waveforms relative to the longest one in the
75
+ batch, tensor of shape [batch]. The longest one should have
76
+ relative length 1.0 and others len(waveform) / max_length.
77
+ Used for ignoring padding.
78
+
79
+ Returns
80
+ -------
81
+ list
82
+ Each waveform in the batch decoded.
83
+ tensor
84
+ Each predicted token id.
85
+ """
86
+ with torch.no_grad():
87
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
88
+ encoder_out = self.encode_batch(wavs)
89
+ predicted_tokens, scores = self.mods.beam_searcher(
90
+ encoder_out, wav_lens
91
+ )
92
+ predicted_words = [
93
+ self.tokenizer.decode_ids(token_seq)
94
+ for token_seq in predicted_tokens
95
+ ]
96
+ return predicted_words, predicted_tokens
97
+
98
+ def forward(self, wavs, wav_lens):
99
+ """Runs full decoding - note: no gradients through decoding"""
100
+ return self.decode_batch(wavs, wav_lens)
hyperparams.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ############################################################################
2
+ # Model: WAV2VEC base for Emotion Recognition
3
+ # ############################################################################
4
+
5
+
6
+ # Hparams NEEDED
7
+ HPARAMS_NEEDED: ["beam_searcher"]
8
+ # Modules Needed
9
+ MODULES_NEEDED: ["hubert", "decoder", "seq_lin"]
10
+
11
+ # URL for the wav2vec2 model, you can change to benchmark diffrenet models
12
+ wav2vec2_hub: facebook/hubert-base-ls960
13
+
14
+ # Pretrain folder (HuggingFace)
15
+ pretrained_path: speechbrain/SLU-direct-SLURP-hubert-enc
16
+
17
+ # parameters
18
+ encoder_dim: 768
19
+ output_neurons: 58
20
+ emb_size: 128
21
+ dec_neurons: 512
22
+ dec_attn_dim: 512
23
+ dec_layer: 3
24
+
25
+ hubert: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
26
+ source: !ref <wav2vec2_hub>
27
+ output_norm: True
28
+ freeze: True
29
+ pretrain: False
30
+ save_path: wav2vec2_checkpoints
31
+
32
+ output_emb: !new:speechbrain.nnet.embedding.Embedding
33
+ num_embeddings: !ref <output_neurons>
34
+ embedding_dim: !ref <emb_size>
35
+
36
+ dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
37
+ enc_dim: !ref <encoder_dim>
38
+ input_size: !ref <emb_size>
39
+ rnn_type: lstm
40
+ attn_type: content
41
+ hidden_size: !ref <dec_neurons>
42
+ attn_dim: !ref <dec_attn_dim>
43
+ num_layers: !ref <dec_layer>
44
+ scaling: 1.0
45
+ dropout: 0.0
46
+
47
+ seq_lin: !new:speechbrain.nnet.linear.Linear
48
+ input_size: !ref <dec_neurons>
49
+ n_neurons: !ref <output_neurons>
50
+
51
+ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
52
+ embedding: !ref <output_emb>
53
+ decoder: !ref <dec>
54
+ linear: !ref <seq_lin>
55
+ bos_index: 0
56
+ eos_index: 0
57
+ min_decode_ratio: 0.0
58
+ max_decode_ratio: 10.0
59
+ beam_size: 80
60
+ eos_threshold: 1.5
61
+ temperature: 1.25
62
+ using_max_attn_shift: false
63
+ max_attn_shift: 30
64
+ coverage_penalty: 0.
65
+
66
+ model: !new:torch.nn.ModuleList
67
+ - [!ref <output_emb>, !ref <dec>, !ref <seq_lin>]
68
+
69
+ modules:
70
+ hubert: !ref <hubert>
71
+ beam_searcher: !ref <beam_searcher>
72
+
73
+ tokenizer: !new:sentencepiece.SentencePieceProcessor
74
+
75
+
76
+ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
77
+ loadables:
78
+ hubert: !ref <hubert>
79
+ model: !ref <model>
80
+ tokenizer: !ref <tokenizer>
81
+ paths:
82
+ hubert: !ref <pretrained_path>/hubert.ckpt
83
+ model: !ref <pretrained_path>/model.ckpt
84
+ tokenizer: !ref <pretrained_path>/tokenizer_58_unigram.model
85
+
86
+