import argparse import os import ruamel_yaml as yaml import numpy as np import time import datetime import json from pathlib import Path import torch import torch.nn as nn from torch.utils.data import DataLoader import torch.backends.cudnn as cudnn from test_res_ft import test from tensorboardX import SummaryWriter import utils from models.resnet import ModelRes_ft from test_res_ft import test from dataset.dataset_siim_acr import SIIM_ACR_Dataset from scheduler import create_scheduler from optim import create_optimizer import warnings warnings.filterwarnings("ignore") def train( model, data_loader, optimizer, criterion, epoch, warmup_steps, device, scheduler, args, config, writer, ): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( "lr", utils.SmoothedValue(window_size=50, fmt="{value:.6f}") ) metric_logger.add_meter( "loss", utils.SmoothedValue(window_size=50, fmt="{value:.6f}") ) metric_logger.update(loss=1.0) metric_logger.update(lr=scheduler._get_lr(epoch)[0]) header = "Train Epoch: [{}]".format(epoch) print_freq = 50 step_size = 100 warmup_iterations = warmup_steps * step_size scalar_step = epoch * len(data_loader) for i, sample in enumerate( metric_logger.log_every(data_loader, print_freq, header) ): image = sample["image"] label = sample["label"].float().to(device) # batch_size,num_class input_image = image.to(device, non_blocking=True) optimizer.zero_grad() pred_class = model(input_image) # batch_size,num_class loss = criterion(pred_class, label) loss.backward() optimizer.step() writer.add_scalar("loss/loss", loss, scalar_step) scalar_step += 1 metric_logger.update(loss=loss.item()) if epoch == 0 and i % step_size == 0 and i <= warmup_iterations: scheduler.step(i // step_size) metric_logger.update(lr=scheduler._get_lr(epoch)[0]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger.global_avg()) return { k: "{:.6f}".format(meter.global_avg) for k, meter in metric_logger.meters.items() } def valid(model, data_loader, criterion, epoch, device, config, writer): model.eval() val_scalar_step = epoch * len(data_loader) val_losses = [] for i, sample in enumerate(data_loader): image = sample["image"] label = sample["label"].float().to(device) input_image = image.to(device, non_blocking=True) with torch.no_grad(): pred_class = model(input_image) val_loss = criterion(pred_class, label) val_losses.append(val_loss.item()) writer.add_scalar("val_loss/loss", val_loss, val_scalar_step) val_scalar_step += 1 avg_val_loss = np.array(val_losses).mean() return avg_val_loss def main(args, config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Total CUDA devices: ", torch.cuda.device_count()) torch.set_default_tensor_type("torch.FloatTensor") start_epoch = 0 max_epoch = config["schedular"]["epochs"] warmup_steps = config["schedular"]["warmup_epochs"] #### Dataset #### print("Creating dataset") # train_dataset = SIIM_ACR_Dataset( # config["train_file"], percentage=config["percentage"] # ) # train_dataloader = DataLoader( # train_dataset, # batch_size=config["batch_size"], # num_workers=30, # pin_memory=True, # sampler=None, # shuffle=True, # collate_fn=None, # drop_last=True, # ) # val_dataset = SIIM_ACR_Dataset(config["valid_file"], is_train=False) # val_dataloader = DataLoader( # val_dataset, # batch_size=config["batch_size"], # num_workers=30, # pin_memory=True, # sampler=None, # shuffle=False, # collate_fn=None, # drop_last=False, # ) # print(len(train_dataset), len(val_dataset)) model = ModelRes_ft(res_base_model="resnet50", out_size=1, use_base=args.use_base) if args.ddp: model = nn.DataParallel( model, device_ids=[i for i in range(torch.cuda.device_count())] ) model = model.to(device) arg_opt = utils.AttrDict(config["optimizer"]) optimizer = create_optimizer(arg_opt, model) arg_sche = utils.AttrDict(config["schedular"]) lr_scheduler, _ = create_scheduler(arg_sche, optimizer) criterion = nn.BCEWithLogitsLoss() if args.checkpoint: checkpoint = torch.load(args.checkpoint, map_location="cpu") state_dict = checkpoint["model"] optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) start_epoch = checkpoint["epoch"] + 1 model.load_state_dict(state_dict) print("load checkpoint from %s" % args.checkpoint) elif args.pretrain_path: checkpoint = torch.load(args.pretrain_path, map_location="cpu") state_dict = checkpoint["model"] model_dict = model.state_dict() model_checkpoint = {k: v for k, v in state_dict.items() if k in model_dict} model_dict.update(model_checkpoint) model.load_state_dict(model_dict) print("load pretrain_path from %s" % args.pretrain_path) print("Start training") start_time = time.time() best_test_auc = 0.0 writer = SummaryWriter(os.path.join(args.output_dir, "log")) for epoch in range(start_epoch, max_epoch): if epoch > 0: lr_scheduler.step(epoch + warmup_steps) train_stats = train( model, train_dataloader, optimizer, criterion, epoch, warmup_steps, device, lr_scheduler, args, config, writer, ) for k, v in train_stats.items(): train_loss_epoch = v writer.add_scalar("loss/train_loss_epoch", float(train_loss_epoch), epoch) writer.add_scalar("loss/leaning_rate", lr_scheduler._get_lr(epoch)[0], epoch) val_loss = valid( model, val_dataloader, criterion, epoch, device, config, writer ) writer.add_scalar("loss/val_loss_epoch", val_loss, epoch) if utils.is_main_process(): log_stats = { **{f"train_{k}": v for k, v in train_stats.items()}, "epoch": epoch, "val_loss": val_loss.item(), } save_obj = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "config": config, "epoch": epoch, } torch.save(save_obj, os.path.join(args.output_dir, "checkpoint_state.pth")) with open(os.path.join(args.output_dir, "log.txt"), "a") as f: f.write(json.dumps(log_stats) + "\n") test_auc = test(args, config) print(best_test_auc, test_auc) if test_auc > best_test_auc: save_obj = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "config": config, "epoch": epoch, } torch.save(save_obj, os.path.join(args.output_dir, "best_test.pth")) best_test_auc = test_auc args.model_path = os.path.join(args.output_dir, "checkpoint_state.pth") with open(os.path.join(args.output_dir, "log.txt"), "a") as f: f.write( "The average AUROC is {AUROC_avg:.4f}".format(AUROC_avg=test_auc) + "\n" ) if epoch % 20 == 1 and epoch > 1: save_obj = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "config": config, "epoch": epoch, } torch.save( save_obj, os.path.join(args.output_dir, "checkpoint_" + str(epoch) + ".pth"), ) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Training time {}".format(total_time_str)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--config", default="Sample_Finetuning_SIIMACR/I1_classification/configs/Res_train.yaml", ) parser.add_argument("--checkpoint", default="") parser.add_argument("--model_path", default="") parser.add_argument("--pretrain_path", default="MeDSLIP_resnet50.pth") parser.add_argument( "--output_dir", default="Sample_Finetuning_SIIMACR/I1_classification/runs/" ) parser.add_argument("--device", default="cuda") parser.add_argument("--gpu", type=str, default="0", help="gpu") parser.add_argument("--use_base", type=bool, default=True) parser.add_argument("--ddp", action="store_true", help="use ddp") args = parser.parse_args() config = yaml.load(open(args.config, "r"), Loader=yaml.Loader) args.output_dir = os.path.join(args.output_dir, str(config["percentage"])) from datetime import datetime args.output_dir = os.path.join( args.output_dir, datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) args.model_path = os.path.join(args.output_dir, "checkpoint_state.pth") Path(args.output_dir).mkdir(parents=True, exist_ok=True) yaml.dump(config, open(os.path.join(args.output_dir, "config.yaml"), "w")) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu torch.cuda.current_device() torch.cuda._initialized = True main(args, config)