|
import argparse |
|
import os |
|
import ruamel_yaml as yaml |
|
import numpy as np |
|
import time |
|
import datetime |
|
import json |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
import torch.backends.cudnn as cudnn |
|
from test_res_ft import test |
|
from tensorboardX import SummaryWriter |
|
import utils |
|
from models.resnet import ModelRes_ft |
|
from test_res_ft import test |
|
from dataset.dataset_siim_acr import SIIM_ACR_Dataset |
|
from scheduler import create_scheduler |
|
from optim import create_optimizer |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
def train( |
|
model, |
|
data_loader, |
|
optimizer, |
|
criterion, |
|
epoch, |
|
warmup_steps, |
|
device, |
|
scheduler, |
|
args, |
|
config, |
|
writer, |
|
): |
|
model.train() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
metric_logger.add_meter( |
|
"lr", utils.SmoothedValue(window_size=50, fmt="{value:.6f}") |
|
) |
|
metric_logger.add_meter( |
|
"loss", utils.SmoothedValue(window_size=50, fmt="{value:.6f}") |
|
) |
|
metric_logger.update(loss=1.0) |
|
metric_logger.update(lr=scheduler._get_lr(epoch)[0]) |
|
|
|
header = "Train Epoch: [{}]".format(epoch) |
|
print_freq = 50 |
|
step_size = 100 |
|
warmup_iterations = warmup_steps * step_size |
|
scalar_step = epoch * len(data_loader) |
|
|
|
for i, sample in enumerate( |
|
metric_logger.log_every(data_loader, print_freq, header) |
|
): |
|
image = sample["image"] |
|
label = sample["label"].float().to(device) |
|
input_image = image.to(device, non_blocking=True) |
|
|
|
optimizer.zero_grad() |
|
pred_class = model(input_image) |
|
|
|
loss = criterion(pred_class, label) |
|
loss.backward() |
|
optimizer.step() |
|
writer.add_scalar("loss/loss", loss, scalar_step) |
|
scalar_step += 1 |
|
|
|
metric_logger.update(loss=loss.item()) |
|
if epoch == 0 and i % step_size == 0 and i <= warmup_iterations: |
|
scheduler.step(i // step_size) |
|
metric_logger.update(lr=scheduler._get_lr(epoch)[0]) |
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
print("Averaged stats:", metric_logger.global_avg()) |
|
return { |
|
k: "{:.6f}".format(meter.global_avg) |
|
for k, meter in metric_logger.meters.items() |
|
} |
|
|
|
|
|
def valid(model, data_loader, criterion, epoch, device, config, writer): |
|
model.eval() |
|
val_scalar_step = epoch * len(data_loader) |
|
val_losses = [] |
|
for i, sample in enumerate(data_loader): |
|
image = sample["image"] |
|
label = sample["label"].float().to(device) |
|
input_image = image.to(device, non_blocking=True) |
|
with torch.no_grad(): |
|
pred_class = model(input_image) |
|
val_loss = criterion(pred_class, label) |
|
val_losses.append(val_loss.item()) |
|
writer.add_scalar("val_loss/loss", val_loss, val_scalar_step) |
|
val_scalar_step += 1 |
|
avg_val_loss = np.array(val_losses).mean() |
|
return avg_val_loss |
|
|
|
|
|
def main(args, config): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print("Total CUDA devices: ", torch.cuda.device_count()) |
|
torch.set_default_tensor_type("torch.FloatTensor") |
|
|
|
start_epoch = 0 |
|
max_epoch = config["schedular"]["epochs"] |
|
warmup_steps = config["schedular"]["warmup_epochs"] |
|
|
|
|
|
print("Creating dataset") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = ModelRes_ft(res_base_model="resnet50", out_size=1, use_base=args.use_base) |
|
if args.ddp: |
|
model = nn.DataParallel( |
|
model, device_ids=[i for i in range(torch.cuda.device_count())] |
|
) |
|
model = model.to(device) |
|
|
|
arg_opt = utils.AttrDict(config["optimizer"]) |
|
optimizer = create_optimizer(arg_opt, model) |
|
arg_sche = utils.AttrDict(config["schedular"]) |
|
lr_scheduler, _ = create_scheduler(arg_sche, optimizer) |
|
|
|
criterion = nn.BCEWithLogitsLoss() |
|
|
|
if args.checkpoint: |
|
checkpoint = torch.load(args.checkpoint, map_location="cpu") |
|
state_dict = checkpoint["model"] |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) |
|
start_epoch = checkpoint["epoch"] + 1 |
|
model.load_state_dict(state_dict) |
|
print("load checkpoint from %s" % args.checkpoint) |
|
elif args.pretrain_path: |
|
checkpoint = torch.load(args.pretrain_path, map_location="cpu") |
|
state_dict = checkpoint["model"] |
|
model_dict = model.state_dict() |
|
model_checkpoint = {k: v for k, v in state_dict.items() if k in model_dict} |
|
model_dict.update(model_checkpoint) |
|
model.load_state_dict(model_dict) |
|
print("load pretrain_path from %s" % args.pretrain_path) |
|
|
|
print("Start training") |
|
start_time = time.time() |
|
|
|
best_test_auc = 0.0 |
|
writer = SummaryWriter(os.path.join(args.output_dir, "log")) |
|
for epoch in range(start_epoch, max_epoch): |
|
if epoch > 0: |
|
lr_scheduler.step(epoch + warmup_steps) |
|
train_stats = train( |
|
model, |
|
train_dataloader, |
|
optimizer, |
|
criterion, |
|
epoch, |
|
warmup_steps, |
|
device, |
|
lr_scheduler, |
|
args, |
|
config, |
|
writer, |
|
) |
|
|
|
for k, v in train_stats.items(): |
|
train_loss_epoch = v |
|
|
|
writer.add_scalar("loss/train_loss_epoch", float(train_loss_epoch), epoch) |
|
writer.add_scalar("loss/leaning_rate", lr_scheduler._get_lr(epoch)[0], epoch) |
|
|
|
val_loss = valid( |
|
model, val_dataloader, criterion, epoch, device, config, writer |
|
) |
|
writer.add_scalar("loss/val_loss_epoch", val_loss, epoch) |
|
|
|
if utils.is_main_process(): |
|
log_stats = { |
|
**{f"train_{k}": v for k, v in train_stats.items()}, |
|
"epoch": epoch, |
|
"val_loss": val_loss.item(), |
|
} |
|
save_obj = { |
|
"model": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"lr_scheduler": lr_scheduler.state_dict(), |
|
"config": config, |
|
"epoch": epoch, |
|
} |
|
torch.save(save_obj, os.path.join(args.output_dir, "checkpoint_state.pth")) |
|
|
|
with open(os.path.join(args.output_dir, "log.txt"), "a") as f: |
|
f.write(json.dumps(log_stats) + "\n") |
|
|
|
test_auc = test(args, config) |
|
print(best_test_auc, test_auc) |
|
if test_auc > best_test_auc: |
|
save_obj = { |
|
"model": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"lr_scheduler": lr_scheduler.state_dict(), |
|
"config": config, |
|
"epoch": epoch, |
|
} |
|
torch.save(save_obj, os.path.join(args.output_dir, "best_test.pth")) |
|
best_test_auc = test_auc |
|
args.model_path = os.path.join(args.output_dir, "checkpoint_state.pth") |
|
|
|
with open(os.path.join(args.output_dir, "log.txt"), "a") as f: |
|
f.write( |
|
"The average AUROC is {AUROC_avg:.4f}".format(AUROC_avg=test_auc) |
|
+ "\n" |
|
) |
|
|
|
if epoch % 20 == 1 and epoch > 1: |
|
save_obj = { |
|
"model": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"lr_scheduler": lr_scheduler.state_dict(), |
|
"config": config, |
|
"epoch": epoch, |
|
} |
|
torch.save( |
|
save_obj, |
|
os.path.join(args.output_dir, "checkpoint_" + str(epoch) + ".pth"), |
|
) |
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print("Training time {}".format(total_time_str)) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--config", |
|
default="Sample_Finetuning_SIIMACR/I1_classification/configs/Res_train.yaml", |
|
) |
|
parser.add_argument("--checkpoint", default="") |
|
parser.add_argument("--model_path", default="") |
|
parser.add_argument("--pretrain_path", default="MeDSLIP_resnet50.pth") |
|
parser.add_argument( |
|
"--output_dir", default="Sample_Finetuning_SIIMACR/I1_classification/runs/" |
|
) |
|
parser.add_argument("--device", default="cuda") |
|
parser.add_argument("--gpu", type=str, default="0", help="gpu") |
|
parser.add_argument("--use_base", type=bool, default=True) |
|
parser.add_argument("--ddp", action="store_true", help="use ddp") |
|
args = parser.parse_args() |
|
|
|
config = yaml.load(open(args.config, "r"), Loader=yaml.Loader) |
|
args.output_dir = os.path.join(args.output_dir, str(config["percentage"])) |
|
from datetime import datetime |
|
|
|
args.output_dir = os.path.join( |
|
args.output_dir, datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
) |
|
args.model_path = os.path.join(args.output_dir, "checkpoint_state.pth") |
|
|
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
yaml.dump(config, open(os.path.join(args.output_dir, "config.yaml"), "w")) |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
|
torch.cuda.current_device() |
|
torch.cuda._initialized = True |
|
|
|
main(args, config) |
|
|