Spaces:
Running
Running
from typing import List, Optional, Tuple | |
import torch | |
from torch import nn | |
from modules.wenet_extractor.utils.common import get_activation, get_rnn | |
def ApplyPadding(input, padding, pad_value) -> torch.Tensor: | |
""" | |
Args: | |
input: [bs, max_time_step, dim] | |
padding: [bs, max_time_step] | |
""" | |
return padding * pad_value + input * (1 - padding) | |
class PredictorBase(torch.nn.Module): | |
# NOTE(Mddct): We can use ABC abstract here, but | |
# keep this class simple enough for now | |
def __init__(self) -> None: | |
super().__init__() | |
def init_state( | |
self, batch_size: int, device: torch.device, method: str = "zero" | |
) -> List[torch.Tensor]: | |
_, _, _ = batch_size, method, device | |
raise NotImplementedError("this is a base precictor") | |
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |
_ = cache | |
raise NotImplementedError("this is a base precictor") | |
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: | |
_ = cache | |
raise NotImplementedError("this is a base precictor") | |
def forward( | |
self, | |
input: torch.Tensor, | |
cache: Optional[List[torch.Tensor]] = None, | |
): | |
( | |
_, | |
_, | |
) = ( | |
input, | |
cache, | |
) | |
raise NotImplementedError("this is a base precictor") | |
def forward_step( | |
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor] | |
) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
( | |
_, | |
_, | |
_, | |
) = ( | |
input, | |
padding, | |
cache, | |
) | |
raise NotImplementedError("this is a base precictor") | |
class RNNPredictor(PredictorBase): | |
def __init__( | |
self, | |
voca_size: int, | |
embed_size: int, | |
output_size: int, | |
embed_dropout: float, | |
hidden_size: int, | |
num_layers: int, | |
bias: bool = True, | |
rnn_type: str = "lstm", | |
dropout: float = 0.1, | |
) -> None: | |
super().__init__() | |
self.n_layers = num_layers | |
self.hidden_size = hidden_size | |
# disable rnn base out projection | |
self.embed = nn.Embedding(voca_size, embed_size) | |
self.dropout = nn.Dropout(embed_dropout) | |
# NOTE(Mddct): rnn base from torch not support layer norm | |
# will add layer norm and prune value in cell and layer | |
# ref: https://github.com/Mddct/neural-lm/blob/main/models/gru_cell.py | |
self.rnn = get_rnn(rnn_type=rnn_type)( | |
input_size=embed_size, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
bias=bias, | |
batch_first=True, | |
dropout=dropout, | |
) | |
self.projection = nn.Linear(hidden_size, output_size) | |
def forward( | |
self, | |
input: torch.Tensor, | |
cache: Optional[List[torch.Tensor]] = None, | |
) -> torch.Tensor: | |
""" | |
Args: | |
input (torch.Tensor): [batch, max_time). | |
padding (torch.Tensor): [batch, max_time] | |
cache : rnn predictor cache[0] == state_m | |
cache[1] == state_c | |
Returns: | |
output: [batch, max_time, output_size] | |
""" | |
# NOTE(Mddct): we don't use pack input format | |
embed = self.embed(input) # [batch, max_time, emb_size] | |
embed = self.dropout(embed) | |
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None | |
if cache is None: | |
state = self.init_state(batch_size=input.size(0), device=input.device) | |
states = (state[0], state[1]) | |
else: | |
assert len(cache) == 2 | |
states = (cache[0], cache[1]) | |
out, (m, c) = self.rnn(embed, states) | |
out = self.projection(out) | |
# NOTE(Mddct): Although we don't use staate in transducer | |
# training forward, we need make it right for padding value | |
# so we create forward_step for infering, forward for training | |
_, _ = m, c | |
return out | |
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |
""" | |
Args: | |
cache: [state_m, state_c] | |
state_ms: [1*n_layers, bs, ...] | |
state_cs: [1*n_layers, bs, ...] | |
Returns: | |
new_cache: [[state_m_1, state_c_1], [state_m_2, state_c_2]...] | |
""" | |
assert len(cache) == 2 | |
state_ms = cache[0] | |
state_cs = cache[1] | |
assert state_ms.size(1) == state_cs.size(1) | |
new_cache: List[List[torch.Tensor]] = [] | |
for state_m, state_c in zip( | |
torch.split(state_ms, 1, dim=1), torch.split(state_cs, 1, dim=1) | |
): | |
new_cache.append([state_m, state_c]) | |
return new_cache | |
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: | |
""" | |
Args: | |
cache : [[state_m_1, state_c_1], [state_m_1, state_c_1]...] | |
Returns: | |
new_caceh: [state_ms, state_cs], | |
state_ms: [1*n_layers, bs, ...] | |
state_cs: [1*n_layers, bs, ...] | |
""" | |
state_ms = torch.cat([states[0] for states in cache], dim=1) | |
state_cs = torch.cat([states[1] for states in cache], dim=1) | |
return [state_ms, state_cs] | |
def init_state( | |
self, | |
batch_size: int, | |
device: torch.device, | |
method: str = "zero", | |
) -> List[torch.Tensor]: | |
assert batch_size > 0 | |
# TODO(Mddct): xavier init method | |
_ = method | |
return [ | |
torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device), | |
torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device), | |
] | |
def forward_step( | |
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor] | |
) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
""" | |
Args: | |
input (torch.Tensor): [batch_size, time_step=1] | |
padding (torch.Tensor): [batch_size,1], 1 is padding value | |
cache : rnn predictor cache[0] == state_m | |
cache[1] == state_c | |
""" | |
assert len(cache) == 2 | |
state_m, state_c = cache[0], cache[1] | |
embed = self.embed(input) # [batch, 1, emb_size] | |
embed = self.dropout(embed) | |
out, (m, c) = self.rnn(embed, (state_m, state_c)) | |
out = self.projection(out) | |
m = ApplyPadding(m, padding.unsqueeze(0), state_m) | |
c = ApplyPadding(c, padding.unsqueeze(0), state_c) | |
return (out, [m, c]) | |
class EmbeddingPredictor(PredictorBase): | |
"""Embedding predictor | |
Described in: | |
https://arxiv.org/pdf/2109.07513.pdf | |
embed-> proj -> layer norm -> swish | |
""" | |
def __init__( | |
self, | |
voca_size: int, | |
embed_size: int, | |
embed_dropout: float, | |
n_head: int, | |
history_size: int = 2, | |
activation: str = "swish", | |
bias: bool = False, | |
layer_norm_epsilon: float = 1e-5, | |
) -> None: | |
super().__init__() | |
# multi head | |
self.num_heads = n_head | |
self.embed_size = embed_size | |
self.context_size = history_size + 1 | |
self.pos_embed = torch.nn.Linear( | |
embed_size * self.context_size, self.num_heads, bias=bias | |
) | |
self.embed = nn.Embedding(voca_size, self.embed_size) | |
self.embed_dropout = nn.Dropout(p=embed_dropout) | |
self.ffn = nn.Linear(self.embed_size, self.embed_size) | |
self.norm = nn.LayerNorm(self.embed_size, eps=layer_norm_epsilon) | |
self.activatoin = get_activation(activation) | |
def init_state( | |
self, batch_size: int, device: torch.device, method: str = "zero" | |
) -> List[torch.Tensor]: | |
assert batch_size > 0 | |
_ = method | |
return [ | |
torch.zeros( | |
batch_size, self.context_size - 1, self.embed_size, device=device | |
), | |
] | |
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |
""" | |
Args: | |
cache : [history] | |
history: [bs, ...] | |
Returns: | |
new_ache : [[history_1], [history_2], [history_3]...] | |
""" | |
assert len(cache) == 1 | |
cache_0 = cache[0] | |
history: List[List[torch.Tensor]] = [] | |
for h in torch.split(cache_0, 1, dim=0): | |
history.append([h]) | |
return history | |
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: | |
""" | |
Args: | |
cache : [[history_1], [history_2], [history3]...] | |
Returns: | |
new_caceh: [history], | |
history: [bs, ...] | |
""" | |
history = torch.cat([h[0] for h in cache], dim=0) | |
return [history] | |
def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None): | |
"""forward for training""" | |
input = self.embed(input) # [bs, seq_len, embed] | |
input = self.embed_dropout(input) | |
if cache is None: | |
zeros = self.init_state(input.size(0), device=input.device)[0] | |
else: | |
assert len(cache) == 1 | |
zeros = cache[0] | |
input = torch.cat( | |
(zeros, input), dim=1 | |
) # [bs, context_size-1 + seq_len, embed] | |
input = input.unfold(1, self.context_size, 1).permute( | |
0, 1, 3, 2 | |
) # [bs, seq_len, context_size, embed] | |
# multi head pos: [n_head, embed, context_size] | |
multi_head_pos = self.pos_embed.weight.view( | |
self.num_heads, self.embed_size, self.context_size | |
) | |
# broadcast dot attenton | |
input_expand = input.unsqueeze(2) # [bs, seq_len, 1, context_size, embed] | |
multi_head_pos = multi_head_pos.permute( | |
0, 2, 1 | |
) # [num_heads, context_size, embed] | |
# [bs, seq_len, num_heads, context_size, embed] | |
weight = input_expand * multi_head_pos | |
weight = weight.sum(dim=-1, keepdim=False).unsqueeze( | |
3 | |
) # [bs, seq_len, num_heads, 1, context_size] | |
output = weight.matmul(input_expand).squeeze( | |
dim=3 | |
) # [bs, seq_len, num_heads, embed] | |
output = output.sum(dim=2) # [bs, seq_len, embed] | |
output = output / (self.num_heads * self.context_size) | |
output = self.ffn(output) | |
output = self.norm(output) | |
output = self.activatoin(output) | |
return output | |
def forward_step( | |
self, | |
input: torch.Tensor, | |
padding: torch.Tensor, | |
cache: List[torch.Tensor], | |
) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
"""forward step for inference | |
Args: | |
input (torch.Tensor): [batch_size, time_step=1] | |
padding (torch.Tensor): [batch_size,1], 1 is padding value | |
cache: for embedding predictor, cache[0] == history | |
""" | |
assert input.size(1) == 1 | |
assert len(cache) == 1 | |
history = cache[0] | |
assert history.size(1) == self.context_size - 1 | |
input = self.embed(input) # [bs, 1, embed] | |
input = self.embed_dropout(input) | |
context_input = torch.cat((history, input), dim=1) | |
input_expand = context_input.unsqueeze(1).unsqueeze( | |
2 | |
) # [bs, 1, 1, context_size, embed] | |
# multi head pos: [n_head, embed, context_size] | |
multi_head_pos = self.pos_embed.weight.view( | |
self.num_heads, self.embed_size, self.context_size | |
) | |
multi_head_pos = multi_head_pos.permute( | |
0, 2, 1 | |
) # [num_heads, context_size, embed] | |
# [bs, 1, num_heads, context_size, embed] | |
weight = input_expand * multi_head_pos | |
weight = weight.sum(dim=-1, keepdim=False).unsqueeze( | |
3 | |
) # [bs, 1, num_heads, 1, context_size] | |
output = weight.matmul(input_expand).squeeze(dim=3) # [bs, 1, num_heads, embed] | |
output = output.sum(dim=2) # [bs, 1, embed] | |
output = output / (self.num_heads * self.context_size) | |
output = self.ffn(output) | |
output = self.norm(output) | |
output = self.activatoin(output) | |
new_cache = context_input[:, 1:, :] | |
# TODO(Mddct): we need padding new_cache in future | |
# new_cache = ApplyPadding(history, padding, new_cache) | |
return (output, [new_cache]) | |
class ConvPredictor(PredictorBase): | |
def __init__( | |
self, | |
voca_size: int, | |
embed_size: int, | |
embed_dropout: float, | |
history_size: int = 2, | |
activation: str = "relu", | |
bias: bool = False, | |
layer_norm_epsilon: float = 1e-5, | |
) -> None: | |
super().__init__() | |
assert history_size >= 0 | |
self.embed_size = embed_size | |
self.context_size = history_size + 1 | |
self.embed = nn.Embedding(voca_size, self.embed_size) | |
self.embed_dropout = nn.Dropout(p=embed_dropout) | |
self.conv = nn.Conv1d( | |
in_channels=embed_size, | |
out_channels=embed_size, | |
kernel_size=self.context_size, | |
padding=0, | |
groups=embed_size, | |
bias=bias, | |
) | |
self.norm = nn.LayerNorm(embed_size, eps=layer_norm_epsilon) | |
self.activatoin = get_activation(activation) | |
def init_state( | |
self, batch_size: int, device: torch.device, method: str = "zero" | |
) -> List[torch.Tensor]: | |
assert batch_size > 0 | |
assert method == "zero" | |
return [ | |
torch.zeros( | |
batch_size, self.context_size - 1, self.embed_size, device=device | |
) | |
] | |
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: | |
""" | |
Args: | |
cache : [[history_1], [history_2], [history3]...] | |
Returns: | |
new_caceh: [history], | |
history: [bs, ...] | |
""" | |
history = torch.cat([h[0] for h in cache], dim=0) | |
return [history] | |
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |
""" | |
Args: | |
cache : [history] | |
history: [bs, ...] | |
Returns: | |
new_ache : [[history_1], [history_2], [history_3]...] | |
""" | |
assert len(cache) == 1 | |
cache_0 = cache[0] | |
history: List[List[torch.Tensor]] = [] | |
for h in torch.split(cache_0, 1, dim=0): | |
history.append([h]) | |
return history | |
def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None): | |
"""forward for training""" | |
input = self.embed(input) # [bs, seq_len, embed] | |
input = self.embed_dropout(input) | |
if cache is None: | |
zeros = self.init_state(input.size(0), device=input.device)[0] | |
else: | |
assert len(cache) == 1 | |
zeros = cache[0] | |
input = torch.cat( | |
(zeros, input), dim=1 | |
) # [bs, context_size-1 + seq_len, embed] | |
input = input.permute(0, 2, 1) | |
out = self.conv(input).permute(0, 2, 1) | |
out = self.activatoin(self.norm(out)) | |
return out | |
def forward_step( | |
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor] | |
) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
"""forward step for inference | |
Args: | |
input (torch.Tensor): [batch_size, time_step=1] | |
padding (torch.Tensor): [batch_size,1], 1 is padding value | |
cache: for embedding predictor, cache[0] == history | |
""" | |
assert input.size(1) == 1 | |
assert len(cache) == 1 | |
history = cache[0] | |
assert history.size(1) == self.context_size - 1 | |
input = self.embed(input) # [bs, 1, embed] | |
input = self.embed_dropout(input) | |
context_input = torch.cat((history, input), dim=1) | |
input = context_input.permute(0, 2, 1) | |
out = self.conv(input).permute(0, 2, 1) | |
out = self.activatoin(self.norm(out)) | |
new_cache = context_input[:, 1:, :] | |
# TODO(Mddct): apply padding in future | |
return (out, [new_cache]) | |