import random | |
from typing import Callable | |
from torch import Tensor | |
class RandomApply: | |
def __init__(self, augmentation: Callable, p: float): | |
assert 0 <= p <= 1 | |
self.augmentation = augmentation | |
self.p = p | |
def __call__(self, data: Tensor) -> Tensor: | |
if random.random() < self.p: | |
return self.augmentation(data) | |
else: | |
return data | |