File size: 6,628 Bytes
bc752b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import argparse
import logging
import sys
import time
from typing import Dict, Optional, Tuple

import numpy as np
import six
import torch

from vita.model.multimodal_encoder.whale.module.component.mamba import MambaSSM
from vita.model.multimodal_encoder.whale.module.component.subsampling import Subsampling
from vita.model.multimodal_encoder.whale.module.component.transformer import Transformer
from vita.model.multimodal_encoder.whale.utils import make_pad_mask


def add_encoder_args(group):
    """Add Encoder common arguments."""
    group.add_argument(
        "--encoder-layer-config",
        type=str,
        default="tdnn-dtc",
        help="Layer config of encoder. Format layername-layername-..., default(conv1d-fsmn-rnn)",
    )
    group.add_argument(
        "--encoder-input-dim",
        type=int,
        default=256,
        help="Input dim of encoder. Must equal to the input dim of the first Component (default=40)",
    )
    group.add_argument(
        "--encoder-output-dim",
        type=int,
        default=256,
        help="Output dim of encoder. Must enqual to the output dim of the last Component ! (default=256)",
    )
    # Add args of all kinds of components.
    # If you add a new component, DO NOT forget to add args to add_component_args func.
    group = Transformer.add_arguments(group)
    group = Subsampling.add_arguments(group)
    group = MambaSSM.add_arguments(group)
    return group


def assign_args_from_dict(args, dict, prefix_key=None):
    if prefix_key is not None:
        dict = dict[prefix_key]
    for k, v in dict.items():
        k_args = k.replace("-", "_")
        if hasattr(args, k_args):
            setattr(args, k_args, dict[k])
    return args


class whaleEncoder(torch.nn.Module):
    def __init__(self, input_dim, overview_conf=None, para_conf=None, global_cmvn=None):
        super(whaleEncoder, self).__init__()

        parser = argparse.ArgumentParser()
        add_encoder_args(parser)
        args, _ = parser.parse_known_args()

        assign_args_from_dict(args, overview_conf)
        # assign_args_from_dict(args, para_conf)

        self.config = args.encoder_layer_config.split("-")
        encoder_input_dim = args.encoder_input_dim
        encoder_output_dim = args.encoder_output_dim
        prev_output_dim = encoder_input_dim
        prev_component_name = "encoder"
        self.enc = torch.nn.ModuleList([])
        for name in self.config:
            assign_args_from_dict(args, para_conf[name])
            if len(name.split("_")) == 2:
                name = name.split("_")[0]
            elif len(name.split("_")) == 1:
                name = name
            else:
                logging.error("WRONG CONFIG! {} is not valid".format("encoder", name))
                sys.exit()

            if name == "transformer":
                self.enc.append(Transformer(args))
            elif name == "subsampling":
                self.enc.append(Subsampling(args))
            elif name == "mamba":
                self.enc.append(MambaSSM(args))
            else:
                print("{} is not supported now!".format(name))
                return NotImplemented
            component_input_dim = getattr(args, name + "_input_dim")
            if component_input_dim != prev_output_dim:
                # This is the first layer
                logging.error(
                    "WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-input-dim ({})".format(
                        prev_component_name, prev_output_dim, name, component_input_dim
                    )
                )
                sys.exit()
            prev_output_dim = getattr(args, name + "_output_dim")
            prev_component_name = name

        self.global_cmvn = global_cmvn
        if prev_output_dim != encoder_output_dim:
            logging.error(
                "WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-output-dim ({}, the last component)".format(
                    "encoder", encoder_output_dim, name, prev_output_dim
                )
            )
            sys.exit()

        self._output_size = encoder_output_dim

        num_params = sum(p.numel() for p in self.parameters())
        print("the number of whale encoder params: {}M".format(num_params / 1024 / 1024))

    def output_size(self) -> int:
        return self._output_size

    @torch.jit.unused
    def forward(self, xs, ilens, decoding_chunk_size=None, num_decoding_left_chunks=None):
        # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[List[int]], Optional[Tensor]]
        """Encoder forward

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :return: batch of hidden state sequences (B, Tmax, eprojs)
        :rtype: torch.Tensor
        """

        if decoding_chunk_size is not None and num_decoding_left_chunks is not None:
            for layer in self.enc:
                if hasattr(layer, "chunk_size"):
                    layer.chunk_size = decoding_chunk_size
                if hasattr(layer, "left_chunks"):
                    layer.left_chunks = num_decoding_left_chunks
                if hasattr(layer, "transformer_dynamic_chunks"):
                    layer.transformer_dynamic_chunks = False

        assert (len(xs.shape)) == 3
        T = xs.size(1)
        masks = ~make_pad_mask(ilens, T).unsqueeze(1)  # (B, 1, T)
        if self.global_cmvn is not None:
            xs = self.global_cmvn(xs)
        for module in self.enc:
            xs, ilens, masks = module(xs, ilens, masks)
        return xs, masks

    @torch.jit.export
    def infer(self, xs_pad, buffer, buffer_index, buffer_out):
        if self.global_cmvn is not None:
            xs = self.global_cmvn(xs)
        for module in self.enc:
            xs_pad, buffer, buffer_index, buffer_out = module.infer(
                xs_pad, buffer, buffer_index, buffer_out
            )
        return xs_pad, buffer, buffer_index, buffer_out

    @torch.jit.export
    def infer_hidden(self, xs_pad, buffer, buffer_index, buffer_out, hidden_out):
        if self.global_cmvn is not None:
            xs = self.global_cmvn(xs)
        for module in self.enc:
            xs_pad, buffer, buffer_index, buffer_out, hidden_out = module.infer_hidden(
                xs_pad, buffer, buffer_index, buffer_out, hidden_out
            )
        return xs_pad, buffer, buffer_index, buffer_out, hidden_out

    @torch.jit.ignore(drop=True)
    def get_extra_loss(self) -> Dict[str, torch.Tensor]:
        return None