from collections import OrderedDict from typing import List, Union, Dict import torch import torch.nn as nn from torch import Tensor from torch.nn.utils.rnn import pad_sequence HIDDEN_DIM = 8 class Model(nn.Module): def __init__(self): super().__init__() # The model needs to be a nn.Module for finetuning, not required for representation extraction self.model1 = nn.Linear(1, HIDDEN_DIM) self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) def forward(self, wavs): hidden = self.model1(wavs) # hidden: (batch_size, max_len, hidden_dim) feature = self.model2(hidden) # feature: (batch_size, max_len, hidden_dim) return [hidden, feature] class UpstreamExpert(nn.Module): def __init__(self, ckpt: str = "./model.pt", **kwargs): """ Args: ckpt: The checkpoint path for loading your pretrained weights. Should be fixed as model.pt for SUPERB Challenge. """ super().__init__() self.name = "[Example UpstreamExpert]" # You can use ckpt to load your pretrained weights ckpt = torch.load(ckpt, map_location="cpu") self.model = Model() self.model.load_state_dict(ckpt) def get_downsample_rates(self, key: str) -> int: """ Since we do not do any downsampling in this example upstream All keys' corresponding representations have downsample rate of 1 Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz) """ return 1 def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]: """ When the returning Dict contains the List with more than one Tensor, those Tensors should be in the same shape to train a weighted-sum on them. """ wavs = pad_sequence(wavs, batch_first=True).unsqueeze(-1) # wavs: (batch_size, max_len, 1) hidden_states = self.model(wavs) # The "hidden_states" key will be used as default in many cases # Others keys in this example are presented for SUPERB Challenge return { "hidden_states": hidden_states, "PR": hidden_states, "SID": hidden_states, "ER": hidden_states, "ASR": hidden_states, "QbE": hidden_states, "ASV": hidden_states, "SD": hidden_states, "ST": hidden_states, "SE": hidden_states, "SS": hidden_states, "secret": hidden_states, }