|
import argparse |
|
import os |
|
import ruamel_yaml as yaml |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
|
|
from models.resunet import ModelResUNet_ft |
|
|
|
from dataset.dataset_siim_acr import SIIM_ACR_Dataset |
|
from metric import mIoU, dice |
|
|
|
|
|
def test(args, config): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print("Total CUDA devices: ", torch.cuda.device_count()) |
|
torch.set_default_tensor_type("torch.FloatTensor") |
|
|
|
test_dataset = SIIM_ACR_Dataset(config["test_file"], is_train=False) |
|
test_dataloader = DataLoader( |
|
test_dataset, |
|
batch_size=config["test_batch_size"], |
|
num_workers=4, |
|
pin_memory=True, |
|
sampler=None, |
|
shuffle=True, |
|
collate_fn=None, |
|
drop_last=True, |
|
) |
|
|
|
model = ModelResUNet_ft( |
|
res_base_model="resnet50", out_size=1, imagenet_pretrain=False |
|
) |
|
model = nn.DataParallel( |
|
model, device_ids=[i for i in range(torch.cuda.device_count())] |
|
) |
|
model = model.to(device) |
|
|
|
print("Load model from checkpoint:", args.model_path) |
|
checkpoint = torch.load(args.model_path, map_location="cpu") |
|
state_dict = checkpoint["model"] |
|
model.load_state_dict(state_dict) |
|
|
|
|
|
gt = torch.FloatTensor() |
|
gt = gt.cuda() |
|
pred = torch.FloatTensor() |
|
pred = pred.cuda() |
|
|
|
print("Start testing") |
|
model.eval() |
|
for i, sample in enumerate(test_dataloader): |
|
image = sample["image"] |
|
mask = sample["seg"].float().to(device) |
|
gt = torch.cat((gt, mask), 0) |
|
input_image = image.to(device, non_blocking=True) |
|
with torch.no_grad(): |
|
pred_mask = model(input_image) |
|
pred_mask = F.sigmoid(pred_mask) |
|
pred = torch.cat((pred, pred_mask), 0) |
|
dice_score, dice_neg, dice_pos, num_neg, num_pos = dice(pred, gt) |
|
IoU_score = mIoU(pred, gt) |
|
print("Dice score is", dice_score) |
|
print("IoU score is", IoU_score) |
|
return dice_score, IoU_score |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", default="Path/To/Res_train.yaml") |
|
parser.add_argument("--checkpoint", default="") |
|
parser.add_argument("--model_path", default="Path/To/best_valid.pth") |
|
parser.add_argument("--device", default="cuda") |
|
parser.add_argument("--gpu", type=str, default="0", help="gpu") |
|
args = parser.parse_args() |
|
|
|
config = yaml.load(open(args.config, "r"), Loader=yaml.Loader) |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
|
torch.cuda.current_device() |
|
torch.cuda._initialized = True |
|
|
|
test(args, config) |
|
|