|
from collections import OrderedDict |
|
from typing import List, Union, Dict |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
import fairseq |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UpstreamExpert(nn.Module): |
|
def __init__( |
|
self, |
|
ckpt: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", |
|
upstream_feature_selection: str = "hidden_states", |
|
**kwargs): |
|
""" |
|
Args: |
|
ckpt: |
|
The checkpoint path for loading your pretrained weights. |
|
Should be fixed as model.pt for SUPERB Challenge. |
|
upstream_feature_selection: |
|
The value could be |
|
'hidden_states', 'PR', 'SID', 'ER', 'ASR', 'QbE', 'ASV', 'SD', 'ST', 'SE', 'SS', 'secret', or others(new tasks). |
|
You can use it to control which task-specified pre- / post-processing to do. |
|
""" |
|
super().__init__() |
|
self.name = "[Example UpstreamExpert]" |
|
self.upstream_feature_selection = upstream_feature_selection |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert version.parse(fairseq.__version__) > version.parse( |
|
"0.10.2" |
|
), "Please install the fairseq master branch." |
|
|
|
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( |
|
[ckpt] |
|
) |
|
self.model = model[0] |
|
self.task = task |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_downsample_rates(self, key: str) -> int: |
|
""" |
|
Since we do not do any downsampling in this example upstream |
|
All keys' corresponding representations have downsample rate of 1 |
|
Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz) |
|
""" |
|
return 320 |
|
|
|
def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]: |
|
""" |
|
When the returning Dict contains the List with more than one Tensor, |
|
those Tensors should be in the same shape to train a weighted-sum on them. |
|
""" |
|
wavs_silence = [] |
|
|
|
|
|
|
|
|
|
|
|
wavs_silence = wavs |
|
|
|
|
|
|
|
for wav in wavs: |
|
temp_wav = torch.zeros(len(wav)//5).to(wav.device) |
|
wavs_silence.append(torch.cat((temp_wav, wav))) |
|
|
|
|
|
for wav in wavs: |
|
temp_wav = torch.zeros(len(wav)//10).to(wav.device) |
|
wavs_silence.append(torch.cat((temp_wav, wav))) |
|
|
|
|
|
for wav in wavs: |
|
temp_wav = torch.zeros(len(wav)//20).to(wav.device) |
|
wavs_silence.append(torch.cat((temp_wav, wav))) |
|
|
|
|
|
for wav in wavs: |
|
temp_wav = torch.zeros(len(wav)//5).to(wav.device) |
|
wavs_silence.append(torch.cat((wav, temp_wav))) |
|
|
|
|
|
for wav in wavs: |
|
temp_wav = torch.zeros(len(wav)//10).to(wav.device) |
|
wavs_silence.append(torch.cat((wav, temp_wav))) |
|
|
|
|
|
for wav in wavs: |
|
temp_wav = torch.zeros(len(wav)//20).to(wav.device) |
|
wavs_silence.append(torch.cat((wav, temp_wav))) |
|
|
|
|
|
wavs = wavs_silence |
|
|
|
device = wavs[0].device |
|
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device) |
|
wav_padding_mask = ~torch.lt( |
|
torch.arange(max(wav_lengths)).unsqueeze(0).to(device), |
|
wav_lengths.unsqueeze(1), |
|
) |
|
padded_wav = pad_sequence(wavs, batch_first=True) |
|
|
|
features, feat_padding_mask = self.model.extract_features( |
|
padded_wav, |
|
padding_mask=wav_padding_mask, |
|
mask=None, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return { |
|
"hidden_states": features, |
|
} |
|
|