File size: 13,438 Bytes
568e264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import math
from typing import Dict, Optional, Tuple, Union
import torch

from wenet.ssl.bestrq.mask import compute_mask_indices_v2
from wenet.ssl.wav2vec2.quantizer import Wav2vecGumbelVectorQuantizer
from wenet.ssl.wav2vec2.wav2vec2_model import (_compute_contrastive_loss,
                                               _sample_negative_indices)
from wenet.transformer.attention import RelPositionMultiHeadedAttention

from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder
from wenet.transformer.encoder_layer import ConformerEncoderLayer
from wenet.utils.mask import make_non_pad_mask


class W2VBERTModel(torch.nn.Module):

    def __init__(
        self,
        encoder: Union[ConformerEncoder, TransformerEncoder],
        embedding_dim: int = 256,
        num_embeddings: int = 320,
        num_codebooks: int = 1,
        mask_prob: float = 0.065,
        mask_length: int = 10,
        min_masks: int = 2,
        num_negatives: int = 100,
        features_regularization_weight: float = 0.01,
        max_gumbel_temperature: float = 2.0,
        min_gumbel_temperature: float = 0.1,
        gumbel_temperature_decay: float = 0.999995,
        contrastive_logits_temperature: float = 0.1,
        diversity_weight: float = 0.0,
        bias: bool = True,
        contrastive_blocks: int = 6,
        masked_blocks: int = 6,
        contrastive_weight: float = 1.0,
        mlm_weight: float = 1.0,
        warmup_steps: int = 25000,
    ) -> None:
        """ Wrap encoder to train using W2V-BERT's style

        Described in:
        https://arxiv.org/pdf/2108.06209v2.pdf

        Args:
            encoder: wenet's encoder,
                     only support conformer and transformer now
            embedding_dim: codebooks embedding dim
            num_embeddings: numbers of each codebook
            num_codebooks: numbers of codebooks i.e groups of codebook
            mask_prob: probs of mask
            mask_length: spans of masks
            min_masks: min masks for each audio
            num_negatives: numbers of negatives of each masks
            features_regularization_weight: l2 regularization weight
            max_gumbel_temperature: maximum temperature for gumbel softmax
            min_gumbel_temperature: minimum temperature for gumbel softmax
            gumbel_temperature_decay:
                decay of gumbel temperature during training
            contrastive_logits_temperature:
                the temperature in the contrastive loss.
        """
        super().__init__()
        assert mask_prob > 0.0
        assert (contrastive_blocks > 0 and masked_blocks > 0 and
                contrastive_blocks + masked_blocks == len(encoder.encoders))
        self.contrastive_blocks = contrastive_blocks
        self.masked_blocks = masked_blocks

        self.mask_prob = mask_prob
        self.mask_length = mask_length
        self.min_masks = min_masks
        self.num_negatives = num_negatives

        self.features_regularization_weight = features_regularization_weight
        self.diversity_weight = diversity_weight

        self.contrastive_weight = contrastive_weight
        self.mlm_weight = mlm_weight
        self.warmup_steps = warmup_steps
        # encoder
        self.encoder = encoder

        # quantizer
        self.num_codebooks = num_codebooks
        self.quantizer = Wav2vecGumbelVectorQuantizer(
            self.encoder.output_size(),
            num_codebooks=num_codebooks,
            num_embeddings=num_embeddings,
            embedding_dim=embedding_dim,
            hard=False,
        )
        self.max_gumbel_temp = max_gumbel_temperature
        self.min_gumbel_temp = min_gumbel_temperature
        self.gumbel_temp_decay = gumbel_temperature_decay

        self.num_codevectors_per_group = num_embeddings
        self.num_codevector_groups = num_codebooks

        self.contrastive_logits_temp = contrastive_logits_temperature

        # NOET(Mddct): mask_em is replaced by random value in Wav-BERT
        # self.mask_emb = torch.nn.parameter.Parameter(
        #     torch.empty(self.encoder.output_size()).uniform_(),
        #     requires_grad=True,
        # )
        # TODO(Mddct): support causal or lookahead mask or keep consistent with
        # wenet dynamic chunk training

        # # n softmax
        self.encoder_top_n_out = torch.nn.parameter.Parameter(
            torch.empty(num_codebooks, self.encoder.output_size(),
                        num_embeddings))
        torch.nn.init.trunc_normal_(self.encoder_top_n_out, std=0.02)
        self.bias = bias
        if bias:
            self.encoder_top_n_out_bias = torch.nn.parameter.Parameter(
                torch.empty(num_codebooks, num_embeddings))
            torch.nn.init.zeros_(self.encoder_top_n_out_bias)

        # reset parameter
        self.reset_encoder_parameter()

    def reset_encoder_parameter(self):

        def _reset_parameter(module: torch.nn.Module):
            if isinstance(module, torch.nn.Linear):
                torch.nn.init.trunc_normal_(module.weight.data,
                                            mean=0.0,
                                            std=0.02)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, torch.nn.Conv1d):
                torch.nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    k = math.sqrt(module.groups /
                                  (module.in_channels * module.kernel_size[0]))
                    torch.nn.init.uniform_(module.bias, a=-k, b=k)
            elif isinstance(module, torch.Tensor):
                torch.nn.init.trunc_normal_(module)
            else:
                raise NotImplementedError("other module not support now")

        encoders = self.encoder.encoders
        for _, layer in enumerate(encoders):
            self_attn = layer.self_attn
            _reset_parameter(self_attn.linear_q)
            _reset_parameter(self_attn.linear_k)
            _reset_parameter(self_attn.linear_v)
            _reset_parameter(self_attn.linear_out)
            if isinstance(self_attn, RelPositionMultiHeadedAttention):
                _reset_parameter(self_attn.pos_bias_u)
                _reset_parameter(self_attn.pos_bias_v)
            if isinstance(layer, ConformerEncoderLayer):
                conv1, conv2 = (layer.conv_module.pointwise_conv1,
                                layer.conv_module.depthwise_conv)
                _reset_parameter(conv1)
                _reset_parameter(conv2)

    @torch.jit.unused
    def forward(
        self,
        batch: Dict,
        device: torch.device,
    ):
        steps = batch.get('steps', None)
        xs = batch['feats'].to(device)
        xs_lens = batch['feats_lengths'].to(device)
        assert xs.size(0) == xs_lens.size(0)
        assert steps is not None

        # 1 forward subsampling
        # NOTE(Mddct): use subsampling as feature extraction
        xs, pos_emb, masks = self._forward_subsampling(xs, xs_lens)
        unmasked_xs = xs
        # 2 mask features
        masked_xs, masked_masks = self._apply_mask(xs, masks.squeeze(1))
        # 3 forward encoder blocks
        contrastive_vec, mlm_vec, out_mask = self._forward_encoder_blocks(
            masked_xs, masks, pos_emb, masks)

        # 4 constrastive branch
        gumbel_temperature = max(
            self.max_gumbel_temp * self.gumbel_temp_decay**steps,
            self.min_gumbel_temp)

        quantized_features, codevector_perplexity, targets_ids = self.quantizer(
            unmasked_xs, masks.squeeze(1), gumbel_temperature)

        sampled_negative_indices = _sample_negative_indices(
            xs.size()[:-1], self.num_negatives, masked_masks.device,
            masked_masks)

        loss_contrastive = _compute_contrastive_loss(
            quantized_features, contrastive_vec, sampled_negative_indices,
            masked_masks, self.contrastive_logits_temp, self.num_negatives)
        loss = loss_contrastive

        # scale by sample size
        # make sure that diversity loss is multiplied by `sample_size`
        # since contrastive_loss is `sum`-reduced instead of averaged
        sample_size = masked_masks.sum()
        # higher codevector_perplexity leads to lower diversity loss
        loss_diversity: Optional[torch.Tensor] = None
        if self.diversity_weight != 0.0:
            loss_diversity = (
                self.num_codevector_groups * self.num_codevectors_per_group -
                codevector_perplexity) / (self.num_codevectors_per_group *
                                          self.num_codevector_groups)
            loss_diversity = loss_diversity * sample_size
            loss = loss + self.diversity_weight * loss_diversity
        loss = loss / sample_size

        features_pen: Optional[torch.Tensor] = None
        if self.features_regularization_weight != 0.0:
            features_pen = xs.pow(2).mean()
            loss = loss + self.features_regularization_weight * features_pen

        # 5 maked lm branch
        out = mlm_vec.unsqueeze(1)
        top_n_out = self.encoder_top_n_out.unsqueeze(
            0)  # [1, num_codebooks, dim, num_embeddings]
        out = torch.matmul(out,
                           top_n_out)  # [B, num_codebooks, T', num_embeddings]
        if self.bias:
            out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2)
        num_codes = masked_masks.sum() * self.num_codebooks
        loss_mlm = self._compute_mlm_loss(out,
                                          targets_ids,
                                          mask=out_mask.squeeze(1) *
                                          masked_masks)
        ids_corr = out.argmax(dim=-1,
                              keepdim=False).transpose(1, 2) == targets_ids
        codes_acc = (ids_corr * masked_masks.unsqueeze(2)).sum() / num_codes
        # TODO(Mddct): support num codes used in batch, unique num codes
        # used in batch like bestrq

        # 6 final loss
        mlm_weight = (self.mlm_weight if steps >= self.warmup_steps else 0.1 +
                      0.9 * (steps / self.warmup_steps))
        loss = self.contrastive_weight * loss + mlm_weight * loss_mlm
        return {
            "code_ppl": codevector_perplexity.detach(),
            "features_l2": features_pen,
            "codes_acc": codes_acc.detach(),
            "loss": loss,
            "loss_contrastive": loss_contrastive / sample_size,
            "loss_diversity": loss_diversity,
            "loss_mlm": loss_mlm,
        }

    def _apply_mask(
            self, xs: torch.Tensor,
            xs_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        masks = compute_mask_indices_v2(xs.size()[:-1],
                                        ~xs_masks,
                                        self.mask_prob,
                                        self.mask_length,
                                        min_masks=self.min_masks,
                                        device=xs.device)
        masks_expand = masks.unsqueeze(-1)  # [B, T, 1]

        mask_emb = torch.normal(mean=0,
                                std=0.1,
                                size=xs.size(),
                                device=xs.device)
        xs = torch.where(masks_expand, mask_emb, xs)

        return xs, masks

    def _compute_mlm_loss(self, input: torch.Tensor, target: torch.Tensor,
                          mask: torch.Tensor) -> torch.Tensor:
        log_probs = torch.log_softmax(input, dim=-1).transpose(
            1, 2)  # [B, T', num_codebooks, num_embeddings]

        per_example_n_loss = -log_probs.gather(3, target.unsqueeze(3)).squeeze(
            3)  # [B, T', num_codebooks]

        numerator = torch.sum(per_example_n_loss * mask.unsqueeze(2))
        denominator = torch.sum(mask) + 1e-5
        loss = numerator / (denominator * self.num_codebooks)
        return loss

    def _forward_subsampling(
        self, xs: torch.Tensor, xs_lens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        masks = make_non_pad_mask(xs_lens).unsqueeze(1)  # (B, 1, T)
        if self.encoder.global_cmvn is not None:
            xs = self.encoder.global_cmvn(xs)
        xs, pos_emb, masks = self.encoder.embed(xs, masks)
        return xs, pos_emb, masks

    def _forward_encoder_blocks(
        self, xs: torch.Tensor, xs_masks: torch.Tensor, pos_emb: torch.Tensor,
        mask_pad: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        masks = xs_masks

        xs: torch.Tensor
        # forward contrastive layers get context vector for Contrastive Loss
        for layer in self.encoder.encoders[:self.contrastive_blocks]:
            xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad)
        contrastive_vec = xs

        for layer in self.encoder.encoders[self.contrastive_blocks:]:
            xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad)
        masked_vec = xs

        if self.encoder.normalize_before:
            xs = self.encoder.after_norm(xs)
            masked_vec = xs
        # Here we assume the mask is not changed in encoder layers, so just
        # return the masks before encoder layers, and the masks will be used
        # for cross attention with decoder later
        return contrastive_vec, masked_vec, masks