File size: 3,437 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Auxiliary task implementation for transducer models."""

from itertools import chain
from typing import List
from typing import Tuple
from typing import Union

import torch
import torch.nn.functional as F

from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface


class AuxiliaryTask(torch.nn.Module):
    """Auxiliary task module."""

    def __init__(
        self,
        decoder: Union[torch.nn.Module, TransducerDecoderInterface],
        joint_network: torch.nn.Module,
        rnnt_criterion: torch.nn.Module,
        aux_task_type: str,
        aux_task_weight: int,
        encoder_out_dim: int,
        joint_dim: int,
    ):
        """Auxiliary task initialization.

        Args:
            decoder: Decoder module
            joint_network: Joint network module
            aux_task_type: Auxiliary task type
            aux_task_weight: Auxiliary task weight
            encoder_out: Encoder output dimension
            joint_dim: Joint space dimension

        """
        super().__init__()

        self.rnnt_criterion = rnnt_criterion

        self.mlp_net = torch.nn.Sequential(
            torch.nn.Linear(encoder_out_dim, joint_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(joint_dim, joint_dim),
        )

        self.decoder = decoder
        self.joint_network = joint_network

        self.aux_task_type = aux_task_type
        self.aux_task_weight = aux_task_weight

    def forward(
        self,
        enc_out_aux: List,
        dec_out: torch.Tensor,
        main_joint: torch.Tensor,
        target: torch.Tensor,
        pred_len: torch.Tensor,
        target_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward auxiliary task.

        Args:
            enc_out_aux: List of encoder intermediate outputs
            dec_out: Decoder outputs
            main_joint: Joint output for main task
            target: Target labels
            pred_len: Prediction lengths
            target_len: Target lengths

        Returns:
            : (Weighted auxiliary transducer loss, Weighted auxiliary symmetric KL loss)

        """
        aux_trans = 0
        aux_symm_kl = 0

        for p in chain(self.decoder.parameters(), self.joint_network.parameters()):
            p.requires_grad = False

        for i, enc_aux in enumerate(enc_out_aux):
            aux_mlp = self.mlp_net(enc_aux)

            aux_joint = self.joint_network(
                aux_mlp.unsqueeze(2),
                dec_out.unsqueeze(1),
                is_aux=True,
            )

            if self.aux_task_type != "symm_kl_div":
                aux_trans += self.rnnt_criterion(
                    aux_joint,
                    target,
                    pred_len,
                    target_len,
                )

            if self.aux_task_type != "default":
                aux_symm_kl += F.kl_div(
                    F.log_softmax(main_joint, dim=-1),
                    F.softmax(aux_joint, dim=-1),
                    reduction="mean",
                ) + F.kl_div(
                    F.log_softmax(aux_joint, dim=-1),
                    F.softmax(main_joint, dim=-1),
                    reduction="mean",
                )

        for p in chain(self.decoder.parameters(), self.joint_network.parameters()):
            p.requires_grad = True

        return self.aux_task_weight * aux_trans, self.aux_task_weight * aux_symm_kl