File size: 4,394 Bytes
a256709 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
"""
* Copyright (c) 2021, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, precision_recall_curve, accuracy_score
from models.resnet import ModelRes_ft
from dataset.dataset_siim_acr import SIIM_ACR_Dataset
def compute_AUCs(gt, pred, n_class):
"""Computes Area Under the Curve (AUC) from prediction scores.
Args:
gt: Pytorch tensor on GPU, shape = [n_samples, n_classes]
true binary labels.
pred: Pytorch tensor on GPU, shape = [n_samples, n_classes]
can either be probability estimates of the positive class,
confidence values, or binary decisions.
Returns:
List of AUROCs of all classes.
"""
AUROCs = []
gt_np = gt.cpu().numpy()
pred_np = pred.cpu().numpy()
for i in range(n_class):
AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i]))
return AUROCs
def test(args, config):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Total CUDA devices: ", torch.cuda.device_count())
torch.set_default_tensor_type("torch.FloatTensor")
test_dataset = SIIM_ACR_Dataset(config["test_file"], is_train=False)
test_dataloader = DataLoader(
test_dataset,
batch_size=config["test_batch_size"],
num_workers=8,
pin_memory=True,
sampler=None,
shuffle=False,
collate_fn=None,
drop_last=False,
)
model = ModelRes_ft(res_base_model="resnet50", out_size=1)
if args.ddp:
model = nn.DataParallel(
model, device_ids=[i for i in range(torch.cuda.device_count())]
)
model = model.to(device)
print("Load model from checkpoint:", args.model_path)
checkpoint = torch.load(args.model_path, map_location="cpu")
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
# initialize the ground truth and output tensor
gt = torch.FloatTensor()
gt = gt.cuda()
pred = torch.FloatTensor()
pred = pred.cuda()
print("Start testing")
model.eval()
for i, sample in enumerate(test_dataloader):
image = sample["image"]
label = sample["label"].float().to(device)
gt = torch.cat((gt, label), 0)
input_image = image.to(device, non_blocking=True)
with torch.no_grad():
pred_class = model(input_image)
pred_class = F.sigmoid(pred_class)
pred = torch.cat((pred, pred_class), 0)
AUROCs = compute_AUCs(gt, pred, config["num_classes"])
AUROC_avg = np.array(AUROCs).mean()
gt_np = gt[:, 0].cpu().numpy()
pred_np = pred[:, 0].cpu().numpy()
precision, recall, thresholds = precision_recall_curve(gt_np, pred_np)
numerator = 2 * recall * precision
denom = recall + precision
f1_scores = np.divide(
numerator, denom, out=np.zeros_like(denom), where=(denom != 0)
)
max_f1 = np.max(f1_scores)
max_f1_thresh = thresholds[np.argmax(f1_scores)]
print("The max f1 is", max_f1)
print("The accuracy is", accuracy_score(gt_np, pred_np > max_f1_thresh))
print("The average AUROC is {AUROC_avg:.3f}".format(AUROC_avg=AUROC_avg))
for i in range(config["num_classes"]):
print("The AUROC of Pneumonia is {}".format(AUROCs[i]))
return AUROC_avg
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="Sample_Finetuning_SIIMACR/I1_classification/configs/Res_train.yaml",
)
parser.add_argument("--checkpoint", default="")
parser.add_argument("--model_path", default="best_valid.pth")
parser.add_argument("--device", default="cuda")
parser.add_argument("--gpu", type=str, default="0", help="gpu")
parser.add_argument("--ddp", action="store_true", help="whether to use ddp")
args = parser.parse_args()
config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
torch.cuda.current_device()
torch.cuda._initialized = True
test(args, config)
|