|
from packaging import version |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
import fairseq |
|
from s3prl.upstream.interfaces import UpstreamBase |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
EXAMPLE_SEC = 5 |
|
|
|
class UpstreamExpert(UpstreamBase): |
|
def __init__(self, ckpt, **kwargs): |
|
super().__init__(**kwargs) |
|
assert version.parse(fairseq.__version__) > version.parse( |
|
"0.10.2" |
|
), "Please install the fairseq master branch." |
|
|
|
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( |
|
[ckpt] |
|
) |
|
self.model = model[0] |
|
self.task = task |
|
|
|
if len(self.hooks) == 0: |
|
module_name = "self.model.encoder.layers" |
|
for module_id in range(len(eval(module_name))): |
|
self.add_hook( |
|
f"{module_name}[{module_id}]", |
|
lambda input, output: input[0].transpose(0, 1), |
|
) |
|
self.add_hook("self.model.encoder", lambda input, output: output[0]) |
|
|
|
def forward(self, wavs): |
|
if self.task.cfg.normalize: |
|
wavs = [F.layer_norm(wav, wav.shape) for wav in wavs] |
|
|
|
device = wavs[0].device |
|
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device) |
|
wav_padding_mask = ~torch.lt( |
|
torch.arange(max(wav_lengths)).unsqueeze(0).to(device), |
|
wav_lengths.unsqueeze(1), |
|
) |
|
padded_wav = pad_sequence(wavs, batch_first=True) |
|
|
|
features, feat_padding_mask = self.model.extract_features( |
|
padded_wav, |
|
padding_mask=wav_padding_mask, |
|
mask=None, |
|
) |
|
return { |
|
"default": features, |
|
} |