Spaces:
Build error
Build error
File size: 7,664 Bytes
4a3ad95 8f243be 4a3ad95 80bcdd4 4a3ad95 8f243be 4a3ad95 8f243be 4a3ad95 8f243be 4a3ad95 8f243be 4a3ad95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import os
import torch
import numpy as np
import torch.distributed as dist
from torchvision import datasets, transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import Mixup
from timm.data import create_transform
from timm.data.transforms import str_to_interp_mode
from .cached_image_folder import CachedImageFolder
from .samplers import SubsetRandomSampler
from .dataset_fg import DatasetMeta
def build_loader(config):
config.defrost()
dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
config.freeze()
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
dataset_val, _ = build_dataset(is_train=False, config=config)
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
sampler_train = SubsetRandomSampler(indices)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
sampler_val = SubsetRandomSampler(indices)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False
)
# setup mixup / cutmix
mixup_fn = None
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
def build_dataset(is_train, config):
transform = build_transform(is_train, config)
if config.DATA.DATASET == 'imagenet':
prefix = 'train' if is_train else 'val'
if config.DATA.ZIP_MODE:
ann_file = prefix + "_map.txt"
prefix = prefix + ".zip@/"
dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
else:
# root = os.path.join(config.DATA.DATA_PATH, prefix)
root = './datasets/imagenet'
dataset = datasets.ImageFolder(root, transform=transform)
elif config.DATA.DATASET == 'inaturelist2021':
root = './datasets/inaturelist2021'
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
elif config.DATA.DATASET == 'inaturelist2021_mini':
root = './datasets/inaturelist2021_mini'
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
elif config.DATA.DATASET == 'inaturelist2017':
root = './datasets/inaturelist2017'
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
elif config.DATA.DATASET == 'inaturelist2018':
root = './datasets/inaturelist2018'
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
elif config.DATA.DATASET == 'cub-200':
root = './datasets/cub-200'
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
elif config.DATA.DATASET == 'stanfordcars':
root = './datasets/stanfordcars'
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
elif config.DATA.DATASET == 'oxfordflower':
root = './datasets/oxfordflower'
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
elif config.DATA.DATASET == 'stanforddogs':
root = './datasets/stanforddogs'
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
elif config.DATA.DATASET == 'nabirds':
root = './datasets/nabirds'
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
elif config.DATA.DATASET == 'aircraft':
root = './datasets/aircraft'
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
else:
root = config.DATA.DATASET_ROOT
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
nb_classes = len(dataset.class_to_idx)
return dataset, nb_classes
def build_transform(is_train, config):
resize_im = config.DATA.IMG_SIZE > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=config.DATA.IMG_SIZE,
is_training=True,
color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
re_prob=config.AUG.REPROB,
re_mode=config.AUG.REMODE,
re_count=config.AUG.RECOUNT,
interpolation=config.DATA.TRAIN_INTERPOLATION,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
return transform
t = []
if resize_im:
if config.TEST.CROP:
size = int((256 / 224) * config.DATA.IMG_SIZE)
t.append(
transforms.Resize(size, interpolation=str_to_interp_mode(config.DATA.INTERPOLATION)),
# to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
else:
t.append(
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
interpolation=str_to_interp_mode(config.DATA.INTERPOLATION))
)
t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)
|