import argparse import os import yaml as yaml import numpy as np import random import time import datetime import json from pathlib import Path import warnings warnings.filterwarnings("ignore") import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from sklearn.metrics import roc_auc_score, precision_recall_curve, accuracy_score from models.model_MeDSLIP import MeDSLIP from dataset.dataset import Chestxray14_Dataset from models.tokenization_bert import BertTokenizer from tqdm import tqdm chexray14_cls = [ "atelectasis", "cardiomegaly", "effusion", "infiltrate", "mass", "nodule", "pneumonia", "pneumothorax", "consolidation", "edema", "emphysema", "tail_abnorm_obs", "thicken", "hernia", ] # Fibrosis seldom appears in MIMIC_CXR and is divided into the 'tail_abnorm_obs' entitiy. original_class = [ "normal", "clear", "sharp", "sharply", "unremarkable", "intact", "stable", "free", "effusion", "opacity", "pneumothorax", "edema", "atelectasis", "tube", "consolidation", "process", "abnormality", "enlarge", "tip", "low", "pneumonia", "line", "congestion", "catheter", "cardiomegaly", "fracture", "air", "tortuous", "lead", "disease", "calcification", "prominence", "device", "engorgement", "picc", "clip", "elevation", "expand", "nodule", "wire", "fluid", "degenerative", "pacemaker", "thicken", "marking", "scar", "hyperinflate", "blunt", "loss", "widen", "collapse", "density", "emphysema", "aerate", "mass", "crowd", "infiltrate", "obscure", "deformity", "hernia", "drainage", "distention", "shift", "stent", "pressure", "lesion", "finding", "borderline", "hardware", "dilation", "chf", "redistribution", "aspiration", "tail_abnorm_obs", "excluded_obs", ] mapping = [] for disease in chexray14_cls: if disease in original_class: mapping.append(original_class.index(disease)) else: mapping.append(-1) MIMIC_mapping = [_ for i, _ in enumerate(mapping) if _ != -1] chexray14_mapping = [i for i, _ in enumerate(mapping) if _ != -1] target_class = [chexray14_cls[i] for i in chexray14_mapping] def compute_AUCs(gt, pred, n_class): """Computes Area Under the Curve (AUC) from prediction scores. Args: gt: Pytorch tensor on GPU, shape = [n_samples, n_classes] true binary labels. pred: Pytorch tensor on GPU, shape = [n_samples, n_classes] can either be probability estimates of the positive class, confidence values, or binary decisions. Returns: List of AUROCs of all classes. """ AUROCs = [] gt_np = gt.cpu().numpy() pred_np = pred.cpu().numpy() for i in range(n_class): AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i])) return AUROCs def get_tokenizer(tokenizer, target_text): target_tokenizer = tokenizer( list(target_text), padding="max_length", truncation=True, max_length=64, return_tensors="pt", ) return target_tokenizer def test(args, config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Total CUDA devices: ", torch.cuda.device_count()) torch.set_default_tensor_type("torch.FloatTensor") test_dataset = Chestxray14_Dataset(config["test_file"], is_train=False) test_dataloader = DataLoader( test_dataset, batch_size=config["test_batch_size"], num_workers=30, pin_memory=True, sampler=None, shuffle=True, collate_fn=None, drop_last=False, ) print("Creating book") json_book = json.load(open(config["disease_book"], "r")) disease_book = [json_book[i] for i in json_book] tokenizer = BertTokenizer.from_pretrained(config["text_encoder"]) disease_book_tokenizer = get_tokenizer(tokenizer, disease_book).to(device) print("Creating model") model = MeDSLIP(config, disease_book_tokenizer) if args.ddp: model = nn.DataParallel( model, device_ids=[i for i in range(torch.cuda.device_count())] ) model = model.to(device) print("Load model from checkpoint:", args.model_path) checkpoint = torch.load(args.model_path, map_location="cpu") state_dict = checkpoint["model"] model.load_state_dict(state_dict) # initialize the ground truth and output tensor gt = torch.FloatTensor() gt = gt.to(device) pred = torch.FloatTensor() pred = pred.to(device) print("Start testing") model.eval() loop = tqdm(test_dataloader) for i, sample in enumerate(loop): loop.set_description(f"Testing: {i+1}/{len(test_dataloader)}") image = sample["image"] label = sample["label"][:, chexray14_mapping].float().to(device) gt = torch.cat((gt, label), 0) input_image = image.to(device, non_blocking=True) with torch.no_grad(): pred_class = model(input_image) # batch_size,num_class,dim pred_class = F.softmax(pred_class.reshape(-1, 2)).reshape( -1, len(original_class), 2 ) pred_class = pred_class[:, MIMIC_mapping, 1] pred = torch.cat((pred, pred_class), 0) AUROCs = compute_AUCs(gt, pred, len(target_class)) AUROC_avg = np.array(AUROCs).mean() print("The average AUROC is {AUROC_avg:.4f}".format(AUROC_avg=AUROC_avg)) for i in range(len(target_class)): print("The AUROC of {} is {}".format(target_class[i], AUROCs[i])) max_f1s = [] accs = [] for i in range(len(target_class)): gt_np = gt[:, i].cpu().numpy() pred_np = pred[:, i].cpu().numpy() precision, recall, thresholds = precision_recall_curve(gt_np, pred_np) numerator = 2 * recall * precision denom = recall + precision f1_scores = np.divide( numerator, denom, out=np.zeros_like(denom), where=(denom != 0) ) max_f1 = np.max(f1_scores) max_f1_thresh = thresholds[np.argmax(f1_scores)] max_f1s.append(max_f1) accs.append(accuracy_score(gt_np, pred_np > max_f1_thresh)) f1_avg = np.array(max_f1s).mean() acc_avg = np.array(accs).mean() print("The average f1 is {F1_avg:.4f}".format(F1_avg=f1_avg)) print("The average ACC is {ACC_avg:.4f}".format(ACC_avg=acc_avg)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--config", default="Sample_zero-shot_Classification_CXR14/configs/MeDSLIP_config.yaml", ) parser.add_argument("--model_path", default="MeDSLIP_resnet50.pth") parser.add_argument("--device", default="cuda") parser.add_argument("--gpu", type=str, default="0", help="gpu") parser.add_argument("--ddp", action="store_true", help="whether to use ddp") args = parser.parse_args() config = yaml.load(open(args.config, "r"), Loader=yaml.Loader) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.gpu != "-1": torch.cuda.current_device() torch.cuda._initialized = True test(args, config)