|
"""Spec Augment module for preprocessing i.e., data augmentation""" |
|
|
|
import random |
|
|
|
import numpy |
|
from PIL import Image |
|
from PIL.Image import BICUBIC |
|
|
|
from espnet.transform.functional import FuncTrans |
|
|
|
|
|
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): |
|
"""time warp for spec augment |
|
|
|
move random center frame by the random width ~ uniform(-window, window) |
|
:param numpy.ndarray x: spectrogram (time, freq) |
|
:param int max_time_warp: maximum time frames to warp |
|
:param bool inplace: overwrite x with the result |
|
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp" |
|
(slow, differentiable) |
|
:returns numpy.ndarray: time warped spectrogram (time, freq) |
|
""" |
|
window = max_time_warp |
|
if mode == "PIL": |
|
t = x.shape[0] |
|
if t - window <= window: |
|
return x |
|
|
|
center = random.randrange(window, t - window) |
|
warped = random.randrange(center - window, center + window) + 1 |
|
|
|
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC) |
|
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC) |
|
if inplace: |
|
x[:warped] = left |
|
x[warped:] = right |
|
return x |
|
return numpy.concatenate((left, right), 0) |
|
elif mode == "sparse_image_warp": |
|
import torch |
|
|
|
from espnet.utils import spec_augment |
|
|
|
|
|
return spec_augment.time_warp(torch.from_numpy(x), window).numpy() |
|
else: |
|
raise NotImplementedError( |
|
"unknown resize mode: " |
|
+ mode |
|
+ ", choose one from (PIL, sparse_image_warp)." |
|
) |
|
|
|
|
|
class TimeWarp(FuncTrans): |
|
_func = time_warp |
|
__doc__ = time_warp.__doc__ |
|
|
|
def __call__(self, x, train): |
|
if not train: |
|
return x |
|
return super().__call__(x) |
|
|
|
|
|
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False): |
|
"""freq mask for spec agument |
|
|
|
:param numpy.ndarray x: (time, freq) |
|
:param int n_mask: the number of masks |
|
:param bool inplace: overwrite |
|
:param bool replace_with_zero: pad zero on mask if true else use mean |
|
""" |
|
if inplace: |
|
cloned = x |
|
else: |
|
cloned = x.copy() |
|
|
|
num_mel_channels = cloned.shape[1] |
|
fs = numpy.random.randint(0, F, size=(n_mask, 2)) |
|
|
|
for f, mask_end in fs: |
|
f_zero = random.randrange(0, num_mel_channels - f) |
|
mask_end += f_zero |
|
|
|
|
|
if f_zero == f_zero + f: |
|
continue |
|
|
|
if replace_with_zero: |
|
cloned[:, f_zero:mask_end] = 0 |
|
else: |
|
cloned[:, f_zero:mask_end] = cloned.mean() |
|
return cloned |
|
|
|
|
|
class FreqMask(FuncTrans): |
|
_func = freq_mask |
|
__doc__ = freq_mask.__doc__ |
|
|
|
def __call__(self, x, train): |
|
if not train: |
|
return x |
|
return super().__call__(x) |
|
|
|
|
|
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False): |
|
"""freq mask for spec agument |
|
|
|
:param numpy.ndarray spec: (time, freq) |
|
:param int n_mask: the number of masks |
|
:param bool inplace: overwrite |
|
:param bool replace_with_zero: pad zero on mask if true else use mean |
|
""" |
|
if inplace: |
|
cloned = spec |
|
else: |
|
cloned = spec.copy() |
|
len_spectro = cloned.shape[0] |
|
ts = numpy.random.randint(0, T, size=(n_mask, 2)) |
|
for t, mask_end in ts: |
|
|
|
if len_spectro - t <= 0: |
|
continue |
|
t_zero = random.randrange(0, len_spectro - t) |
|
|
|
|
|
if t_zero == t_zero + t: |
|
continue |
|
|
|
mask_end += t_zero |
|
if replace_with_zero: |
|
cloned[t_zero:mask_end] = 0 |
|
else: |
|
cloned[t_zero:mask_end] = cloned.mean() |
|
return cloned |
|
|
|
|
|
class TimeMask(FuncTrans): |
|
_func = time_mask |
|
__doc__ = time_mask.__doc__ |
|
|
|
def __call__(self, x, train): |
|
if not train: |
|
return x |
|
return super().__call__(x) |
|
|
|
|
|
def spec_augment( |
|
x, |
|
resize_mode="PIL", |
|
max_time_warp=80, |
|
max_freq_width=27, |
|
n_freq_mask=2, |
|
max_time_width=100, |
|
n_time_mask=2, |
|
inplace=True, |
|
replace_with_zero=True, |
|
): |
|
"""spec agument |
|
|
|
apply random time warping and time/freq masking |
|
default setting is based on LD (Librispeech double) in Table 2 |
|
https://arxiv.org/pdf/1904.08779.pdf |
|
|
|
:param numpy.ndarray x: (time, freq) |
|
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp" |
|
(slow, differentiable) |
|
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W) |
|
:param int freq_mask_width: maximum width of the random freq mask (F) |
|
:param int n_freq_mask: the number of the random freq mask (m_F) |
|
:param int time_mask_width: maximum width of the random time mask (T) |
|
:param int n_time_mask: the number of the random time mask (m_T) |
|
:param bool inplace: overwrite intermediate array |
|
:param bool replace_with_zero: pad zero on mask if true else use mean |
|
""" |
|
assert isinstance(x, numpy.ndarray) |
|
assert x.ndim == 2 |
|
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode) |
|
x = freq_mask( |
|
x, |
|
max_freq_width, |
|
n_freq_mask, |
|
inplace=inplace, |
|
replace_with_zero=replace_with_zero, |
|
) |
|
x = time_mask( |
|
x, |
|
max_time_width, |
|
n_time_mask, |
|
inplace=inplace, |
|
replace_with_zero=replace_with_zero, |
|
) |
|
return x |
|
|
|
|
|
class SpecAugment(FuncTrans): |
|
_func = spec_augment |
|
__doc__ = spec_augment.__doc__ |
|
|
|
def __call__(self, x, train): |
|
if not train: |
|
return x |
|
return super().__call__(x) |
|
|