|
"""Auxiliary task implementation for transducer models.""" |
|
|
|
from itertools import chain |
|
from typing import List |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface |
|
|
|
|
|
class AuxiliaryTask(torch.nn.Module): |
|
"""Auxiliary task module.""" |
|
|
|
def __init__( |
|
self, |
|
decoder: Union[torch.nn.Module, TransducerDecoderInterface], |
|
joint_network: torch.nn.Module, |
|
rnnt_criterion: torch.nn.Module, |
|
aux_task_type: str, |
|
aux_task_weight: int, |
|
encoder_out_dim: int, |
|
joint_dim: int, |
|
): |
|
"""Auxiliary task initialization. |
|
|
|
Args: |
|
decoder: Decoder module |
|
joint_network: Joint network module |
|
aux_task_type: Auxiliary task type |
|
aux_task_weight: Auxiliary task weight |
|
encoder_out: Encoder output dimension |
|
joint_dim: Joint space dimension |
|
|
|
""" |
|
super().__init__() |
|
|
|
self.rnnt_criterion = rnnt_criterion |
|
|
|
self.mlp_net = torch.nn.Sequential( |
|
torch.nn.Linear(encoder_out_dim, joint_dim), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(joint_dim, joint_dim), |
|
) |
|
|
|
self.decoder = decoder |
|
self.joint_network = joint_network |
|
|
|
self.aux_task_type = aux_task_type |
|
self.aux_task_weight = aux_task_weight |
|
|
|
def forward( |
|
self, |
|
enc_out_aux: List, |
|
dec_out: torch.Tensor, |
|
main_joint: torch.Tensor, |
|
target: torch.Tensor, |
|
pred_len: torch.Tensor, |
|
target_len: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Forward auxiliary task. |
|
|
|
Args: |
|
enc_out_aux: List of encoder intermediate outputs |
|
dec_out: Decoder outputs |
|
main_joint: Joint output for main task |
|
target: Target labels |
|
pred_len: Prediction lengths |
|
target_len: Target lengths |
|
|
|
Returns: |
|
: (Weighted auxiliary transducer loss, Weighted auxiliary symmetric KL loss) |
|
|
|
""" |
|
aux_trans = 0 |
|
aux_symm_kl = 0 |
|
|
|
for p in chain(self.decoder.parameters(), self.joint_network.parameters()): |
|
p.requires_grad = False |
|
|
|
for i, enc_aux in enumerate(enc_out_aux): |
|
aux_mlp = self.mlp_net(enc_aux) |
|
|
|
aux_joint = self.joint_network( |
|
aux_mlp.unsqueeze(2), |
|
dec_out.unsqueeze(1), |
|
is_aux=True, |
|
) |
|
|
|
if self.aux_task_type != "symm_kl_div": |
|
aux_trans += self.rnnt_criterion( |
|
aux_joint, |
|
target, |
|
pred_len, |
|
target_len, |
|
) |
|
|
|
if self.aux_task_type != "default": |
|
aux_symm_kl += F.kl_div( |
|
F.log_softmax(main_joint, dim=-1), |
|
F.softmax(aux_joint, dim=-1), |
|
reduction="mean", |
|
) + F.kl_div( |
|
F.log_softmax(aux_joint, dim=-1), |
|
F.softmax(main_joint, dim=-1), |
|
reduction="mean", |
|
) |
|
|
|
for p in chain(self.decoder.parameters(), self.joint_network.parameters()): |
|
p.requires_grad = True |
|
|
|
return self.aux_task_weight * aux_trans, self.aux_task_weight * aux_symm_kl |
|
|