import argparse import os import ruamel_yaml as yaml import numpy as np import time import datetime import json from pathlib import Path import torch import torch.nn as nn from torch.utils.data import DataLoader from tensorboardX import SummaryWriter import utils from utils import DiceBCELoss from models.resunet import ModelResUNet_ft from test_res_ft import test from dataset.dataset_siim_acr import SIIM_ACR_Dataset from scheduler import create_scheduler from optim import create_optimizer from torchvision import models import warnings warnings.filterwarnings("ignore") def train( model, data_loader, optimizer, criterion, epoch, warmup_steps, device, scheduler, args, config, writer, ): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( "lr", utils.SmoothedValue(window_size=50, fmt="{value:.6f}") ) metric_logger.add_meter( "loss", utils.SmoothedValue(window_size=50, fmt="{value:.6f}") ) metric_logger.update(loss=1.0) metric_logger.update(lr=scheduler._get_lr(epoch)[0]) header = "Train Epoch: [{}]".format(epoch) print_freq = 50 step_size = 100 warmup_iterations = warmup_steps * step_size scalar_step = epoch * len(data_loader) for i, sample in enumerate( metric_logger.log_every(data_loader, print_freq, header) ): image = sample["image"] mask = sample["seg"].float().to(device) # batch_size,num_class input_image = image.to(device, non_blocking=True) optimizer.zero_grad() pred_map = model(input_image) # batch_size,num_class loss = criterion(pred_map, mask) loss.backward() optimizer.step() writer.add_scalar("loss/loss", loss, scalar_step) scalar_step += 1 metric_logger.update(loss=loss.item()) if epoch == 0 and i % step_size == 0 and i <= warmup_iterations: scheduler.step(i // step_size) metric_logger.update(lr=scheduler._get_lr(epoch)[0]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger.global_avg()) return { k: "{:.6f}".format(meter.global_avg) for k, meter in metric_logger.meters.items() } def valid(model, data_loader, criterion, epoch, device, config, writer): model.eval() val_scalar_step = epoch * len(data_loader) val_losses = [] for i, sample in enumerate(data_loader): image = sample["image"] mask = sample["seg"].float().to(device) # batch_size,num_class input_image = image.to(device, non_blocking=True) with torch.no_grad(): pred_map = model(input_image) val_loss = criterion(pred_map, mask) val_losses.append(val_loss.item()) writer.add_scalar("val_loss/loss", val_loss, val_scalar_step) val_scalar_step += 1 avg_val_loss = np.array(val_losses).mean() return avg_val_loss def main(args, config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Total CUDA devices: ", torch.cuda.device_count()) torch.set_default_tensor_type("torch.FloatTensor") start_epoch = 0 max_epoch = config["schedular"]["epochs"] warmup_steps = config["schedular"]["warmup_epochs"] #### Dataset #### print("Creating dataset") train_dataset = SIIM_ACR_Dataset( config["train_file"], percentage=config["percentage"] ) train_dataloader = DataLoader( train_dataset, batch_size=config["batch_size"], num_workers=30, pin_memory=True, sampler=None, shuffle=True, collate_fn=None, drop_last=True, ) val_dataset = SIIM_ACR_Dataset(config["valid_file"], is_train=False) val_dataloader = DataLoader( val_dataset, batch_size=config["batch_size"], num_workers=30, pin_memory=True, sampler=None, shuffle=True, collate_fn=None, drop_last=True, ) model = ModelResUNet_ft( res_base_model="resnet50", out_size=1, imagenet_pretrain=models.ResNet50_Weights.DEFAULT, ) if args.ddp: model = nn.DataParallel( model, device_ids=[i for i in range(torch.cuda.device_count())] ) model = model.to(device) arg_opt = utils.AttrDict(config["optimizer"]) optimizer = create_optimizer(arg_opt, model) arg_sche = utils.AttrDict(config["schedular"]) lr_scheduler, _ = create_scheduler(arg_sche, optimizer) criterion = DiceBCELoss() if args.checkpoint: checkpoint = torch.load(args.checkpoint, map_location="cpu") state_dict = checkpoint["model"] optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) start_epoch = checkpoint["epoch"] + 1 model.load_state_dict(state_dict) print("load checkpoint from %s" % args.checkpoint) elif args.pretrain_path: checkpoint = torch.load(args.pretrain_path, map_location="cpu") state_dict = checkpoint["model"] model_dict = model.state_dict() model_checkpoint = {k: v for k, v in state_dict.items() if k in model_dict} model_dict.update(model_checkpoint) model.load_state_dict(model_dict) print("load pretrain_path from %s" % args.pretrain_path) print("Start training") start_time = time.time() best_test_IoU_score = 0 best_dice_score = 0 writer = SummaryWriter(os.path.join(args.output_dir, "log")) for epoch in range(start_epoch, max_epoch): if epoch > 0: lr_scheduler.step(epoch + warmup_steps) train_stats = train( model, train_dataloader, optimizer, criterion, epoch, warmup_steps, device, lr_scheduler, args, config, writer, ) for k, v in train_stats.items(): train_loss_epoch = v writer.add_scalar("loss/train_loss_epoch", float(train_loss_epoch), epoch) writer.add_scalar("loss/leaning_rate", lr_scheduler._get_lr(epoch)[0], epoch) val_loss = valid( model, val_dataloader, criterion, epoch, device, config, writer ) writer.add_scalar("loss/val_loss_epoch", val_loss, epoch) if utils.is_main_process(): log_stats = { **{f"train_{k}": v for k, v in train_stats.items()}, "epoch": epoch, "val_loss": val_loss.item(), } save_obj = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "config": config, "epoch": epoch, } torch.save(save_obj, os.path.join(args.output_dir, "checkpoint_state.pth")) args.model_path = os.path.join(args.output_dir, "checkpoint_state.pth") with open(os.path.join(args.output_dir, "log.txt"), "a") as f: f.write(json.dumps(log_stats) + "\n") dice_score, IoU_score = test(args, config) print(IoU_score, best_test_IoU_score, dice_score, best_dice_score) if dice_score > best_dice_score: save_obj = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "config": config, "epoch": epoch, } torch.save(save_obj, os.path.join(args.output_dir, "best_valid.pth")) best_dice_score = dice_score best_test_IoU_score = IoU_score with open(os.path.join(args.output_dir, "log.txt"), "a") as f: f.write("The dice score is {dice:.4f}".format(dice=dice_score) + "\n") f.write("The iou score is {iou:.4f}".format(iou=IoU_score) + "\n") if epoch % 20 == 1 and epoch > 1: save_obj = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "config": config, "epoch": epoch, } torch.save( save_obj, os.path.join(args.output_dir, "checkpoint_" + str(epoch) + ".pth"), ) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Training time {}".format(total_time_str)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--config", default="Sample_Finetuning_SIIMACR/I2_segmentation/configs/Res_train.yaml", ) parser.add_argument("--checkpoint", default="") parser.add_argument("--model_path", default="") parser.add_argument("--pretrain_path", default="MeDSLIP_resnet50.pth") parser.add_argument( "--output_dir", default="Sample_Finetuning_SIIMACR/I2_segmentation/runs" ) parser.add_argument("--device", default="cuda") parser.add_argument("--gpu", type=str, default="0", help="gpu") parser.add_argument("--ddp", action="store_true", help="whether to use ddp") args = parser.parse_args() config = yaml.load(open(args.config, "r"), Loader=yaml.Loader) from datetime import datetime args.output_dir = os.path.join( args.output_dir, str(config["percentage"]), datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), ) args.model_path = ( args.model_path if args.model_path else os.path.join(args.output_dir, "best_valid.pth") ) Path(args.output_dir).mkdir(parents=True, exist_ok=True) yaml.dump(config, open(os.path.join(args.output_dir, "config.yaml"), "w")) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu torch.cuda.current_device() torch.cuda._initialized = True main(args, config)