File size: 2,076 Bytes
62e9ca6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# --------------------------------------------------------
# The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task (https://arxiv.org/abs/2206.05777)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/YiTrans
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------

"""
    Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/audio/multi_modality_dataset.py
"""


from typing import Optional

import numpy as np
import torch
from fairseq.data import (
    LanguagePairDataset,
)
from fairseq.data.audio.multi_modality_dataset import LangPairMaskDataset as FairseqLangPairMaskDataset

class LangPairMaskDataset(FairseqLangPairMaskDataset):
    def __init__(
        self,
        dataset: LanguagePairDataset,
        src_eos: int,
        src_bos: Optional[int] = None,
        noise_id: Optional[int] = -1,
        mask_ratio: Optional[float] = 0,
        mask_type: Optional[str] = "random",
    ):
        super.__init__(
            dataset,
            src_eos,
            src_bos,
            noise_id,
            mask_ratio,
            mask_type,
        )
    def mask_src_tokens(self, sample):
        src_item = sample["source"]
        mask = None
        if self.mask_type == "random":
            mask = torch.rand(len(src_item)).le(self.mask_ratio)
        else:
            mask = torch.ones(len(src_item))
            mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
            mask = mask.eq(1)
        if src_item[0] == self.src_bos:
            mask[0] = False
        if src_item[-1] == self.src_eos:
            mask[-1] = False
        mask_src_item = src_item.masked_fill(mask, self.noise_id)
        smp = sample
        smp["source"] = mask_src_item
        return smp

    def collater(self, samples, pad_to_length=None):
        return self.dataset.collater(samples, pad_to_length=pad_to_length)