from typing import Optional import torch from torch import nn from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES class TransducerJoint(torch.nn.Module): def __init__(self, vocab_size: int, enc_output_size: int, pred_output_size: int, join_dim: int, prejoin_linear: bool = True, postjoin_linear: bool = False, joint_mode: str = 'add', activation: str = "tanh", hat_joint: bool = False, dropout_rate: float = 0.1, hat_activation: str = 'tanh'): # TODO(Mddct): concat in future assert joint_mode in ['add'] super().__init__() self.activatoin = WENET_ACTIVATION_CLASSES[activation]() self.prejoin_linear = prejoin_linear self.postjoin_linear = postjoin_linear self.joint_mode = joint_mode if not self.prejoin_linear and not self.postjoin_linear: assert enc_output_size == pred_output_size == join_dim # torchscript compatibility self.enc_ffn: Optional[nn.Linear] = None self.pred_ffn: Optional[nn.Linear] = None if self.prejoin_linear: self.enc_ffn = nn.Linear(enc_output_size, join_dim) self.pred_ffn = nn.Linear(pred_output_size, join_dim) # torchscript compatibility self.post_ffn: Optional[nn.Linear] = None if self.postjoin_linear: self.post_ffn = nn.Linear(join_dim, join_dim) # NOTE: in vocab_size self.hat_joint = hat_joint self.vocab_size = vocab_size self.ffn_out: Optional[torch.nn.Linear] = None if not self.hat_joint: self.ffn_out = nn.Linear(join_dim, vocab_size) self.blank_pred: Optional[torch.nn.Module] = None self.token_pred: Optional[torch.nn.Module] = None if self.hat_joint: self.blank_pred = torch.nn.Sequential( torch.nn.Tanh(), torch.nn.Dropout(dropout_rate), torch.nn.Linear(join_dim, 1), torch.nn.LogSigmoid()) self.token_pred = torch.nn.Sequential( WENET_ACTIVATION_CLASSES[hat_activation](), torch.nn.Dropout(dropout_rate), torch.nn.Linear(join_dim, self.vocab_size - 1)) def forward(self, enc_out: torch.Tensor, pred_out: torch.Tensor, pre_project: bool = True) -> torch.Tensor: """ Args: enc_out (torch.Tensor): [B, T, E] pred_out (torch.Tensor): [B, T, P] Return: [B,T,U,V] """ if (pre_project and self.prejoin_linear and self.enc_ffn is not None and self.pred_ffn is not None): enc_out = self.enc_ffn(enc_out) # [B,T,E] -> [B,T,D] pred_out = self.pred_ffn(pred_out) if enc_out.ndim != 4: enc_out = enc_out.unsqueeze(2) # [B,T,D] -> [B,T,1,D] if pred_out.ndim != 4: pred_out = pred_out.unsqueeze(1) # [B,U,D] -> [B,1,U,D] # TODO(Mddct): concat joint _ = self.joint_mode out = enc_out + pred_out # [B,T,U,V] if self.postjoin_linear and self.post_ffn is not None: out = self.post_ffn(out) if not self.hat_joint and self.ffn_out is not None: out = self.activatoin(out) out = self.ffn_out(out) return out else: assert self.blank_pred is not None assert self.token_pred is not None blank_logp = self.blank_pred(out) # [B,T,U,1] # scale blank logp scale_logp = torch.clamp(1 - torch.exp(blank_logp), min=1e-6) label_logp = self.token_pred(out).log_softmax( dim=-1) # [B,T,U,vocab-1] # scale token logp label_logp = torch.log(scale_logp) + label_logp out = torch.cat((blank_logp, label_logp), dim=-1) # [B,T,U,vocab] return out