File size: 4,880 Bytes
e129232 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from collections import OrderedDict
from typing import List, Union, Dict
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
import fairseq
# class Model(nn.Module):
# def __init__(self):
# super().__init__()
# # The model needs to be a nn.Module for finetuning, not required for representation extraction
# self.model1 = nn.Linear(1, HIDDEN_DIM)
# self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
# def forward(self, wavs, upstream_feature_selection="hidden_states"):
# # You can do task-specified pre- / post-processing based on upstream_feature_selection
# hidden = self.model1(wavs)
# # hidden: (batch_size, max_len, hidden_dim)
# feature = self.model2(hidden)
# # feature: (batch_size, max_len, hidden_dim)
# return [hidden, feature]
class UpstreamExpert(nn.Module):
def __init__(
self,
ckpt: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt",
upstream_feature_selection: str = "hidden_states",
**kwargs):
"""
Args:
ckpt:
The checkpoint path for loading your pretrained weights.
Should be fixed as model.pt for SUPERB Challenge.
upstream_feature_selection:
The value could be
'hidden_states', 'PR', 'SID', 'ER', 'ASR', 'QbE', 'ASV', 'SD', 'ST', 'SE', 'SS', 'secret', or others(new tasks).
You can use it to control which task-specified pre- / post-processing to do.
"""
super().__init__()
self.name = "[Example UpstreamExpert]"
self.upstream_feature_selection = upstream_feature_selection
# # You can use ckpt to load your pretrained weights
# ckpt = torch.load(ckpt, map_location="cpu")
# self.model = Model()
# self.model.load_state_dict(ckpt)
assert version.parse(fairseq.__version__) > version.parse(
"0.10.2"
), "Please install the fairseq master branch."
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[ckpt]
)
self.model = model[0]
self.task = task
def get_downsample_rates(self, key: str) -> int:
"""
Since we do not do any downsampling in this example upstream
All keys' corresponding representations have downsample rate of 1
Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz)
"""
return 320
def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]:
"""
When the returning Dict contains the List with more than one Tensor,
those Tensors should be in the same shape to train a weighted-sum on them.
"""
wavs_silence = []
#Total 7 settings
#original
wavs_silence = wavs
#front, 5
for wav in wavs:
temp_wav = torch.zeros(len(wav)//5).to(wav.device)
wavs_silence.append(torch.cat((temp_wav, wav)))
#front, 10
for wav in wavs:
temp_wav = torch.zeros(len(wav)//10).to(wav.device)
wavs_silence.append(torch.cat((temp_wav, wav)))
#front, 20
for wav in wavs:
temp_wav = torch.zeros(len(wav)//20).to(wav.device)
wavs_silence.append(torch.cat((temp_wav, wav)))
#end, 5
for wav in wavs:
temp_wav = torch.zeros(len(wav)//5).to(wav.device)
wavs_silence.append(torch.cat((wav, temp_wav)))
#end, 10
for wav in wavs:
temp_wav = torch.zeros(len(wav)//10).to(wav.device)
wavs_silence.append(torch.cat((wav, temp_wav)))
#end, 20
for wav in wavs:
temp_wav = torch.zeros(len(wav)//20).to(wav.device)
wavs_silence.append(torch.cat((wav, temp_wav)))
wavs = wavs_silence
device = wavs[0].device
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
wav_padding_mask = ~torch.lt(
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
wav_lengths.unsqueeze(1),
)
padded_wav = pad_sequence(wavs, batch_first=True)
features, feat_padding_mask = self.model.extract_features(
padded_wav,
padding_mask=wav_padding_mask,
mask=None,
)
# Deprecated! Do not do any task-specified postprocess below
# You can use the init arg "upstream_feature_selection" to control which task-specified pre- / post-processing to do.
# The "hidden_states" key will be used as default in many cases
# Others keys in this example are presented for SUPERB Challenge
return {
"hidden_states": features,
}
|