Spaces:
Sleeping
Sleeping
File size: 717 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
from timm.data import Mixup
from timm.data.mixup import mixup_target
class TimmMixup(Mixup):
""" Wrap timm.data.Mixup that avoids the assert that batch size must be even.
"""
def __call__(self, x, target):
if self.mode == 'elem':
lam = self._mix_elem(x)
elif self.mode == 'pair':
# We move the assert from the beginning of the function to here
assert len(x) % 2 == 0, 'Batch size should be even when using this'
lam = self._mix_pair(x)
else:
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
return x, target
|