S-MultiMAE / s_multimae /data_augmentation.py
thinh-researcher's picture
Init
6e9c433
raw
history blame contribute delete
530 Bytes
from torch import nn
from .configs.base_config import base_cfg
from .da.dav6 import DataAugmentationV6
from .da.base_da import BaseDataAugmentation
def get_data_augmentation(
cfg: base_cfg,
image_size: int,
is_padding: bool,
) -> BaseDataAugmentation:
if cfg.data_augmentation_version == 6:
print("Using DataAugmentationV6")
return DataAugmentationV6(cfg)
else:
raise NotImplementedError(
f"Unsupported DataAugmentation version {cfg.data_augmentation_version}"
)