|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import sys |
|
import datetime |
|
import time |
|
import math |
|
import json |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torch.distributed as dist |
|
import torch.backends.cudnn as cudnn |
|
import torch.nn.functional as F |
|
from torchvision import datasets, transforms |
|
from torchvision import models as torchvision_models |
|
|
|
import utils |
|
import vision_transformer as vits |
|
from vision_transformer import DINOHead |
|
|
|
torchvision_archs = sorted(name for name in torchvision_models.__dict__ |
|
if name.islower() and not name.startswith("__") |
|
and callable(torchvision_models.__dict__[name])) |
|
|
|
def get_args_parser(): |
|
parser = argparse.ArgumentParser('DINO', add_help=False) |
|
|
|
|
|
parser.add_argument('--arch', default='vit_base', type=str, |
|
help="""Name of architecture to train. For quick experiments with ViTs, |
|
we recommend using vit_tiny or vit_small.""") |
|
parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels |
|
of input square patches - default 16 (for 16x16 patches). Using smaller |
|
values leads to better performance but requires more memory. Applies only |
|
for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling |
|
mixed precision training (--use_fp16 false) to avoid unstabilities.""") |
|
parser.add_argument('--input_size', default=224, type=int) |
|
parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of |
|
the DINO head output. For complex and large datasets large values (like 65k) work well.""") |
|
parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag, |
|
help="""Whether or not to weight normalize the last layer of the DINO head. |
|
Not normalizing leads to better performance but can make the training unstable. |
|
In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""") |
|
parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA |
|
parameter for teacher update. The value is increased to 1 during training with cosine schedule. |
|
We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""") |
|
parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag, |
|
help="Whether to use batch normalizations in projection head (Default: False)") |
|
|
|
|
|
parser.add_argument('--warmup_teacher_temp', default=0.04, type=float, |
|
help="""Initial value for the teacher temperature: 0.04 works well in most cases. |
|
Try decreasing it if the training loss does not decrease.""") |
|
parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup) |
|
of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend |
|
starting with the default value of 0.04 and increase this slightly if needed.""") |
|
parser.add_argument('--warmup_teacher_temp_epochs', default=0, type=int, |
|
help='Number of warmup epochs for the teacher temperature (Default: 30).') |
|
|
|
|
|
parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not |
|
to use half precision for training. Improves training time and memory requirements, |
|
but can provoke instability and slight decay of performance. We recommend disabling |
|
mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""") |
|
parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the |
|
weight decay. With ViT, a smaller value at the beginning of training works well.""") |
|
parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the |
|
weight decay. We use a cosine schedule for WD and using a larger decay by |
|
the end of training improves performance for ViTs.""") |
|
parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter |
|
gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can |
|
help optimization for larger ViT architectures. 0 for disabling.""") |
|
parser.add_argument('--batch_size_per_gpu', default=512, type=int, |
|
help='Per-GPU batch-size : number of distinct images loaded on one GPU.') |
|
parser.add_argument('--epochs', default=150, type=int, help='Number of epochs of training.') |
|
parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs |
|
during which we keep the output layer fixed. Typically doing so during |
|
the first epoch helps training. Try increasing this value if the loss does not decrease.""") |
|
parser.add_argument("--lr", default=1.5e-4, type=float, help="""Learning rate at the end of |
|
linear warmup (highest LR used during training). The learning rate is linearly scaled |
|
with the batch size, and specified here for a reference batch size of 512.""") |
|
parser.add_argument("--warmup_epochs", default=20, type=int, |
|
help="Number of epochs for the linear learning-rate warm up.") |
|
parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the |
|
end of optimization. We use a cosine LR schedule with linear warmup.""") |
|
parser.add_argument('--optimizer', default='adamw', type=str, |
|
choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""") |
|
parser.add_argument('--drop_path_rate', type=float, default=0.1, help="stochastic depth rate") |
|
|
|
|
|
parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.), |
|
help="""Scale range of the cropped image before resizing, relatively to the origin image. |
|
Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we |
|
recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""") |
|
parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small |
|
local views to generate. Set this parameter to 0 to disable multi-crop training. |
|
When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """) |
|
parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4), |
|
help="""Scale range of the cropped image before resizing, relatively to the origin image. |
|
Used for small local view cropping of multi-crop.""") |
|
|
|
|
|
parser.add_argument('--data_path', default='/root/autodl-tmp/All', type=str, |
|
help='Please specify path to the ImageNet training data.') |
|
parser.add_argument('--basic_state_dict', default='/root/autodl-tmp/ViT_b16_224_Imagenet.pth', type=str, |
|
help='Load in pretrained or un-pretrained model pth') |
|
parser.add_argument('--output_dir', default="/home/CPIA/saved_models/DINO", type=str, help='Path to save logs and checkpoints.') |
|
parser.add_argument('--saveckp_freq', default=50, type=int, help='Save checkpoint every x epochs.') |
|
parser.add_argument('--seed', default=0, type=int, help='Random seed.') |
|
parser.add_argument('--num_workers', default=32, type=int, help='Number of data loading workers per GPU.') |
|
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up |
|
distributed training; see https://pytorch.org/docs/stable/distributed.html""") |
|
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") |
|
return parser |
|
|
|
|
|
def train_dino(args): |
|
utils.init_distributed_mode(args) |
|
utils.fix_random_seeds(args.seed) |
|
print('ok') |
|
print("git:\n {}\n".format(utils.get_sha())) |
|
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) |
|
cudnn.benchmark = True |
|
|
|
|
|
transform = DataAugmentationDINO( |
|
args.global_crops_scale, |
|
args.local_crops_scale, |
|
args.local_crops_number, |
|
) |
|
dataset = datasets.ImageFolder(args.data_path, transform=transform) |
|
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) |
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, |
|
sampler=sampler, |
|
batch_size=args.batch_size_per_gpu, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
drop_last=True, |
|
) |
|
print(f"Data loaded: there are {len(dataset)} images.") |
|
|
|
|
|
|
|
args.arch = args.arch.replace("deit", "vit") |
|
|
|
if args.arch in vits.__dict__.keys(): |
|
student = vits.__dict__[args.arch]( |
|
patch_size=args.patch_size, |
|
drop_path_rate=args.drop_path_rate, |
|
) |
|
if args.basic_state_dict is not None: |
|
try: |
|
basic_state_dict = torch.load(args.basic_state_dict) |
|
if 'model' in basic_state_dict: |
|
basic_state_dict = basic_state_dict['model'] |
|
else: |
|
pass |
|
student.load_state_dict(basic_state_dict, False) |
|
|
|
except: |
|
print('erro in args.basic_state_dict:', args.basic_state_dict) |
|
print('Student PreTuning Restart') |
|
|
|
else: |
|
print('Student PreTuning with Transfer-learning with:', args.basic_state_dict) |
|
|
|
else: |
|
print('Student PreTuning Restart') |
|
teacher = vits.__dict__[args.arch](patch_size=args.patch_size) |
|
if args.basic_state_dict is not None: |
|
try: |
|
basic_state_dict = torch.load(args.basic_state_dict) |
|
if 'model' in basic_state_dict: |
|
basic_state_dict = basic_state_dict['model'] |
|
else: |
|
pass |
|
teacher.load_state_dict(basic_state_dict, False) |
|
|
|
except: |
|
print('erro in args.basic_state_dict:', args.basic_state_dict) |
|
print('Teacher PreTuning Restart') |
|
|
|
else: |
|
print('Teacher PreTuning with Transfer-learning with:', args.basic_state_dict) |
|
|
|
else: |
|
print('Teacher PreTuning Restart') |
|
embed_dim = student.embed_dim |
|
|
|
"""elif args.arch in torch.hub.list("facebookresearch/xcit:main"): |
|
student = torch.hub.load('facebookresearch/xcit:main', args.arch, |
|
pretrained=False, drop_path_rate=args.drop_path_rate) |
|
teacher = torch.hub.load('facebookresearch/xcit:main', args.arch, pretrained=False) |
|
embed_dim = student.embed_dim |
|
# otherwise, we check if the architecture is in torchvision models |
|
elif args.arch in torchvision_models.__dict__.keys(): |
|
student = torchvision_models.__dict__[args.arch]() |
|
teacher = torchvision_models.__dict__[args.arch]() |
|
embed_dim = student.fc.weight.shape[1] |
|
else: |
|
print(f"Unknow architecture: {args.arch}")""" |
|
|
|
|
|
student = utils.MultiCropWrapper(student, DINOHead( |
|
embed_dim, |
|
args.out_dim, |
|
use_bn=args.use_bn_in_head, |
|
norm_last_layer=args.norm_last_layer, |
|
)) |
|
teacher = utils.MultiCropWrapper( |
|
teacher, |
|
DINOHead(embed_dim, args.out_dim, args.use_bn_in_head), |
|
) |
|
|
|
student, teacher = student.cuda(), teacher.cuda() |
|
|
|
if utils.has_batchnorms(student): |
|
student = nn.SyncBatchNorm.convert_sync_batchnorm(student) |
|
teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher) |
|
|
|
|
|
teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu]) |
|
teacher_without_ddp = teacher.module |
|
else: |
|
|
|
teacher_without_ddp = teacher |
|
student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu]) |
|
|
|
teacher_without_ddp.load_state_dict(student.module.state_dict()) |
|
|
|
for p in teacher.parameters(): |
|
p.requires_grad = False |
|
print(f"Student and Teacher are built: they are both {args.arch} network.") |
|
|
|
|
|
dino_loss = DINOLoss( |
|
args.out_dim, |
|
args.local_crops_number + 2, |
|
args.warmup_teacher_temp, |
|
args.teacher_temp, |
|
args.warmup_teacher_temp_epochs, |
|
args.epochs, |
|
).cuda() |
|
|
|
|
|
params_groups = utils.get_params_groups(student) |
|
if args.optimizer == "adamw": |
|
optimizer = torch.optim.AdamW(params_groups) |
|
elif args.optimizer == "sgd": |
|
optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) |
|
elif args.optimizer == "lars": |
|
optimizer = utils.LARS(params_groups) |
|
|
|
fp16_scaler = None |
|
if args.use_fp16: |
|
fp16_scaler = torch.cuda.amp.GradScaler() |
|
|
|
|
|
lr_schedule = utils.cosine_scheduler( |
|
args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., |
|
args.min_lr, |
|
args.epochs, len(data_loader), |
|
warmup_epochs=args.warmup_epochs, |
|
) |
|
wd_schedule = utils.cosine_scheduler( |
|
args.weight_decay, |
|
args.weight_decay_end, |
|
args.epochs, len(data_loader), |
|
) |
|
|
|
momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1, |
|
args.epochs, len(data_loader)) |
|
print(f"Loss, optimizer and schedulers ready.") |
|
|
|
|
|
to_restore = {"epoch": 0} |
|
utils.restart_from_checkpoint( |
|
os.path.join(args.output_dir, "checkpoint.pth"), |
|
run_variables=to_restore, |
|
student=student, |
|
teacher=teacher, |
|
optimizer=optimizer, |
|
fp16_scaler=fp16_scaler, |
|
dino_loss=dino_loss, |
|
) |
|
start_epoch = to_restore["epoch"] |
|
|
|
start_time = time.time() |
|
print("Starting DINO training !") |
|
for epoch in range(start_epoch, args.epochs): |
|
data_loader.sampler.set_epoch(epoch) |
|
|
|
|
|
train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, |
|
data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule, |
|
epoch, fp16_scaler, args) |
|
|
|
|
|
save_dict = { |
|
'student': student.state_dict(), |
|
'teacher': teacher.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'epoch': epoch + 1, |
|
'args': args, |
|
'dino_loss': dino_loss.state_dict(), |
|
} |
|
if fp16_scaler is not None: |
|
save_dict['fp16_scaler'] = fp16_scaler.state_dict() |
|
utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth')) |
|
if args.saveckp_freq and (epoch % args.saveckp_freq == 0 or epoch + 1 == args.epochs): |
|
initial_setting = os.path.split(args.basic_state_dict)[1].split('.')[0] |
|
dataset_using = os.path.split(args.data_path)[1] |
|
utils.save_on_master(save_dict, os.path.join(args.output_dir, f'dino_'+initial_setting + '_' + dataset_using+f'_checkpoint{epoch:04}.pth')) |
|
|
|
|
|
|
|
|
|
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, |
|
'epoch': epoch} |
|
if utils.is_main_process(): |
|
with (Path(args.output_dir) / "log.txt").open("a") as f: |
|
f.write(json.dumps(log_stats) + "\n") |
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print('Training time {}'.format(total_time_str)) |
|
|
|
|
|
def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader, |
|
optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch, |
|
fp16_scaler, args): |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = 'Epoch: [{}/{}]'.format(epoch, args.epochs) |
|
for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)): |
|
|
|
it = len(data_loader) * epoch + it |
|
for i, param_group in enumerate(optimizer.param_groups): |
|
param_group["lr"] = lr_schedule[it] |
|
if i == 0: |
|
param_group["weight_decay"] = wd_schedule[it] |
|
|
|
|
|
images = [im.cuda(non_blocking=True) for im in images] |
|
|
|
with torch.cuda.amp.autocast(fp16_scaler is not None): |
|
teacher_output = teacher(images[:2]) |
|
student_output = student(images) |
|
loss = dino_loss(student_output, teacher_output, epoch) |
|
|
|
if not math.isfinite(loss.item()): |
|
print("Loss is {}, stopping training".format(loss.item()), force=True) |
|
sys.exit(1) |
|
|
|
|
|
optimizer.zero_grad() |
|
param_norms = None |
|
if fp16_scaler is None: |
|
loss.backward() |
|
if args.clip_grad: |
|
param_norms = utils.clip_gradients(student, args.clip_grad) |
|
utils.cancel_gradients_last_layer(epoch, student, |
|
args.freeze_last_layer) |
|
optimizer.step() |
|
else: |
|
fp16_scaler.scale(loss).backward() |
|
if args.clip_grad: |
|
fp16_scaler.unscale_(optimizer) |
|
param_norms = utils.clip_gradients(student, args.clip_grad) |
|
utils.cancel_gradients_last_layer(epoch, student, |
|
args.freeze_last_layer) |
|
fp16_scaler.step(optimizer) |
|
fp16_scaler.update() |
|
|
|
|
|
with torch.no_grad(): |
|
m = momentum_schedule[it] |
|
for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()): |
|
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) |
|
|
|
|
|
torch.cuda.synchronize() |
|
metric_logger.update(loss=loss.item()) |
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
|
metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"]) |
|
|
|
metric_logger.synchronize_between_processes() |
|
print("Averaged stats:", metric_logger) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
class DINOLoss(nn.Module): |
|
def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, |
|
warmup_teacher_temp_epochs, nepochs, student_temp=0.1, |
|
center_momentum=0.9): |
|
super().__init__() |
|
self.student_temp = student_temp |
|
self.center_momentum = center_momentum |
|
self.ncrops = ncrops |
|
self.register_buffer("center", torch.zeros(1, out_dim)) |
|
|
|
|
|
self.teacher_temp_schedule = np.concatenate(( |
|
np.linspace(warmup_teacher_temp, |
|
teacher_temp, warmup_teacher_temp_epochs), |
|
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp |
|
)) |
|
|
|
def forward(self, student_output, teacher_output, epoch): |
|
""" |
|
Cross-entropy between softmax outputs of the teacher and student networks. |
|
""" |
|
student_out = student_output / self.student_temp |
|
student_out = student_out.chunk(self.ncrops) |
|
|
|
|
|
temp = self.teacher_temp_schedule[epoch] |
|
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) |
|
teacher_out = teacher_out.detach().chunk(2) |
|
|
|
total_loss = 0 |
|
n_loss_terms = 0 |
|
for iq, q in enumerate(teacher_out): |
|
for v in range(len(student_out)): |
|
if v == iq: |
|
|
|
continue |
|
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) |
|
total_loss += loss.mean() |
|
n_loss_terms += 1 |
|
total_loss /= n_loss_terms |
|
self.update_center(teacher_output) |
|
return total_loss |
|
|
|
@torch.no_grad() |
|
def update_center(self, teacher_output): |
|
""" |
|
Update center used for teacher output. |
|
""" |
|
batch_center = torch.sum(teacher_output, dim=0, keepdim=True) |
|
dist.all_reduce(batch_center) |
|
batch_center = batch_center / (len(teacher_output) * dist.get_world_size()) |
|
|
|
|
|
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) |
|
|
|
|
|
class DataAugmentationDINO(object): |
|
def __init__(self, global_crops_scale, local_crops_scale, local_crops_number): |
|
flip_and_color_jitter = transforms.Compose([ |
|
transforms.RandomHorizontalFlip(p=0.5), |
|
transforms.RandomApply( |
|
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], |
|
p=0.8 |
|
), |
|
transforms.RandomGrayscale(p=0.2), |
|
]) |
|
normalize = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
]) |
|
|
|
|
|
self.global_transfo1 = transforms.Compose([ |
|
transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC), |
|
flip_and_color_jitter, |
|
utils.GaussianBlur(1.0), |
|
normalize, |
|
]) |
|
|
|
self.global_transfo2 = transforms.Compose([ |
|
transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC), |
|
flip_and_color_jitter, |
|
utils.GaussianBlur(0.1), |
|
utils.Solarization(0.2), |
|
normalize, |
|
]) |
|
|
|
self.local_crops_number = local_crops_number |
|
self.local_transfo = transforms.Compose([ |
|
transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC), |
|
flip_and_color_jitter, |
|
utils.GaussianBlur(p=0.5), |
|
normalize, |
|
]) |
|
|
|
def __call__(self, image): |
|
crops = [] |
|
crops.append(self.global_transfo1(image)) |
|
crops.append(self.global_transfo2(image)) |
|
for _ in range(self.local_crops_number): |
|
crops.append(self.local_transfo(image)) |
|
return crops |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser('DINO', parents=[get_args_parser()]) |
|
args = parser.parse_args() |
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
|
train_dino(args) |
|
|