import argparse import numpy as np import random from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn import torch.distributed as dist from torch.cuda.amp import GradScaler, autocast from models.FFLIP import FLIP from models import utils from eval.pretrain_eval import evaluation, itm_eval from data import create_dataset, create_sampler, create_loader def main(args): utils.init_distributed_mode(args) device = torch.device(args.device) seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) cudnn.benchmark = True #### The reference code for creating the dataset #### print("Creating dataset") train_dataset, test_dataset = create_dataset(args, 'facecaption') if args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None] else: samplers = [None, None] train_loader, test_loader = create_loader([train_dataset, test_dataset], samplers, batch_size=[80] + [80], num_workers=[8, 8], is_trains=[True, False], collate_fns=[None, None]) #### Model #### print("Creating model") model = FLIP(pretrained=args.pretrained, vit='base', queue_size=61440) model = model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module print("Start evaluation") score_test_i2t, score_test_t2i = evaluation(args, model_without_ddp, test_loader, device) if utils.is_main_process(): test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt) print(test_result) if args.distributed: dist.barrier() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--output_dir', default='./outputs') parser.add_argument('--img_root', default='./FaceCaption/images') parser.add_argument('--ann_root', default='.FaceCaption/caption') parser.add_argument('--pretrained', default='./FaceCaption-15M-base.pth') parser.add_argument('--device', default='cuda') parser.add_argument('--seed', default=42, type=int) parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') parser.add_argument('--distributed', default=False, type=bool, help='whether to use distributed mode to training') args = parser.parse_args() main(args)