leo19941227's picture
Upload Upstream: comit message: my best model
c9d4907
raw
history blame
2.59 kB
from collections import OrderedDict
from typing import List, Union, Dict
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
HIDDEN_DIM = 8
class Model(nn.Module):
def __init__(self):
super().__init__()
# The model needs to be a nn.Module for finetuning, not required for representation extraction
self.model1 = nn.Linear(1, HIDDEN_DIM)
self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
def forward(self, wavs):
hidden = self.model1(wavs)
# hidden: (batch_size, max_len, hidden_dim)
feature = self.model2(hidden)
# feature: (batch_size, max_len, hidden_dim)
return [hidden, feature]
class UpstreamExpert(nn.Module):
def __init__(self, ckpt: str = "./model.pt", **kwargs):
"""
Args:
ckpt:
The checkpoint path for loading your pretrained weights.
Should be fixed as model.pt for SUPERB Challenge.
"""
super().__init__()
self.name = "[Example UpstreamExpert]"
# You can use ckpt to load your pretrained weights
ckpt = torch.load(ckpt, map_location="cpu")
self.model = Model()
self.model.load_state_dict(ckpt)
def get_downsample_rates(self, key: str) -> int:
"""
Since we do not do any downsampling in this example upstream
All keys' corresponding representations have downsample rate of 1
Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz)
"""
return 1
def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]:
"""
When the returning Dict contains the List with more than one Tensor,
those Tensors should be in the same shape to train a weighted-sum on them.
"""
wavs = pad_sequence(wavs, batch_first=True).unsqueeze(-1)
# wavs: (batch_size, max_len, 1)
hidden_states = self.model(wavs)
# The "hidden_states" key will be used as default in many cases
# Others keys in this example are presented for SUPERB Challenge
return {
"hidden_states": hidden_states,
"PR": hidden_states,
"SID": hidden_states,
"ER": hidden_states,
"ASR": hidden_states,
"QbE": hidden_states,
"ASV": hidden_states,
"SD": hidden_states,
"ST": hidden_states,
"SE": hidden_states,
"SS": hidden_states,
"secret": hidden_states,
}