Spaces:
Runtime error
Runtime error
File size: 6,948 Bytes
e4bd7f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Author : Xinhao Mei @CVSSP, University of Surrey
# @E-mail : [email protected]
"""
Implemenation of SpecAugment++,
Adapated from Qiuqiang Kong's trochlibrosa:
https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py
"""
import torch
import torch.nn as nn
class DropStripes:
def __init__(self, dim, drop_width, stripes_num):
""" Drop stripes.
args:
dim: int, dimension along which to drop
drop_width: int, maximum width of stripes to drop
stripes_num: int, how many stripes to drop
"""
super(DropStripes, self).__init__()
assert dim in [2, 3] # dim 2: time; dim 3: frequency
self.dim = dim
self.drop_width = drop_width
self.stripes_num = stripes_num
def __call__(self, input):
"""input: (batch_size, channels, time_steps, freq_bins)"""
assert input.ndimension() == 4
batch_size = input.shape[0]
total_width = input.shape[self.dim]
for n in range(batch_size):
self.transform_slice(input[n], total_width)
return input
def transform_slice(self, e, total_width):
""" e: (channels, time_steps, freq_bins)"""
for _ in range(self.stripes_num):
distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
if self.dim == 2:
e[:, bgn: bgn + distance, :] = 0
elif self.dim == 3:
e[:, :, bgn: bgn + distance] = 0
class MixStripes:
def __init__(self, dim, mix_width, stripes_num):
""" Mix stripes
args:
dim: int, dimension along which to mix
mix_width: int, maximum width of stripes to mix
stripes_num: int, how many stripes to mix
"""
super(MixStripes, self).__init__()
assert dim in [2, 3]
self.dim = dim
self.mix_width = mix_width
self.stripes_num = stripes_num
def __call__(self, input):
"""input: (batch_size, channel, time_steps, freq_bins)"""
assert input.ndimension() == 4
batch_size = input.shape[0]
total_width = input.shape[self.dim]
rand_sample = input[torch.randperm(batch_size)]
for i in range(batch_size):
self.transform_slice(input[i], rand_sample[i], total_width)
return input
def transform_slice(self, input, random_sample, total_width):
for _ in range(self.stripes_num):
distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0]
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
if self.dim == 2:
input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \
0.5 * random_sample[:, bgn: bgn + distance, :]
elif self.dim == 3:
input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \
0.5 * random_sample[:, :, bgn: bgn + distance]
class CutStripes:
def __init__(self, dim, cut_width, stripes_num):
""" Cutting stripes with another randomly selected sample in mini-batch.
args:
dim: int, dimension along which to cut
cut_width: int, maximum width of stripes to cut
stripes_num: int, how many stripes to cut
"""
super(CutStripes, self).__init__()
assert dim in [2, 3]
self.dim = dim
self.cut_width = cut_width
self.stripes_num = stripes_num
def __call__(self, input):
"""input: (batch_size, channel, time_steps, freq_bins)"""
assert input.ndimension() == 4
batch_size = input.shape[0]
total_width = input.shape[self.dim]
rand_sample = input[torch.randperm(batch_size)]
for i in range(batch_size):
self.transform_slice(input[i], rand_sample[i], total_width)
return input
def transform_slice(self, input, random_sample, total_width):
for _ in range(self.stripes_num):
distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0]
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
if self.dim == 2:
input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :]
elif self.dim == 3:
input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance]
class SpecAugmentation:
def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num,
mask_type='mixture'):
"""Spec augmetation and SpecAugment++.
[ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D.
and Le, Q.V., 2019. Specaugment: A simple data augmentation method
for automatic speech recognition. arXiv preprint arXiv:1904.08779.
[ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space
Data Augmentation Method for Acoustic Scene Classification. arXiv
preprint arXiv:2103.16858.
Args:
time_drop_width: int
time_stripes_num: int
freq_drop_width: int
freq_stripes_num: int
mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting)
"""
super(SpecAugmentation, self).__init__()
if mask_type == 'zero_value':
self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width,
stripes_num=time_stripes_num)
self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width,
stripes_num=freq_stripes_num)
elif mask_type == 'mixture':
self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width,
stripes_num=time_stripes_num)
self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width,
stripes_num=freq_stripes_num)
elif mask_type == 'cutting':
self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width,
stripes_num=time_stripes_num)
self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width,
stripes_num=freq_stripes_num)
else:
raise NameError('No such mask type in SpecAugment++')
def __call__(self, inputs):
# x should be in size [batch_size, channel, time_steps, freq_bins]
x = self.time_augmentator(inputs)
x = self.freq_augmentator(x)
return x
|