|
""" |
|
RNN models used in Superb Benchmark |
|
|
|
Authors: |
|
* Heng-Jui Chang 2022 |
|
* Leo 2022 |
|
""" |
|
|
|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
|
|
|
from s3prl.nn.interface import AbsFrameModel |
|
|
|
__all__ = ["RNNEncoder", "SuperbDiarizationModel", "RNNLayer"] |
|
|
|
|
|
def downsample( |
|
x: torch.Tensor, x_len: torch.LongTensor, sample_rate: int, sample_style: str |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Downsamples a sequence. |
|
|
|
Args: |
|
x (torch.Tensor): Sequence (batch, timestamps, hidden_size) |
|
x_len (torch.LongTensor): Sequence length (batch, ) |
|
sample_rate (int): Downsample rate (must be greater than one) |
|
sample_style (str): Downsample style ("drop" or "concat") |
|
|
|
Raises: |
|
NotImplementedError: Sample style not supported. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: |
|
x (torch.Tensor): (batch, timestamps // sample_rate, output_size) |
|
x_len (torch.LongTensor): (batch, ) |
|
""" |
|
|
|
B, T, D = x.shape |
|
x_len = torch.div(x_len, sample_rate, rounding_mode="floor") |
|
|
|
if sample_style == "drop": |
|
|
|
x = x[:, ::sample_rate, :].contiguous() |
|
elif sample_style == "concat": |
|
|
|
if T % sample_rate != 0: |
|
x = x[:, : -(T % sample_rate), :] |
|
x = x.contiguous().view(B, int(T / sample_rate), D * sample_rate) |
|
else: |
|
raise NotImplementedError(f"Sample style={sample_style} not supported.") |
|
|
|
return x, x_len |
|
|
|
|
|
class RNNLayer(nn.Module): |
|
"""RNN Layer |
|
|
|
Args: |
|
input_size (int): Input size. |
|
hidden_size (int): Hidden size. |
|
module (str): RNN module (RNN, GRU, LSTM) |
|
dropout (float, optional): Dropout rate. Defaults to 0.0. |
|
bidirectional (bool, optional): Bidirectional. Defaults to False. |
|
proj (bool, optional): Projection layer. Defaults to False. |
|
layer_norm (bool, optional): Layer normalization. Defaults to False. |
|
sample_rate (int, optional): Downsampling rate. Defaults to 1. |
|
sample_style (str, optional): Downsampling style (**drop** or **concat**). Defaults to "drop". |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
hidden_size: int, |
|
module: str, |
|
dropout: float = 0.0, |
|
bidirectional: bool = False, |
|
proj: bool = False, |
|
layer_norm: bool = False, |
|
sample_rate: int = 1, |
|
sample_style: str = "drop", |
|
): |
|
super().__init__() |
|
self._insize = input_size |
|
|
|
self.out_size = ( |
|
hidden_size |
|
* (2 if bidirectional else 1) |
|
* (2 if sample_style == "concat" and sample_rate > 1 else 1) |
|
) |
|
self.dropout = dropout |
|
self.proj = proj |
|
self.layer_norm = layer_norm |
|
self.sample_rate = sample_rate |
|
self.sample_style = sample_style |
|
|
|
assert module.upper() in {"RNN", "GRU", "LSTM"} |
|
assert sample_style in {"drop", "concat"} |
|
|
|
self.layer = getattr(nn, module.upper())( |
|
input_size, |
|
hidden_size, |
|
num_layers=1, |
|
batch_first=True, |
|
bidirectional=bidirectional, |
|
) |
|
|
|
if self.layer_norm: |
|
rnn_out_size = hidden_size * (2 if bidirectional else 1) |
|
self.ln_layer = nn.LayerNorm(rnn_out_size) |
|
|
|
if self.dropout > 0: |
|
self.dp_layer = nn.Dropout(self.dropout) |
|
|
|
if self.proj: |
|
self.pj_layer = nn.Linear(self.out_size, self.out_size) |
|
|
|
def forward(self, xs: torch.Tensor, xs_len: torch.LongTensor): |
|
""" |
|
Args: |
|
xs (torch.FloatTensor): (batch_size, seq_len, input_size) |
|
xs_len (torch.LongTensor): (batch_size, ) |
|
|
|
Returns: |
|
tuple: |
|
|
|
1. ys (torch.FloatTensor): (batch_size, seq_len, output_size) |
|
2. ys_len (torch.LongTensor): (batch_size, ) |
|
""" |
|
if not self.training: |
|
self.layer.flatten_parameters() |
|
|
|
xs = pack_padded_sequence( |
|
xs, xs_len.cpu(), batch_first=True, enforce_sorted=False |
|
) |
|
output, _ = self.layer(xs) |
|
output, _ = pad_packed_sequence(output, batch_first=True) |
|
|
|
|
|
if self.layer_norm: |
|
output = self.ln_layer(output) |
|
|
|
if self.dropout > 0: |
|
output = self.dp_layer(output) |
|
|
|
|
|
if self.sample_rate > 1: |
|
output, xs_len = downsample( |
|
output, xs_len, self.sample_rate, self.sample_style |
|
) |
|
|
|
|
|
if self.proj: |
|
output = torch.tanh(self.pj_layer(output)) |
|
|
|
return output, xs_len |
|
|
|
@property |
|
def input_size(self) -> int: |
|
return self._insize |
|
|
|
@property |
|
def output_size(self) -> int: |
|
return self.out_size |
|
|
|
|
|
class RNNEncoder(AbsFrameModel): |
|
"""RNN Encoder for sequence to sequence modeling, e.g., ASR. |
|
|
|
Args: |
|
input_size (int): Input size. |
|
output_size (int): Output size. |
|
module (str, optional): RNN module type. Defaults to "LSTM". |
|
hidden_size (List[int], optional): Hidden sizes for each layer. Defaults to [1024]. |
|
dropout (List[float], optional): Dropout rates for each layer. Defaults to [0.0]. |
|
layer_norm (List[bool], optional): Whether to use layer norm for each layer. Defaults to [False]. |
|
proj (List[bool], optional): Whether to use projection for each layer. Defaults to [True]. |
|
sample_rate (List[int], optional): Downsample rates for each layer. Defaults to [1]. |
|
sample_style (str, optional): Downsample style ("drop" or "concat"). Defaults to "drop". |
|
bidirectional (bool, optional): Whether RNN layers are bidirectional. Defaults to False. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
output_size: int, |
|
module: str = "LSTM", |
|
proj_size: int = 1024, |
|
hidden_size: List[int] = [1024], |
|
dropout: List[float] = [0.0], |
|
layer_norm: List[bool] = [False], |
|
proj: List[bool] = [True], |
|
sample_rate: List[int] = [1], |
|
sample_style: str = "drop", |
|
bidirectional: bool = False, |
|
): |
|
super().__init__() |
|
self._input_size = input_size |
|
self._output_size = output_size |
|
|
|
prev_size = input_size |
|
|
|
self.proj = nn.Linear(prev_size, proj_size) |
|
prev_size = proj_size |
|
|
|
self.rnns = nn.ModuleList() |
|
for i in range(len(hidden_size)): |
|
rnn_layer = RNNLayer( |
|
input_size=prev_size, |
|
hidden_size=hidden_size[i], |
|
module=module, |
|
dropout=dropout[i], |
|
bidirectional=bidirectional, |
|
proj=proj[i], |
|
layer_norm=layer_norm[i], |
|
sample_rate=sample_rate[i], |
|
sample_style=sample_style, |
|
) |
|
self.rnns.append(rnn_layer) |
|
prev_size = rnn_layer.output_size |
|
|
|
self.linear = nn.Linear(prev_size, output_size) |
|
|
|
def forward(self, x: torch.Tensor, x_len: torch.LongTensor): |
|
""" |
|
Args: |
|
xs (torch.FloatTensor): (batch_size, seq_len, input_size) |
|
xs_len (torch.LongTensor): (batch_size, ) |
|
|
|
Returns: |
|
tuple: |
|
|
|
1. ys (torch.FloatTensor): (batch_size, seq_len, output_size) |
|
2. ys_len (torch.LongTensor): (batch_size, ) |
|
""" |
|
|
|
xs, xs_len = x, x_len |
|
xs = self.proj(xs) |
|
|
|
for rnn in self.rnns: |
|
xs, xs_len = rnn(xs, xs_len) |
|
|
|
logits = self.linear(xs) |
|
|
|
return logits, xs_len |
|
|
|
@property |
|
def input_size(self) -> int: |
|
return self._input_size |
|
|
|
@property |
|
def output_size(self) -> int: |
|
return self._output_size |
|
|
|
|
|
class SuperbDiarizationModel(AbsFrameModel): |
|
""" |
|
The exact RNN model used in SUPERB Benchmark for Speaker Diarization |
|
|
|
Args: |
|
input_size (int): input_size |
|
output_size (int): output_size |
|
rnn_layers (int): number of rnn layers |
|
hidden_size (int): the hidden size across all rnn layers |
|
""" |
|
|
|
def __init__( |
|
self, input_size: int, output_size: int, rnn_layers: int, hidden_size: int |
|
): |
|
super().__init__() |
|
self._input_size = input_size |
|
self._output_size = output_size |
|
|
|
self.use_rnn = rnn_layers > 0 |
|
if self.use_rnn: |
|
self.rnn = nn.LSTM( |
|
input_size, hidden_size, num_layers=rnn_layers, batch_first=True |
|
) |
|
self.linear = nn.Linear(hidden_size, output_size) |
|
else: |
|
self.linear = nn.Linear(input_size, output_size) |
|
|
|
@property |
|
def input_size(self) -> int: |
|
return self._input_size |
|
|
|
@property |
|
def output_size(self) -> int: |
|
return self._output_size |
|
|
|
def forward(self, xs, xs_len): |
|
""" |
|
Args: |
|
xs (torch.FloatTensor): (batch_size, seq_len, input_size) |
|
xs_len (torch.LongTensor): (batch_size, ) |
|
|
|
Returns: |
|
tuple: |
|
|
|
1. ys (torch.FloatTensor): (batch_size, seq_len, output_size) |
|
2. ys_len (torch.LongTensor): (batch_size, ) |
|
""" |
|
features, features_len = xs, xs_len |
|
features = features.float() |
|
if self.use_rnn: |
|
hidden, _ = self.rnn(features) |
|
predicted = self.linear(hidden) |
|
else: |
|
predicted = self.linear(features) |
|
|
|
return predicted, features_len |
|
|