OSUM / wenet /paraformer /layers.py
tomxxie
适配zeroGPU
568e264
""" NOTE(Mddct): This file is experimental and is used to export paraformer
"""
import math
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint as ckpt
from wenet.paraformer.attention import (DummyMultiHeadSANM,
MultiHeadAttentionCross,
MultiHeadedAttentionSANM)
from wenet.paraformer.embedding import ParaformerPositinoalEncoding
from wenet.paraformer.subsampling import IdentitySubsampling
from wenet.transformer.encoder import BaseEncoder
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.mask import make_non_pad_mask
class LFR(torch.nn.Module):
def __init__(self, m: int = 7, n: int = 6) -> None:
"""
Actually, this implements stacking frames and skipping frames.
if m = 1 and n = 1, just return the origin features.
if m = 1 and n > 1, it works like skipping.
if m > 1 and n = 1, it works like stacking but only support right frames.
if m > 1 and n > 1, it works like LFR.
"""
super().__init__()
self.m = m
self.n = n
self.left_padding_nums = math.ceil((self.m - 1) // 2)
def forward(self, input: torch.Tensor,
input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
orign_type = input_lens.dtype
input_lens = input_lens.to(torch.int64)
B, _, D = input.size()
n_lfr = torch.ceil(input_lens / self.n).to(input_lens.dtype)
# right_padding_nums >= 0
prepad_nums = input_lens + self.left_padding_nums
right_padding_nums = torch.where(
self.m >= (prepad_nums - self.n * (n_lfr - 1)),
self.m - (prepad_nums - self.n * (n_lfr - 1)),
0,
)
T_all = self.left_padding_nums + input_lens + right_padding_nums
new_len = T_all // self.n
T_all_max = T_all.max().int()
tail_frames_index = (input_lens - 1).view(B, 1, 1).repeat(1, 1,
D) # [B,1,D]
tail_frames = torch.gather(input, 1, tail_frames_index)
tail_frames = tail_frames.repeat(1, right_padding_nums.max().int(), 1)
head_frames = input[:, 0:1, :].repeat(1, self.left_padding_nums, 1)
# stack
input = torch.cat([head_frames, input, tail_frames], dim=1)
index = torch.arange(T_all_max,
device=input.device,
dtype=input_lens.dtype).unsqueeze(0).repeat(
B, 1) # [B, T_all_max]
# [B, T_all_max]
index_mask = index < (self.left_padding_nums + input_lens).unsqueeze(1)
tail_index_mask = torch.logical_not(
index >= (T_all.unsqueeze(1))) & index_mask
tail = torch.ones(T_all_max,
dtype=input_lens.dtype,
device=input.device).unsqueeze(0).repeat(B, 1) * (
T_all_max - 1) # [B, T_all_max]
indices = torch.where(torch.logical_or(index_mask, tail_index_mask),
index, tail)
input = torch.gather(input, 1, indices.unsqueeze(2).repeat(1, 1, D))
input = input.unfold(1, self.m, step=self.n).transpose(2, 3)
# new len
new_len = new_len.to(orign_type)
return input.reshape(B, -1, D * self.m), new_len
class PositionwiseFeedForwardDecoderSANM(torch.nn.Module):
"""Positionwise feed forward layer.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
"""
def __init__(self,
idim,
hidden_units,
dropout_rate,
adim=None,
activation=torch.nn.ReLU()):
"""Construct an PositionwiseFeedForward object."""
super(PositionwiseFeedForwardDecoderSANM, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_2 = torch.nn.Linear(hidden_units,
idim if adim is None else adim,
bias=False)
self.dropout = torch.nn.Dropout(dropout_rate)
self.activation = activation
self.norm = torch.nn.LayerNorm(hidden_units)
def forward(self, x):
"""Forward function."""
return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x)))))
class AliParaformerEncoderLayer(TransformerEncoderLayer):
def __init__(self,
size: int,
self_attn: torch.nn.Module,
feed_forward: torch.nn.Module,
dropout_rate: float,
normalize_before: bool = True,
in_size: int = 256):
""" Resize input in_size to size
"""
super().__init__(size, self_attn, feed_forward, dropout_rate,
normalize_before)
self.in_size = in_size
self.size = size
del self.norm1
self.norm1 = torch.nn.LayerNorm(in_size)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: Optional[torch.Tensor] = None,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
residual = x
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(
x,
x,
x,
mask,
cache=att_cache,
mask_pad=mask_pad,
)
if self.in_size == self.size:
x = residual + self.dropout(x_att)
else:
x = self.dropout(x_att)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
return x, mask, new_att_cache, fake_cnn_cache
class SanmEncoder(BaseEncoder):
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
kernel_size: int = 11,
sanm_shfit: int = 0,
gradient_checkpointing: bool = False,
):
super().__init__(input_size,
output_size,
attention_heads,
linear_units,
num_blocks,
dropout_rate,
positional_dropout_rate,
attention_dropout_rate,
input_layer,
pos_enc_layer_type,
normalize_before,
static_chunk_size,
use_dynamic_chunk,
global_cmvn,
use_dynamic_left_chunk,
gradient_checkpointing=gradient_checkpointing)
del self.embed
self.embed = IdentitySubsampling(
input_size,
output_size,
dropout_rate,
ParaformerPositinoalEncoding(input_size,
output_size,
positional_dropout_rate,
max_len=5000),
)
encoder_selfattn_layer = MultiHeadedAttentionSANM
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = torch.nn.ModuleList([
AliParaformerEncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
PositionwiseFeedForward(output_size, linear_units,
dropout_rate),
dropout_rate,
normalize_before,
in_size=input_size,
)
])
self.encoders = torch.nn.ModuleList([
AliParaformerEncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
PositionwiseFeedForward(
output_size,
linear_units,
dropout_rate,
),
dropout_rate,
normalize_before,
in_size=output_size) for _ in range(num_blocks - 1)
])
if self.normalize_before:
self.after_norm = torch.nn.LayerNorm(output_size)
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
for layer in self.encoders0:
xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
for layer in self.encoders:
xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs
@torch.jit.unused
def forward_layers_checkpointed(self, xs: torch.Tensor,
chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
for layer in self.encoders0:
xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
for layer in self.encoders:
xs, _, _, _ = ckpt.checkpoint(layer.__call__,
xs,
chunk_masks,
pos_emb,
mask_pad,
use_reentrant=False)
return xs
class _Decoders3(torch.nn.Module):
"""Paraformer has a decoder3"""
def __init__(self, hidden: int, pos_clss: torch.nn.Module) -> None:
super().__init__()
self.feed_forward = pos_clss
self.norm1 = torch.nn.LayerNorm(hidden)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.feed_forward(self.norm1(x))
class SanmDecoderLayer(DecoderLayer):
def __init__(self,
size: int,
self_attn: Optional[torch.nn.Module],
src_attn: Optional[torch.nn.Module],
feed_forward: torch.nn.Module,
dropout_rate: float,
normalize_before: bool = True):
super().__init__(size, self_attn, src_attn, feed_forward, dropout_rate,
normalize_before)
# NOTE(Mddct): ali-Paraformer need eps=1e-12
self.norm1 = torch.nn.LayerNorm(size, eps=1e-12)
self.norm2 = torch.nn.LayerNorm(size, eps=1e-12)
self.norm3 = torch.nn.LayerNorm(size, eps=1e-12)
def forward(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
cache: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = tgt_mask[:, -1:, :]
x = tgt
if self.self_attn is not None:
if self.normalize_before:
tgt = self.norm2(tgt)
tgt_q = tgt
x = self.self_attn(tgt_q,
tgt,
tgt,
tgt_q_mask,
mask_pad=tgt_q_mask)[0]
x = residual + self.dropout(x)
if self.src_attn is not None:
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(
self.src_attn(
x, memory, memory, memory_mask, mask_pad=memory_mask)[0])
return x, tgt_mask, memory, memory_mask
class SanmDecoder(TransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0,
src_attention_dropout_rate: float = 0,
input_layer: str = "embed",
use_output_layer: bool = True,
normalize_before: bool = True,
src_attention: bool = True,
att_layer_num: int = 16,
kernel_size: int = 11,
sanm_shfit: int = 0,
gradient_checkpointing: bool = False,
):
super().__init__(vocab_size,
encoder_output_size,
attention_heads,
linear_units,
num_blocks,
dropout_rate,
positional_dropout_rate,
self_attention_dropout_rate,
src_attention_dropout_rate,
input_layer,
use_output_layer,
normalize_before,
src_attention,
gradient_checkpointing=gradient_checkpointing)
del self.embed, self.decoders
self.decoders = torch.nn.ModuleList([
SanmDecoderLayer(
encoder_output_size,
DummyMultiHeadSANM(attention_heads, encoder_output_size,
encoder_output_size, dropout_rate,
kernel_size, sanm_shfit),
MultiHeadAttentionCross(attention_heads, encoder_output_size,
encoder_output_size, dropout_rate,
kernel_size, sanm_shfit,
encoder_output_size),
PositionwiseFeedForwardDecoderSANM(encoder_output_size,
linear_units, dropout_rate),
dropout_rate,
normalize_before,
) for _ in range(att_layer_num)
])
# NOTE(Mddct): att_layer_num == num_blocks in released pararformer model
assert att_layer_num == num_blocks
# NOTE(Mddct): Paraformer has a deocder3
self.decoders3 = torch.nn.ModuleList([
_Decoders3(
encoder_output_size,
PositionwiseFeedForwardDecoderSANM(encoder_output_size,
linear_units, dropout_rate))
])
def forward(
self,
encoder_out: torch.Tensor,
encoder_out_mask: torch.Tensor,
sematic_embeds: torch.Tensor,
ys_pad_lens: torch.Tensor,
r_ys_in_pad: torch.Tensor = torch.empty(0),
reverse_weight: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
ys_pad_mask = make_non_pad_mask(ys_pad_lens).unsqueeze(1)
x = sematic_embeds
if self.gradient_checkpointing and self.training:
x = self.forward_layers_checkpointed(x, ys_pad_mask, encoder_out,
encoder_out_mask)
else:
x = self.forward_layers(x, ys_pad_mask, encoder_out,
encoder_out_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
return x, torch.tensor(0.0), ys_pad_lens
def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor) -> torch.Tensor:
for layer in self.decoders:
x, _, _, _ = layer(x, tgt_mask, memory, memory_mask)
for layer in self.decoders3:
x = layer(x)
return x
@torch.jit.unused
def forward_layers_checkpointed(self, x: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor) -> torch.Tensor:
for i, layer in enumerate(self.decoders):
if i == 0:
x, _, _, _ = layer(x, tgt_mask, memory, memory_mask)
else:
x, _, _, _ = ckpt.checkpoint(layer.__call__,
x,
tgt_mask,
memory,
memory_mask,
use_reentrant=False)
for layer in self.decoders3:
x = layer(x)
return x