import torch | |
import speechbrain as sb | |
class Custom(sb.pretrained.interfaces.Pretrained): | |
MODULES_NEEDED = ["normalizer"] | |
HPARAMS_NEEDED = ["feature_extractor"] | |
def feats_from_audio(self, audio, lengths=torch.tensor([1.0])): | |
feats = self.hparams.feature_extractor(audio) | |
normalized = self.mods.normalizer(feats, lengths) | |
return normalized | |
def feats_from_file(self, path): | |
audio = self.load_audio(path) | |
return self.feats_from_audio(audio.unsqueeze(0)).squeeze(0) | |