File size: 6,948 Bytes
e4bd7f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Author  : Xinhao Mei @CVSSP, University of Surrey
# @E-mail  : [email protected]

"""
    Implemenation of SpecAugment++,
    Adapated from Qiuqiang Kong's trochlibrosa:
    https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py
"""

import torch
import torch.nn as nn


class DropStripes:

    def __init__(self, dim, drop_width, stripes_num):
        """ Drop stripes.
        args:
            dim: int, dimension along which to drop
            drop_width: int, maximum width of stripes to drop
            stripes_num: int, how many stripes to drop
        """
        super(DropStripes, self).__init__()

        assert dim in [2, 3]  # dim 2: time; dim 3: frequency

        self.dim = dim
        self.drop_width = drop_width
        self.stripes_num = stripes_num

    def __call__(self, input):
        """input: (batch_size, channels, time_steps, freq_bins)"""

        assert input.ndimension() == 4
        batch_size = input.shape[0]
        total_width = input.shape[self.dim]

        for n in range(batch_size):
            self.transform_slice(input[n], total_width)

        return input

    def transform_slice(self, e, total_width):
        """ e: (channels, time_steps, freq_bins)"""

        for _ in range(self.stripes_num):
            distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
            bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]

            if self.dim == 2:
                e[:, bgn: bgn + distance, :] = 0
            elif self.dim == 3:
                e[:, :, bgn: bgn + distance] = 0


class MixStripes:

    def __init__(self, dim, mix_width, stripes_num):
        """ Mix stripes
        args:
            dim: int, dimension along which to mix
            mix_width: int, maximum width of stripes to mix
            stripes_num: int, how many stripes to mix
        """

        super(MixStripes, self).__init__()

        assert dim in [2, 3]

        self.dim = dim
        self.mix_width = mix_width
        self.stripes_num = stripes_num

    def __call__(self, input):
        """input: (batch_size, channel, time_steps, freq_bins)"""

        assert input.ndimension() == 4

        batch_size = input.shape[0]
        total_width = input.shape[self.dim]

        rand_sample = input[torch.randperm(batch_size)]
        for i in range(batch_size):
            self.transform_slice(input[i], rand_sample[i], total_width)
        return input

    def transform_slice(self, input, random_sample, total_width):

        for _ in range(self.stripes_num):
            distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0]
            bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]

            if self.dim == 2:
                input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \
                                                   0.5 * random_sample[:, bgn: bgn + distance, :]
            elif self.dim == 3:
                input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \
                                                   0.5 * random_sample[:, :, bgn: bgn + distance]


class CutStripes:

    def __init__(self, dim, cut_width, stripes_num):
        """ Cutting stripes with another randomly selected sample in mini-batch.
        args:
            dim: int, dimension along which to cut
            cut_width: int, maximum width of stripes to cut
            stripes_num: int, how many stripes to cut
        """

        super(CutStripes, self).__init__()

        assert dim in [2, 3]

        self.dim = dim
        self.cut_width = cut_width
        self.stripes_num = stripes_num

    def __call__(self, input):
        """input: (batch_size, channel, time_steps, freq_bins)"""

        assert input.ndimension() == 4

        batch_size = input.shape[0]
        total_width = input.shape[self.dim]

        rand_sample = input[torch.randperm(batch_size)]
        for i in range(batch_size):
            self.transform_slice(input[i], rand_sample[i], total_width)
        return input

    def transform_slice(self, input, random_sample, total_width):

        for _ in range(self.stripes_num):
            distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0]
            bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]

            if self.dim == 2:
                input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :]
            elif self.dim == 3:
                input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance]


class SpecAugmentation:

    def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num,
                 mask_type='mixture'):
        """Spec augmetation and SpecAugment++.
        [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D.
        and Le, Q.V., 2019. Specaugment: A simple data augmentation method
        for automatic speech recognition. arXiv preprint arXiv:1904.08779.
        [ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space
        Data Augmentation Method for Acoustic Scene Classification. arXiv
        preprint arXiv:2103.16858.

        Args:
            time_drop_width: int
            time_stripes_num: int
            freq_drop_width: int
            freq_stripes_num: int
            mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting)
        """

        super(SpecAugmentation, self).__init__()

        if mask_type == 'zero_value':
            self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width,
                                                stripes_num=time_stripes_num)
            self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width,
                                                stripes_num=freq_stripes_num)
        elif mask_type == 'mixture':
            self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width,
                                               stripes_num=time_stripes_num)
            self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width,
                                               stripes_num=freq_stripes_num)
        elif mask_type == 'cutting':
            self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width,
                                               stripes_num=time_stripes_num)
            self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width,
                                               stripes_num=freq_stripes_num)
        else:
            raise NameError('No such mask type in SpecAugment++')

    def __call__(self, inputs):
        # x should be in size [batch_size, channel, time_steps, freq_bins]
        x = self.time_augmentator(inputs)
        x = self.freq_augmentator(x)
        return x