File size: 530 Bytes
6e9c433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torch import nn

from .configs.base_config import base_cfg
from .da.dav6 import DataAugmentationV6
from .da.base_da import BaseDataAugmentation


def get_data_augmentation(
    cfg: base_cfg,
    image_size: int,
    is_padding: bool,
) -> BaseDataAugmentation:
    if cfg.data_augmentation_version == 6:
        print("Using DataAugmentationV6")
        return DataAugmentationV6(cfg)
    else:
        raise NotImplementedError(
            f"Unsupported DataAugmentation version {cfg.data_augmentation_version}"
        )