SLU-direct-SLURP-hubert-enc / custom_interface.py
aheba31's picture
add py_module + hparams
8789192
raw
history blame
3.43 kB
import torch
from speechbrain.pretrained import Pretrained
class CustomSLUDecoder(Pretrained):
"""A end-to-end SLU model using hubert self-supervised encoder.
The class can be used either to run only the encoder (encode()) to extract
features or to run the entire model (decode()) to map the speech to its semantics.
Example
-------
>>> from speechbrain.pretrained.interfaces import foreign_class
>>> slu_model = foreign_class(source="speechbrain/slu-timers-and-such-direct-librispeech-asr",
pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
>>> slu_model.decode_file("samples/audio_samples/example6.wav")
"{'intent': 'SimpleMath', 'slots': {'number1': 37.67, 'number2': 75.7, 'op': ' minus '}}"
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = self.hparams.tokenizer
def decode_file(self, path):
"""Maps the given audio file to a string representing the
semantic dictionary for the utterance.
Arguments
---------
path : str
Path to audio file to decode.
Returns
-------
str
The predicted semantics.
"""
waveform = self.load_audio(path)
waveform = waveform.to(self.device)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
predicted_words, predicted_tokens = self.decode_batch(batch, rel_length)
return predicted_words[0]
def encode_batch(self, wavs):
"""Encodes the input audio into a sequence of hidden states
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model.
Returns
-------
torch.tensor
The encoded batch
"""
wavs = wavs.float()
wavs = wavs.to(self.device)
encoder_out = self.mods.hubert(wavs.detach())
return encoder_out
def decode_batch(self, wavs, wav_lens):
"""Maps the input audio to its semantics
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
list
Each waveform in the batch decoded.
tensor
Each predicted token id.
"""
with torch.no_grad():
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
encoder_out = self.encode_batch(wavs)
predicted_tokens, scores = self.mods.beam_searcher(
encoder_out, wav_lens
)
predicted_words = [
self.tokenizer.decode_ids(token_seq)
for token_seq in predicted_tokens
]
return predicted_words, predicted_tokens
def forward(self, wavs, wav_lens):
"""Runs full decoding - note: no gradients through decoding"""
return self.decode_batch(wavs, wav_lens)