from typing import List, Callable | |
from torch import Tensor | |
import random | |
from hw_asr.augmentations.base import AugmentationBase | |
class RandomChoice(AugmentationBase): | |
def __init__(self, augmentation_list: List[Callable], p: float): | |
self.augmentation_list = augmentation_list | |
self.p = p | |
def __call__(self, data: Tensor) -> Tensor: | |
x = data | |
if random.random() < self.p: | |
augmentation = random.choice(self.augmentation_list) | |
x = augmentation(x) | |
return x |