lmzjms's picture
Upload 1162 files
0b32ad6 verified
"""
Model interfaces
Authors:
* Leo 2022
"""
from typing import List, Tuple
import torch
import torch.nn as nn
__all__ = [
"AbsUpstream",
"AbsFeaturizer",
"AbsFrameModel",
"AbsUtteranceModel",
]
class AbsUpstream(nn.Module):
"""
The upstream model should follow this interface. Please subclass it.
"""
@property
def num_layer(self) -> int:
"""
number of hidden states
"""
raise NotImplementedError
@property
def hidden_sizes(self) -> List[int]:
"""
hidden size of each hidden state
"""
raise NotImplementedError
@property
def downsample_rates(self) -> List[int]:
"""
downsample rate from 16 KHz waveforms for each hidden state
"""
raise NotImplementedError
def forward(
self, wavs: torch.FloatTensor, wavs_len: torch.LongTensor
) -> Tuple[List[torch.FloatTensor], List[torch.LongTensor]]:
"""
Args:
wavs (torch.FloatTensor): (batch_size, seq_len, 1)
wavs_len (torch.LongTensor): (batch_size, )
Returns:
tuple:
1. all_hs (List[torch.FloatTensor]): all the hidden states
2. all_hs_len (List[torch.LongTensor]): the lengths for all the hidden states
"""
raise NotImplementedError
class AbsFeaturizer(nn.Module):
"""
The featurizer should follow this interface. Please subclass it.
The featurizer's mission is to reduce (standardize) the multiple hidden
states from :obj:`AbsUpstream` into a single hidden state, so that
the downstream model can use it as a conventional representation.
"""
@property
def output_size(self) -> int:
"""
The output size after hidden states reduction
"""
raise NotImplementedError
@property
def downsample_rate(self) -> int:
"""
The downsample rate from 16 KHz waveform of the reduced single hidden state
"""
raise NotImplementedError
def forward(
self, all_hs: List[torch.FloatTensor], all_hs_len: List[torch.LongTensor]
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
"""
Args:
all_hs (List[torch.FloatTensor]): all the hidden states
all_hs_len (List[torch.LongTensor]): the lengths for all the hidden states
Returns:
tuple:
1. hs (torch.FloatTensor)
2. hs_len (torch.LongTensor)
"""
raise NotImplementedError
class AbsFrameModel(nn.Module):
"""
The frame-level model interface.
"""
@property
def input_size(self) -> int:
raise NotImplementedError
@property
def output_size(self) -> int:
raise NotImplementedError
def forward(
self, x: torch.FloatTensor, x_len: torch.LongTensor
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
raise NotImplementedError
class AbsUtteranceModel(nn.Module):
"""
The utterance-level model interface, which pools the temporal dimension.
"""
@property
def input_size(self) -> int:
raise NotImplementedError
@property
def output_size(self) -> int:
raise NotImplementedError
def forward(
self, x: torch.FloatTensor, x_len: torch.LongTensor
) -> torch.FloatTensor:
raise NotImplementedError