tgritsaev's picture
Upload 198 files
affcd23 verified
raw
history blame contribute delete
523 Bytes
from typing import List, Callable
from torch import Tensor
import random
from hw_asr.augmentations.base import AugmentationBase
class RandomChoice(AugmentationBase):
def __init__(self, augmentation_list: List[Callable], p: float):
self.augmentation_list = augmentation_list
self.p = p
def __call__(self, data: Tensor) -> Tensor:
x = data
if random.random() < self.p:
augmentation = random.choice(self.augmentation_list)
x = augmentation(x)
return x