wavlm-large / s3prl_s3prl_main /s3prl /dataset /noise_augmentation_pipes.py
lmzjms's picture
Upload 1162 files
0b32ad6 verified
import copy
import random
from dataclasses import dataclass
import torch
from .base import AugmentedDynamicItemDataset, DataPipe
@dataclass
class NoiseAugmentation(DataPipe):
noise_proportion: float = 0.0
input_feat_name: str = "input_feat"
output_feat_name: str = "output_feat"
"""
Args:
noise_proportion (float): for this percentage of the time, Gaussian noise will be applied on all frames during MAM training, set to 0 for no noise
input_feat_name (str): handle for the `takes` (input)
output_feat_name (str): handle for the `provides` (output)
"""
def apply_noise_on_data(self, input_feat):
with torch.no_grad():
if self.noise_proportion > 0:
noised_feat = copy.deepcopy(input_feat)
dice = random.random()
if dice < self.noise_proportion:
noise_sampler = torch.distributions.Normal(0, 0.2)
noised_feat += noise_sampler.sample(noised_feat.shape)
noised_feat = noised_feat.to(dtype=torch.float32)
return noised_feat
else:
return input_feat
def __call__(self, dataset: AugmentedDynamicItemDataset):
dataset.add_dynamic_item(
self.apply_noise_on_data,
takes=self.input_feat_name,
provides=self.output_feat_name,
)
return dataset