lmzjms's picture
Upload 1162 files
0b32ad6 verified
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from ..interfaces import UpstreamBase
from .convert import load_converted_model
class UpstreamExpert(UpstreamBase):
def __init__(self, ckpt, **kwargs):
super().__init__(**kwargs)
model, task_cfg = load_converted_model(ckpt)
self.model = model
self.wav_normalize = task_cfg.normalize
self.model.feature_grad_mult = 0.0
self.model.encoder.layerdrop = 0.0
if len(self.hooks) == 0:
module_name = "self.model.encoder.layers"
for module_id in range(len(eval(module_name))):
self.add_hook(
f"{module_name}[{module_id}]",
lambda input, output: input[0].transpose(0, 1),
)
self.add_hook("self.model.encoder", lambda input, output: output[0])
def postprocess(xs):
names, hiddens = zip(*xs)
unpad_len = min([hidden.size(1) for hidden in hiddens])
hiddens = [hidden[:, :unpad_len, :] for hidden in hiddens]
return list(zip(names, hiddens))
self.hook_postprocess = postprocess
self._init_layerdrop = self.model.encoder.layerdrop
@property
def layer_drop(self):
return self.model.encoder.layerdrop
def set_layer_drop(self, layerdrop: float = None):
if isinstance(layerdrop, float):
self.model.encoder.layerdrop = layerdrop
elif layerdrop is None:
self.model.encoder.layerdrop = self._init_layerdrop
else:
raise ValueError("layerdrop can only be float or None")
def get_downsample_rates(self, key: str) -> int:
return 320
def forward(self, wavs):
device = wavs[0].device
if self.wav_normalize:
wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
wav_padding_mask = ~torch.lt(
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
wav_lengths.unsqueeze(1),
)
padded_wav = pad_sequence(wavs, batch_first=True)
results = self.model.extract_features(padded_wav, wav_padding_mask)
# This forward function only does the model forward
# The return dict is then handled by UpstreamBase's hooks