OSUM / wenet /transducer /joint.py
tomxxie
适配zeroGPU
568e264
raw
history blame
4.07 kB
from typing import Optional
import torch
from torch import nn
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES
class TransducerJoint(torch.nn.Module):
def __init__(self,
vocab_size: int,
enc_output_size: int,
pred_output_size: int,
join_dim: int,
prejoin_linear: bool = True,
postjoin_linear: bool = False,
joint_mode: str = 'add',
activation: str = "tanh",
hat_joint: bool = False,
dropout_rate: float = 0.1,
hat_activation: str = 'tanh'):
# TODO(Mddct): concat in future
assert joint_mode in ['add']
super().__init__()
self.activatoin = WENET_ACTIVATION_CLASSES[activation]()
self.prejoin_linear = prejoin_linear
self.postjoin_linear = postjoin_linear
self.joint_mode = joint_mode
if not self.prejoin_linear and not self.postjoin_linear:
assert enc_output_size == pred_output_size == join_dim
# torchscript compatibility
self.enc_ffn: Optional[nn.Linear] = None
self.pred_ffn: Optional[nn.Linear] = None
if self.prejoin_linear:
self.enc_ffn = nn.Linear(enc_output_size, join_dim)
self.pred_ffn = nn.Linear(pred_output_size, join_dim)
# torchscript compatibility
self.post_ffn: Optional[nn.Linear] = None
if self.postjoin_linear:
self.post_ffn = nn.Linear(join_dim, join_dim)
# NOTE: <blank> in vocab_size
self.hat_joint = hat_joint
self.vocab_size = vocab_size
self.ffn_out: Optional[torch.nn.Linear] = None
if not self.hat_joint:
self.ffn_out = nn.Linear(join_dim, vocab_size)
self.blank_pred: Optional[torch.nn.Module] = None
self.token_pred: Optional[torch.nn.Module] = None
if self.hat_joint:
self.blank_pred = torch.nn.Sequential(
torch.nn.Tanh(), torch.nn.Dropout(dropout_rate),
torch.nn.Linear(join_dim, 1), torch.nn.LogSigmoid())
self.token_pred = torch.nn.Sequential(
WENET_ACTIVATION_CLASSES[hat_activation](),
torch.nn.Dropout(dropout_rate),
torch.nn.Linear(join_dim, self.vocab_size - 1))
def forward(self,
enc_out: torch.Tensor,
pred_out: torch.Tensor,
pre_project: bool = True) -> torch.Tensor:
"""
Args:
enc_out (torch.Tensor): [B, T, E]
pred_out (torch.Tensor): [B, T, P]
Return:
[B,T,U,V]
"""
if (pre_project and self.prejoin_linear and self.enc_ffn is not None
and self.pred_ffn is not None):
enc_out = self.enc_ffn(enc_out) # [B,T,E] -> [B,T,D]
pred_out = self.pred_ffn(pred_out)
if enc_out.ndim != 4:
enc_out = enc_out.unsqueeze(2) # [B,T,D] -> [B,T,1,D]
if pred_out.ndim != 4:
pred_out = pred_out.unsqueeze(1) # [B,U,D] -> [B,1,U,D]
# TODO(Mddct): concat joint
_ = self.joint_mode
out = enc_out + pred_out # [B,T,U,V]
if self.postjoin_linear and self.post_ffn is not None:
out = self.post_ffn(out)
if not self.hat_joint and self.ffn_out is not None:
out = self.activatoin(out)
out = self.ffn_out(out)
return out
else:
assert self.blank_pred is not None
assert self.token_pred is not None
blank_logp = self.blank_pred(out) # [B,T,U,1]
# scale blank logp
scale_logp = torch.clamp(1 - torch.exp(blank_logp), min=1e-6)
label_logp = self.token_pred(out).log_softmax(
dim=-1) # [B,T,U,vocab-1]
# scale token logp
label_logp = torch.log(scale_logp) + label_logp
out = torch.cat((blank_logp, label_logp), dim=-1) # [B,T,U,vocab]
return out