|
from collections import OrderedDict |
|
from typing import List, Union, Dict |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
HIDDEN_DIM = 8 |
|
|
|
class Model(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.model1 = nn.Linear(1, HIDDEN_DIM) |
|
self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) |
|
|
|
def forward(self, wavs): |
|
hidden = self.model1(wavs) |
|
|
|
|
|
feature = self.model2(hidden) |
|
|
|
|
|
return [hidden, feature] |
|
|
|
class UpstreamExpert(nn.Module): |
|
def __init__(self, ckpt: str = "./model.pt", **kwargs): |
|
""" |
|
Args: |
|
ckpt: |
|
The checkpoint path for loading your pretrained weights. |
|
Should be fixed as model.pt for SUPERB Challenge. |
|
""" |
|
super().__init__() |
|
self.name = "[Example UpstreamExpert]" |
|
|
|
|
|
ckpt = torch.load(ckpt, map_location="cpu") |
|
self.model = Model() |
|
self.model.load_state_dict(ckpt) |
|
|
|
def get_downsample_rates(self, key: str) -> int: |
|
""" |
|
Since we do not do any downsampling in this example upstream |
|
All keys' corresponding representations have downsample rate of 1 |
|
Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz) |
|
""" |
|
return 1 |
|
|
|
def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]: |
|
""" |
|
When the returning Dict contains the List with more than one Tensor, |
|
those Tensors should be in the same shape to train a weighted-sum on them. |
|
""" |
|
|
|
wavs = pad_sequence(wavs, batch_first=True).unsqueeze(-1) |
|
|
|
|
|
hidden_states = self.model(wavs) |
|
|
|
|
|
|
|
return { |
|
"hidden_states": hidden_states, |
|
"PR": hidden_states, |
|
"SID": hidden_states, |
|
"ER": hidden_states, |
|
"ASR": hidden_states, |
|
"QbE": hidden_states, |
|
"ASV": hidden_states, |
|
"SD": hidden_states, |
|
"ST": hidden_states, |
|
"SE": hidden_states, |
|
"SS": hidden_states, |
|
"secret": hidden_states, |
|
} |
|
|