""" Testing script of PuzzleTuning Visualization Script ver: Feb 11th 14:00 Paper: https://arxiv.org/abs/2311.06712 Code: https://github.com/sagizty/PuzzleTuning Ref: MAE https://github.com/facebookresearch/mae Step 1: PreTraining on the ImagetNet-1k style dataset (others) Step 2: Domain Prompt Tuning (PuzzleTuning) on Pathological Images (in ImageFolder) Step 3: FineTuning on the Downstream Tasks This is the independent testing for step 2 update: Use "--seg_decoder" parameter to introduce segmentation networks swin_unet for Swin-Unet """ import argparse import datetime import numpy as np import os import time from pathlib import Path import torch import torch.backends.cudnn as cudnn from tensorboardX import SummaryWriter import torchvision.transforms as transforms import torchvision.datasets as datasets import timm # assert timm.__version__ == "0.3.2" # version check from SSL_structures import models_mae, SAE from utils.visual_usage import patchify, unpatchify, Draw_tri_fig from torchvision.transforms import ToPILImage def Puzzle_test(model, data_loader_test, test_dataset_size, mask_ratio, fix_position_ratio, fix_patch_size, check_minibatch=100, enable_visualize_check=True, combined_pred_illustration=False, check_samples=1, device=None, output_dir=None, writer=None, args=None): # start testing print(f"Start testing for {args.model_idx} \n with checkpoint: {args.checkpoint_path}") start_time = time.time() index = 0 model_time = time.time() # criterias, initially empty running_loss = 0.0 log_running_loss = 0.0 model.eval() # Iterate over data. for inputs, labels in data_loader_test: # use different dataloder in different phase inputs = inputs.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) # for tracking fixme if args.model[0:3] == 'sae': loss, pred, imgs_puzzled_patches = model(inputs, fix_position_ratio=fix_position_ratio, puzzle_patch_size=fix_patch_size, combined_pred_illustration=combined_pred_illustration) # SAE else: # args.model[0:3] == 'mae' loss, pred, mask_patch_indicators = model(inputs, mask_ratio=mask_ratio) # MAE loss_value = float(loss.cpu().detach().numpy()) if args.gpu == 1 else sum(loss.cpu().detach().numpy()) # log criterias: update log_running_loss += loss_value running_loss += loss_value * inputs.size(0) # attach the records to the tensorboard backend if writer is not None: # ...log the running loss writer.add_scalar('Test minibatch loss', float(loss_value), index) # at the checking time now if index % check_minibatch == check_minibatch - 1: model_time = time.time() - model_time check_index = index // check_minibatch + 1 print('Test index ' + str(check_index) + ' of ' + str(check_minibatch) + ' minibatch with batch_size of ' + str(inputs.size(0)) + ' time used:', model_time) print('minibatch AVG loss:', float(log_running_loss) / check_minibatch) model_time = time.time() log_running_loss = 0.0 # paint pic if enable_visualize_check: if args.model[0:3] == 'sae': imgs_puzzled_batch = unpatchify(imgs_puzzled_patches, patch_size=16) # Reconstructed img recons_img_batch = unpatchify(pred, patch_size=16) else: # MAE sample_img_patches = patchify(inputs, patch_size=16) # on GPU masked_img_patches = sample_img_patches * mask_patch_indicators.unsqueeze(-1).expand(-1, -1, sample_img_patches.shape[-1]) masked_img_batch = unpatchify(masked_img_patches, patch_size=16) if combined_pred_illustration: anti_mask_patch_indicators = 1 - mask_patch_indicators pred_img_patches = pred * anti_mask_patch_indicators.unsqueeze(-1).\ expand(-1, -1, sample_img_patches.shape[-1]) # Reconstructed img recons_img_batch = unpatchify(masked_img_patches + pred_img_patches, patch_size=16) else: # Reconstructed img recons_img_batch = unpatchify(pred, patch_size=16) for sampleIDX in range(check_samples): # Ori img sample_img = inputs.cpu()[sampleIDX] sample_img = ToPILImage()(sample_img) sample_img.save(os.path.join(output_dir, 'Test_sample_idx_' + str(check_index) + '_sampleIDX_' + str(sampleIDX) + '.jpg')) recons_img = recons_img_batch.cpu()[sampleIDX] recons_img = ToPILImage()(recons_img) recons_img.save(os.path.join(output_dir, 'Test_recons_idx_' + str(check_index) + '_sampleIDX_' + str(sampleIDX) + '.jpg')) # mask_img or puzzled_img if args.model[0:3] == 'sae': puzzled_img = imgs_puzzled_batch.cpu()[sampleIDX] puzzled_img = ToPILImage()(puzzled_img) puzzled_img.save(os.path.join(output_dir, 'Test_puzzled_idx_' + str(check_index) + '.jpg')) picpath = os.path.join(output_dir, 'Test_minibatchIDX_' + str(check_index) + '_sampleIDX_' + str(sampleIDX) + '.jpg') Draw_tri_fig(sample_img, puzzled_img, recons_img, picpath) else: # MAE masked_img = masked_img_batch.cpu()[sampleIDX] masked_img = ToPILImage()(masked_img) masked_img.save(os.path.join(output_dir, 'Test_masked_idx_' + str(check_index) + '.jpg')) picpath = os.path.join(output_dir, 'Test_minibatchIDX_' + str(check_index) + '_sampleIDX_' + str(sampleIDX) + '.jpg') Draw_tri_fig(sample_img, masked_img, recons_img, picpath) index += 1 # time stamp total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) # log criterias: print epoch_loss = running_loss / test_dataset_size print('\nTest_dataset_size: {} \nAvg Loss: {:.4f}'.format(test_dataset_size, epoch_loss)) print('Testing time {}'.format(total_time_str)) def main(args): # choose decoder version args.model = args.model + '_decoder' if args.seg_decoder is not None else args.model # note decoder args.model_idx = args.model_idx + args.model + '_' + args.seg_decoder if args.seg_decoder is not None \ else args.model_idx + args.model # note PromptTuning args.model_idx = args.model_idx + '_Prompt_' + args.PromptTuning + '_tokennum_' + str(args.Prompt_Token_num) \ if args.PromptTuning is not None else args.model_idx # Specify the Test settings if args.fix_position_ratio is not None and args.fix_patch_size is not None and args.mask_ratio is None: args.model_idx = 'Testing_' + args.model_idx + '_b_' + str(args.batch_size) \ + '_hint_ratio_' + str(args.fix_position_ratio) + '_patch_size_' + str(args.fix_patch_size) elif args.mask_ratio is not None and args.fix_position_ratio is None and args.fix_patch_size is None: args.model_idx = 'Testing_' + args.model_idx + '_b_' + str(args.batch_size) \ + '_mask_ratio_' + str(args.mask_ratio) else: print('not a correct test setting, should correctly specify fix_position_ratio/fix_patch_size/mask_ratio') print('\n\n' + args.model_idx + '\n\n') # setting k for: only card idx k is sighted for this code if args.gpu_idx != -1: # fixme: notice for test, we are going to use single gpu only print("Use", torch.cuda.device_count(), "GPUs of idx:", args.gpu_idx) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_idx) else: print("Use", torch.cuda.device_count(), "GPUs") args.gpu = torch.cuda.device_count() print('job AImageFolderDir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print("{}".format(args).replace(', ', ',\n')) device = torch.device(args.device) # cuda # fix the seed for reproducibility torch.manual_seed(args.seed) np.random.seed(args.seed) cudnn.benchmark = True # simple augmentation transform_test = transforms.Compose([ # transforms.RandomResizedCrop(args.input_size, scale=(0.8, 1.0), interpolation=3, ratio=(1. / 1., 1. / 1.)), # 3 is bicubic transforms.Resize(args.input_size), transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) test_dataroot = os.path.join(args.data_path) # , 'test' dataset_test = datasets.ImageFolder(test_dataroot, transform=transform_test) test_dataset_size = len(dataset_test) class_names = [d.name for d in os.scandir(test_dataroot) if d.is_dir()] class_names.sort() print('dataset_test', dataset_test) # Test data # skip minibatch, none to draw 80 figs check_minibatch = args.check_minibatch if args.check_minibatch is not None else \ test_dataset_size // (80 * args.batch_size) check_minibatch = max(1, check_minibatch) # outputs if args.log_dir is not None: args.log_dir = os.path.join(args.log_dir, args.model_idx) os.makedirs(args.log_dir, exist_ok=True) log_writer = SummaryWriter(log_dir=args.log_dir) # Tensorboard else: log_writer = None # output_dir if args.output_dir is not None: args.output_dir = os.path.join(args.output_dir, args.model_idx) os.makedirs(args.output_dir, exist_ok=True) print('Testing output files will be at', args.output_dir) data_loader_test = torch.utils.data.DataLoader(dataset_test, shuffle=args.shuffle_dataloader, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, # 建议False drop_last=True) # define the model if args.model[0:3] == 'mae': model = models_mae.__dict__[args.model](img_size=args.input_size, norm_pix_loss=args.norm_pix_loss, prompt_mode=args.PromptTuning, Prompt_Token_num=args.Prompt_Token_num, dec_idx=args.seg_decoder) elif args.model[0:3] == 'sae': model = SAE.__dict__[args.model](img_size=args.input_size, group_shuffle_size=args.group_shuffle_size, norm_pix_loss=args.norm_pix_loss, prompt_mode=args.PromptTuning, Prompt_Token_num=args.Prompt_Token_num, dec_idx=args.seg_decoder) else: print('This MIM test script only support SAE or MAE') return -1 # take model out of checkpoint and load_model state_dict = torch.load(args.checkpoint_path)['model'] model.load_state_dict(state_dict, False) model.to(device) # loss backward and optimizer operations and no longer needed in testing # loss_scaler = NativeScaler() Puzzle_test(model, data_loader_test, test_dataset_size, args.mask_ratio, args.fix_position_ratio, args.fix_patch_size, check_minibatch, args.enable_visualize_check, args.combined_pred_illustration, args.check_samples, device=device, output_dir=args.output_dir, writer=log_writer, args=args) # os.system("shutdown") # AUTO-DL server shutdown currently moved to .sh script for nohup task queue. def get_args_parser(): parser = argparse.ArgumentParser('MIM visualization for PuzzleTuning', add_help=False) # Model Name or index parser.add_argument('--model_idx', default='PuzzleTuning_', type=str, help='Model Name or index') # testing batch size parser.add_argument('--batch_size', default=16, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') # Model parameters sae_vit_base_patch16 or mae_vit_base_patch16 parser.add_argument('--model', default='sae_vit_base_patch16', type=str, metavar='MODEL', help='Name of model to train') # ori mae_vit_large_patch16 parser.add_argument('--seg_decoder', default=None, type=str, metavar='segmentation decoder', help='Name of segmentation decoder') parser.add_argument('--input_size', default=224, type=int, # 原224 help='images input size') parser.add_argument('--num_classes', default=3, type=int, # decoder seg class set to channel help='the number of classes for segmentation') # MAE mask_ratio parser.add_argument('--mask_ratio', default=None, type=float, help='Masking ratio (percentage of removed patches).') # Hint tokens parser.add_argument('--fix_position_ratio', default=None, type=float, help='basic fix_position_ratio (percentage of position token patches).') parser.add_argument('--fix_patch_size', default=None, type=int, # 原224 help='images input size') parser.add_argument('--group_shuffle_size', default=-1, type=int, help='group_shuffle_size of group shuffling,' 'default -1 for the whole batch as a group') # shuffle_dataloader parser.add_argument('--shuffle_dataloader', action='store_true', help='shuffle Test dataset') # Tuning setting # PromptTuning parser.add_argument('--PromptTuning', default=None, type=str, help='Deep/Shallow to use Prompt Tuning model instead of Finetuning model, by default None') # Prompt_Token_num parser.add_argument('--Prompt_Token_num', default=20, type=int, help='Prompt_Token_num') # loss settings parser.add_argument('--norm_pix_loss', action='store_true', help='Use (per-patch) normalized pixels as targets for computing loss') parser.set_defaults(norm_pix_loss=False) # PATH settings # Dataset parameters /root/autodl-tmp/MARS_ALL /root/autodl-tmp/imagenet /root/autodl-tmp/datasets/All parser.add_argument('--data_path', default='/root/autodl-tmp/datasets/PuzzleTuning_demoset', type=str, help='dataset path') parser.add_argument('--output_dir', default='/root/autodl-tmp/runs', help='path where to save test log, empty for no saving') parser.add_argument('--log_dir', default='/root/tf-logs', help='path where to test tensorboard log') # Enviroment parameters parser.add_argument('--gpu_idx', default=0, type=int, help='use a single GPU with its index, -1 to use multiple GPU') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=42, type=int) # ori 0 不过应该无所谓? # checkpoint_state_dict_path parser.add_argument('--checkpoint_path', default='/root/autodl-tmp/runs/PuzzleTuning_SAE_vit_base_patch16_Prompt_Deep_tokennum_20_tr_timm_CPIAm/PuzzleTuning_sae_vit_base_patch16_Prompt_Deep_tokennum_20_checkpoint-199.pth', type=str, help='load state_dict for testing') # check settings parser.add_argument('--combined_pred_illustration', action='store_true', help='check combined_pred_illustration pics') parser.add_argument('--enable_visualize_check', action='store_true', help='check and save pics') parser.add_argument('--check_minibatch', default=None, type=int, help='check batch_size') parser.add_argument('--check_samples', default=1, type=int, help='check how many images in a checking batch') # dataloader setting parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) return parser if __name__ == '__main__': args = get_args_parser() args = args.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args)