File size: 4,067 Bytes
568e264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Optional

import torch
from torch import nn
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES


class TransducerJoint(torch.nn.Module):

    def __init__(self,
                 vocab_size: int,
                 enc_output_size: int,
                 pred_output_size: int,
                 join_dim: int,
                 prejoin_linear: bool = True,
                 postjoin_linear: bool = False,
                 joint_mode: str = 'add',
                 activation: str = "tanh",
                 hat_joint: bool = False,
                 dropout_rate: float = 0.1,
                 hat_activation: str = 'tanh'):
        # TODO(Mddct): concat in future
        assert joint_mode in ['add']
        super().__init__()

        self.activatoin = WENET_ACTIVATION_CLASSES[activation]()
        self.prejoin_linear = prejoin_linear
        self.postjoin_linear = postjoin_linear
        self.joint_mode = joint_mode

        if not self.prejoin_linear and not self.postjoin_linear:
            assert enc_output_size == pred_output_size == join_dim
        # torchscript compatibility
        self.enc_ffn: Optional[nn.Linear] = None
        self.pred_ffn: Optional[nn.Linear] = None
        if self.prejoin_linear:
            self.enc_ffn = nn.Linear(enc_output_size, join_dim)
            self.pred_ffn = nn.Linear(pred_output_size, join_dim)
        # torchscript compatibility
        self.post_ffn: Optional[nn.Linear] = None
        if self.postjoin_linear:
            self.post_ffn = nn.Linear(join_dim, join_dim)

        # NOTE: <blank> in vocab_size
        self.hat_joint = hat_joint
        self.vocab_size = vocab_size
        self.ffn_out: Optional[torch.nn.Linear] = None
        if not self.hat_joint:
            self.ffn_out = nn.Linear(join_dim, vocab_size)

        self.blank_pred: Optional[torch.nn.Module] = None
        self.token_pred: Optional[torch.nn.Module] = None
        if self.hat_joint:
            self.blank_pred = torch.nn.Sequential(
                torch.nn.Tanh(), torch.nn.Dropout(dropout_rate),
                torch.nn.Linear(join_dim, 1), torch.nn.LogSigmoid())
            self.token_pred = torch.nn.Sequential(
                WENET_ACTIVATION_CLASSES[hat_activation](),
                torch.nn.Dropout(dropout_rate),
                torch.nn.Linear(join_dim, self.vocab_size - 1))

    def forward(self,
                enc_out: torch.Tensor,
                pred_out: torch.Tensor,
                pre_project: bool = True) -> torch.Tensor:
        """
        Args:
            enc_out (torch.Tensor): [B, T, E]
            pred_out (torch.Tensor): [B, T, P]
        Return:
            [B,T,U,V]
        """
        if (pre_project and self.prejoin_linear and self.enc_ffn is not None
                and self.pred_ffn is not None):
            enc_out = self.enc_ffn(enc_out)  # [B,T,E] -> [B,T,D]
            pred_out = self.pred_ffn(pred_out)
        if enc_out.ndim != 4:
            enc_out = enc_out.unsqueeze(2)  # [B,T,D] -> [B,T,1,D]
        if pred_out.ndim != 4:
            pred_out = pred_out.unsqueeze(1)  # [B,U,D] -> [B,1,U,D]

        # TODO(Mddct): concat joint
        _ = self.joint_mode
        out = enc_out + pred_out  # [B,T,U,V]

        if self.postjoin_linear and self.post_ffn is not None:
            out = self.post_ffn(out)

        if not self.hat_joint and self.ffn_out is not None:
            out = self.activatoin(out)
            out = self.ffn_out(out)
            return out
        else:
            assert self.blank_pred is not None
            assert self.token_pred is not None
            blank_logp = self.blank_pred(out)  # [B,T,U,1]

            # scale blank logp
            scale_logp = torch.clamp(1 - torch.exp(blank_logp), min=1e-6)
            label_logp = self.token_pred(out).log_softmax(
                dim=-1)  # [B,T,U,vocab-1]
            # scale token logp
            label_logp = torch.log(scale_logp) + label_logp

            out = torch.cat((blank_logp, label_logp), dim=-1)  # [B,T,U,vocab]
            return out