File size: 400 Bytes
affcd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import random
from typing import Callable
from torch import Tensor


class RandomApply:
    def __init__(self, augmentation: Callable, p: float):
        assert 0 <= p <= 1
        self.augmentation = augmentation
        self.p = p

    def __call__(self, data: Tensor) -> Tensor:
        if random.random() < self.p:
            return self.augmentation(data)
        else:
            return data