Spaces:
Running
Running
"""SpecAugment module.""" | |
from typing import Optional | |
from typing import Sequence | |
from typing import Union | |
from funasr_detach.models.specaug.mask_along_axis import MaskAlongAxis | |
from funasr_detach.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth | |
from funasr_detach.models.specaug.mask_along_axis import MaskAlongAxisLFR | |
from funasr_detach.models.specaug.time_warp import TimeWarp | |
from funasr_detach.register import tables | |
import torch.nn as nn | |
class SpecAug(nn.Module): | |
"""Implementation of SpecAug. | |
Reference: | |
Daniel S. Park et al. | |
"SpecAugment: A Simple Data | |
Augmentation Method for Automatic Speech Recognition" | |
.. warning:: | |
When using cuda mode, time_warp doesn't have reproducibility | |
due to `torch.nn.functional.interpolate`. | |
""" | |
def __init__( | |
self, | |
apply_time_warp: bool = True, | |
time_warp_window: int = 5, | |
time_warp_mode: str = "bicubic", | |
apply_freq_mask: bool = True, | |
freq_mask_width_range: Union[int, Sequence[int]] = (0, 20), | |
num_freq_mask: int = 2, | |
apply_time_mask: bool = True, | |
time_mask_width_range: Optional[Union[int, Sequence[int]]] = None, | |
time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None, | |
num_time_mask: int = 2, | |
): | |
if not apply_time_warp and not apply_time_mask and not apply_freq_mask: | |
raise ValueError( | |
"Either one of time_warp, time_mask, or freq_mask should be applied" | |
) | |
if ( | |
apply_time_mask | |
and (time_mask_width_range is not None) | |
and (time_mask_width_ratio_range is not None) | |
): | |
raise ValueError( | |
'Either one of "time_mask_width_range" or ' | |
'"time_mask_width_ratio_range" can be used' | |
) | |
super().__init__() | |
self.apply_time_warp = apply_time_warp | |
self.apply_freq_mask = apply_freq_mask | |
self.apply_time_mask = apply_time_mask | |
if apply_time_warp: | |
self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode) | |
else: | |
self.time_warp = None | |
if apply_freq_mask: | |
self.freq_mask = MaskAlongAxis( | |
dim="freq", | |
mask_width_range=freq_mask_width_range, | |
num_mask=num_freq_mask, | |
) | |
else: | |
self.freq_mask = None | |
if apply_time_mask: | |
if time_mask_width_range is not None: | |
self.time_mask = MaskAlongAxis( | |
dim="time", | |
mask_width_range=time_mask_width_range, | |
num_mask=num_time_mask, | |
) | |
elif time_mask_width_ratio_range is not None: | |
self.time_mask = MaskAlongAxisVariableMaxWidth( | |
dim="time", | |
mask_width_ratio_range=time_mask_width_ratio_range, | |
num_mask=num_time_mask, | |
) | |
else: | |
raise ValueError( | |
'Either one of "time_mask_width_range" or ' | |
'"time_mask_width_ratio_range" should be used.' | |
) | |
else: | |
self.time_mask = None | |
def forward(self, x, x_lengths=None): | |
if self.time_warp is not None: | |
x, x_lengths = self.time_warp(x, x_lengths) | |
if self.freq_mask is not None: | |
x, x_lengths = self.freq_mask(x, x_lengths) | |
if self.time_mask is not None: | |
x, x_lengths = self.time_mask(x, x_lengths) | |
return x, x_lengths | |
class SpecAugLFR(nn.Module): | |
"""Implementation of SpecAug. | |
lfr_rate:low frame rate | |
""" | |
def __init__( | |
self, | |
apply_time_warp: bool = True, | |
time_warp_window: int = 5, | |
time_warp_mode: str = "bicubic", | |
apply_freq_mask: bool = True, | |
freq_mask_width_range: Union[int, Sequence[int]] = (0, 20), | |
num_freq_mask: int = 2, | |
lfr_rate: int = 0, | |
apply_time_mask: bool = True, | |
time_mask_width_range: Optional[Union[int, Sequence[int]]] = None, | |
time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None, | |
num_time_mask: int = 2, | |
): | |
if not apply_time_warp and not apply_time_mask and not apply_freq_mask: | |
raise ValueError( | |
"Either one of time_warp, time_mask, or freq_mask should be applied" | |
) | |
if ( | |
apply_time_mask | |
and (time_mask_width_range is not None) | |
and (time_mask_width_ratio_range is not None) | |
): | |
raise ValueError( | |
'Either one of "time_mask_width_range" or ' | |
'"time_mask_width_ratio_range" can be used' | |
) | |
super().__init__() | |
self.apply_time_warp = apply_time_warp | |
self.apply_freq_mask = apply_freq_mask | |
self.apply_time_mask = apply_time_mask | |
if apply_time_warp: | |
self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode) | |
else: | |
self.time_warp = None | |
if apply_freq_mask: | |
self.freq_mask = MaskAlongAxisLFR( | |
dim="freq", | |
mask_width_range=freq_mask_width_range, | |
num_mask=num_freq_mask, | |
lfr_rate=lfr_rate + 1, | |
) | |
else: | |
self.freq_mask = None | |
if apply_time_mask: | |
if time_mask_width_range is not None: | |
self.time_mask = MaskAlongAxisLFR( | |
dim="time", | |
mask_width_range=time_mask_width_range, | |
num_mask=num_time_mask, | |
lfr_rate=lfr_rate + 1, | |
) | |
elif time_mask_width_ratio_range is not None: | |
self.time_mask = MaskAlongAxisVariableMaxWidth( | |
dim="time", | |
mask_width_ratio_range=time_mask_width_ratio_range, | |
num_mask=num_time_mask, | |
) | |
else: | |
raise ValueError( | |
'Either one of "time_mask_width_range" or ' | |
'"time_mask_width_ratio_range" should be used.' | |
) | |
else: | |
self.time_mask = None | |
def forward(self, x, x_lengths=None): | |
if self.time_warp is not None: | |
x, x_lengths = self.time_warp(x, x_lengths) | |
if self.freq_mask is not None: | |
x, x_lengths = self.freq_mask(x, x_lengths) | |
if self.time_mask is not None: | |
x, x_lengths = self.time_mask(x, x_lengths) | |
return x, x_lengths | |