import torch
import speechbrain as sb

class FeatureScaler(torch.nn.Module):
    def __init__(self, num_in, scale):
        super().__init__()
        self.scaler = torch.ones((num_in,))* scale

    def forward(self, x):
        return x * self.scaler

class CustomInterface(sb.pretrained.interfaces.Pretrained):
    MODULES_NEEDED = ["normalizer"]
    HPARAMS_NEEDED = ["feature_extractor"]

    def feats_from_audio(self, audio, lengths=torch.tensor([1.0])):
        feats = self.hparams.feature_extractor(audio)
        normalized = self.mods.normalizer(feats, lengths)
        scaled = self.mods.feature_scaler(normalized)
        return scaled

    def feats_from_file(self, path):
        audio = self.load_audio(path)
        return self.feats_from_audio(audio.unsqueeze(0)).squeeze(0)