|
import math |
|
import numbers |
|
from typing import Optional |
|
|
|
import numpy as np |
|
from fairseq.data.audio.feature_transforms import ( |
|
AudioFeatureTransform, |
|
register_audio_feature_transform, |
|
) |
|
|
|
|
|
@register_audio_feature_transform("specaugment") |
|
class SpecAugmentTransform(AudioFeatureTransform): |
|
"""SpecAugment (https://arxiv.org/abs/1904.08779)""" |
|
|
|
@classmethod |
|
def from_config_dict(cls, config=None): |
|
_config = {} if config is None else config |
|
return SpecAugmentTransform( |
|
_config.get("time_warp_W", 0), |
|
_config.get("freq_mask_N", 0), |
|
_config.get("freq_mask_F", 0), |
|
_config.get("time_mask_N", 0), |
|
_config.get("time_mask_T", 0), |
|
_config.get("time_mask_p", 0.0), |
|
_config.get("mask_value", None), |
|
) |
|
|
|
def __init__( |
|
self, |
|
time_warp_w: int = 0, |
|
freq_mask_n: int = 0, |
|
freq_mask_f: int = 0, |
|
time_mask_n: int = 0, |
|
time_mask_t: int = 0, |
|
time_mask_p: float = 0.0, |
|
mask_value: Optional[float] = 0.0, |
|
): |
|
|
|
assert mask_value is None or isinstance( |
|
mask_value, numbers.Number |
|
), f"mask_value (type: {type(mask_value)}) must be None or a number" |
|
if freq_mask_n > 0: |
|
assert freq_mask_f > 0, ( |
|
f"freq_mask_F ({freq_mask_f}) " |
|
f"must be larger than 0 when doing freq masking." |
|
) |
|
if time_mask_n > 0: |
|
assert time_mask_t > 0, ( |
|
f"time_mask_T ({time_mask_t}) must be larger than 0 when " |
|
f"doing time masking." |
|
) |
|
|
|
self.time_warp_w = time_warp_w |
|
self.freq_mask_n = freq_mask_n |
|
self.freq_mask_f = freq_mask_f |
|
self.time_mask_n = time_mask_n |
|
self.time_mask_t = time_mask_t |
|
self.time_mask_p = time_mask_p |
|
self.mask_value = mask_value |
|
|
|
def __repr__(self): |
|
return ( |
|
self.__class__.__name__ |
|
+ "(" |
|
+ ", ".join( |
|
[ |
|
f"time_warp_w={self.time_warp_w}", |
|
f"freq_mask_n={self.freq_mask_n}", |
|
f"freq_mask_f={self.freq_mask_f}", |
|
f"time_mask_n={self.time_mask_n}", |
|
f"time_mask_t={self.time_mask_t}", |
|
f"time_mask_p={self.time_mask_p}", |
|
] |
|
) |
|
+ ")" |
|
) |
|
|
|
def __call__(self, spectrogram): |
|
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor." |
|
|
|
distorted = spectrogram.copy() |
|
num_frames = spectrogram.shape[0] |
|
num_freqs = spectrogram.shape[1] |
|
mask_value = self.mask_value |
|
|
|
if mask_value is None: |
|
mask_value = spectrogram.mean() |
|
|
|
if num_frames == 0: |
|
return spectrogram |
|
|
|
if num_freqs < self.freq_mask_f: |
|
return spectrogram |
|
|
|
if self.time_warp_w > 0: |
|
if 2 * self.time_warp_w < num_frames: |
|
import cv2 |
|
|
|
w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w) |
|
w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w) |
|
upper, lower = distorted[:w0, :], distorted[w0:, :] |
|
upper = cv2.resize( |
|
upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR |
|
) |
|
lower = cv2.resize( |
|
lower, |
|
dsize=(num_freqs, num_frames - w0 - w), |
|
interpolation=cv2.INTER_LINEAR, |
|
) |
|
distorted = np.concatenate((upper, lower), axis=0) |
|
|
|
for _i in range(self.freq_mask_n): |
|
f = np.random.randint(0, self.freq_mask_f) |
|
f0 = np.random.randint(0, num_freqs - f) |
|
if f != 0: |
|
distorted[:, f0 : f0 + f] = mask_value |
|
|
|
max_time_mask_t = min( |
|
self.time_mask_t, math.floor(num_frames * self.time_mask_p) |
|
) |
|
if max_time_mask_t < 1: |
|
return distorted |
|
|
|
for _i in range(self.time_mask_n): |
|
t = np.random.randint(0, max_time_mask_t) |
|
t0 = np.random.randint(0, num_frames - t) |
|
if t != 0: |
|
distorted[t0 : t0 + t, :] = mask_value |
|
|
|
return distorted |
|
|