File size: 717 Bytes
e45d058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

from timm.data import Mixup
from timm.data.mixup import mixup_target


class TimmMixup(Mixup):
    """ Wrap timm.data.Mixup that avoids the assert that batch size must be even.

    """
    def __call__(self, x, target):
        if self.mode == 'elem':
            lam = self._mix_elem(x)
        elif self.mode == 'pair':
            # We move the assert from the beginning of the function to here
            assert len(x) % 2 == 0, 'Batch size should be even when using this'
            lam = self._mix_pair(x)
        else:
            lam = self._mix_batch(x)
        target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
        return x, target