File size: 400 Bytes
affcd23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import random
from typing import Callable
from torch import Tensor
class RandomApply:
def __init__(self, augmentation: Callable, p: float):
assert 0 <= p <= 1
self.augmentation = augmentation
self.p = p
def __call__(self, data: Tensor) -> Tensor:
if random.random() < self.p:
return self.augmentation(data)
else:
return data
|