tobiasc's picture
Initial commit
ad16788
"""Auxiliary task implementation for transducer models."""
from itertools import chain
from typing import List
from typing import Tuple
from typing import Union
import torch
import torch.nn.functional as F
from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface
class AuxiliaryTask(torch.nn.Module):
"""Auxiliary task module."""
def __init__(
self,
decoder: Union[torch.nn.Module, TransducerDecoderInterface],
joint_network: torch.nn.Module,
rnnt_criterion: torch.nn.Module,
aux_task_type: str,
aux_task_weight: int,
encoder_out_dim: int,
joint_dim: int,
):
"""Auxiliary task initialization.
Args:
decoder: Decoder module
joint_network: Joint network module
aux_task_type: Auxiliary task type
aux_task_weight: Auxiliary task weight
encoder_out: Encoder output dimension
joint_dim: Joint space dimension
"""
super().__init__()
self.rnnt_criterion = rnnt_criterion
self.mlp_net = torch.nn.Sequential(
torch.nn.Linear(encoder_out_dim, joint_dim),
torch.nn.ReLU(),
torch.nn.Linear(joint_dim, joint_dim),
)
self.decoder = decoder
self.joint_network = joint_network
self.aux_task_type = aux_task_type
self.aux_task_weight = aux_task_weight
def forward(
self,
enc_out_aux: List,
dec_out: torch.Tensor,
main_joint: torch.Tensor,
target: torch.Tensor,
pred_len: torch.Tensor,
target_len: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward auxiliary task.
Args:
enc_out_aux: List of encoder intermediate outputs
dec_out: Decoder outputs
main_joint: Joint output for main task
target: Target labels
pred_len: Prediction lengths
target_len: Target lengths
Returns:
: (Weighted auxiliary transducer loss, Weighted auxiliary symmetric KL loss)
"""
aux_trans = 0
aux_symm_kl = 0
for p in chain(self.decoder.parameters(), self.joint_network.parameters()):
p.requires_grad = False
for i, enc_aux in enumerate(enc_out_aux):
aux_mlp = self.mlp_net(enc_aux)
aux_joint = self.joint_network(
aux_mlp.unsqueeze(2),
dec_out.unsqueeze(1),
is_aux=True,
)
if self.aux_task_type != "symm_kl_div":
aux_trans += self.rnnt_criterion(
aux_joint,
target,
pred_len,
target_len,
)
if self.aux_task_type != "default":
aux_symm_kl += F.kl_div(
F.log_softmax(main_joint, dim=-1),
F.softmax(aux_joint, dim=-1),
reduction="mean",
) + F.kl_div(
F.log_softmax(aux_joint, dim=-1),
F.softmax(main_joint, dim=-1),
reduction="mean",
)
for p in chain(self.decoder.parameters(), self.joint_network.parameters()):
p.requires_grad = True
return self.aux_task_weight * aux_trans, self.aux_task_weight * aux_symm_kl