superb_submit / expert.py
LeoFeng's picture
init
e129232
from collections import OrderedDict
from typing import List, Union, Dict
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
import fairseq
# class Model(nn.Module):
# def __init__(self):
# super().__init__()
# # The model needs to be a nn.Module for finetuning, not required for representation extraction
# self.model1 = nn.Linear(1, HIDDEN_DIM)
# self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
# def forward(self, wavs, upstream_feature_selection="hidden_states"):
# # You can do task-specified pre- / post-processing based on upstream_feature_selection
# hidden = self.model1(wavs)
# # hidden: (batch_size, max_len, hidden_dim)
# feature = self.model2(hidden)
# # feature: (batch_size, max_len, hidden_dim)
# return [hidden, feature]
class UpstreamExpert(nn.Module):
def __init__(
self,
ckpt: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt",
upstream_feature_selection: str = "hidden_states",
**kwargs):
"""
Args:
ckpt:
The checkpoint path for loading your pretrained weights.
Should be fixed as model.pt for SUPERB Challenge.
upstream_feature_selection:
The value could be
'hidden_states', 'PR', 'SID', 'ER', 'ASR', 'QbE', 'ASV', 'SD', 'ST', 'SE', 'SS', 'secret', or others(new tasks).
You can use it to control which task-specified pre- / post-processing to do.
"""
super().__init__()
self.name = "[Example UpstreamExpert]"
self.upstream_feature_selection = upstream_feature_selection
# # You can use ckpt to load your pretrained weights
# ckpt = torch.load(ckpt, map_location="cpu")
# self.model = Model()
# self.model.load_state_dict(ckpt)
assert version.parse(fairseq.__version__) > version.parse(
"0.10.2"
), "Please install the fairseq master branch."
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[ckpt]
)
self.model = model[0]
self.task = task
def get_downsample_rates(self, key: str) -> int:
"""
Since we do not do any downsampling in this example upstream
All keys' corresponding representations have downsample rate of 1
Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz)
"""
return 320
def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]:
"""
When the returning Dict contains the List with more than one Tensor,
those Tensors should be in the same shape to train a weighted-sum on them.
"""
wavs_silence = []
#Total 7 settings
#original
wavs_silence = wavs
#front, 5
for wav in wavs:
temp_wav = torch.zeros(len(wav)//5).to(wav.device)
wavs_silence.append(torch.cat((temp_wav, wav)))
#front, 10
for wav in wavs:
temp_wav = torch.zeros(len(wav)//10).to(wav.device)
wavs_silence.append(torch.cat((temp_wav, wav)))
#front, 20
for wav in wavs:
temp_wav = torch.zeros(len(wav)//20).to(wav.device)
wavs_silence.append(torch.cat((temp_wav, wav)))
#end, 5
for wav in wavs:
temp_wav = torch.zeros(len(wav)//5).to(wav.device)
wavs_silence.append(torch.cat((wav, temp_wav)))
#end, 10
for wav in wavs:
temp_wav = torch.zeros(len(wav)//10).to(wav.device)
wavs_silence.append(torch.cat((wav, temp_wav)))
#end, 20
for wav in wavs:
temp_wav = torch.zeros(len(wav)//20).to(wav.device)
wavs_silence.append(torch.cat((wav, temp_wav)))
wavs = wavs_silence
device = wavs[0].device
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
wav_padding_mask = ~torch.lt(
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
wav_lengths.unsqueeze(1),
)
padded_wav = pad_sequence(wavs, batch_first=True)
features, feat_padding_mask = self.model.extract_features(
padded_wav,
padding_mask=wav_padding_mask,
mask=None,
)
# Deprecated! Do not do any task-specified postprocess below
# You can use the init arg "upstream_feature_selection" to control which task-specified pre- / post-processing to do.
# The "hidden_states" key will be used as default in many cases
# Others keys in this example are presented for SUPERB Challenge
return {
"hidden_states": features,
}