import argparse import os import ruamel_yaml as yaml import numpy as np import json import torch import torch.nn as nn from torch.utils.data import DataLoader from dataset.dataset_RSNA import RSNA2018_Dataset from models.model_MeDSLIP import MeDSLIP from models.tokenization_bert import BertTokenizer from tqdm import tqdm original_class = [ "normal", "clear", "sharp", "sharply", "unremarkable", "intact", "stable", "free", "effusion", "opacity", "pneumothorax", "edema", "atelectasis", "tube", "consolidation", "process", "abnormality", "enlarge", "tip", "low", "pneumonia", "line", "congestion", "catheter", "cardiomegaly", "fracture", "air", "tortuous", "lead", "disease", "calcification", "prominence", "device", "engorgement", "picc", "clip", "elevation", "expand", "nodule", "wire", "fluid", "degenerative", "pacemaker", "thicken", "marking", "scar", "hyperinflate", "blunt", "loss", "widen", "collapse", "density", "emphysema", "aerate", "mass", "crowd", "infiltrate", "obscure", "deformity", "hernia", "drainage", "distention", "shift", "stent", "pressure", "lesion", "finding", "borderline", "hardware", "dilation", "chf", "redistribution", "aspiration", "tail_abnorm_obs", "excluded_obs", ] def get_tokenizer(tokenizer, target_text): target_tokenizer = tokenizer( list(target_text), padding="max_length", truncation=True, max_length=64, return_tensors="pt", ) return target_tokenizer def score_cal(labels, seg_map, pred_map, threshold=0.005): """ labels B * 1 seg_map B *H * W pred_map B * H * W """ device = labels.device total_num = torch.sum(labels) mask = (labels == 1).squeeze() seg_map = seg_map[mask, :, :].reshape(total_num, -1) pred_map = pred_map[mask, :, :].reshape(total_num, -1) one_hot_map = pred_map > threshold dot_product = (seg_map * one_hot_map).reshape(total_num, -1) max_number = torch.max(pred_map, dim=-1)[0] point_score = 0 for i, number in enumerate(max_number): temp_pred = (pred_map[i] == number).type(torch.int) flag = int((torch.sum(temp_pred * seg_map[i])) > 0) point_score = point_score + flag mass_score = torch.sum(dot_product, dim=-1) / ( (torch.sum(seg_map, dim=-1) + torch.sum(one_hot_map, dim=-1)) - torch.sum(dot_product, dim=-1) ) dice_score = ( 2 * (torch.sum(dot_product, dim=-1)) / (torch.sum(seg_map, dim=-1) + torch.sum(one_hot_map, dim=-1)) ) return total_num, point_score, mass_score.to(device), dice_score.to(device) def main(args, config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Total CUDA devices: ", torch.cuda.device_count()) torch.set_default_tensor_type("torch.FloatTensor") #### Dataset #### print("Creating dataset") test_dataset = RSNA2018_Dataset(config["test_file"]) test_dataloader = DataLoader( test_dataset, batch_size=config["test_batch_size"], num_workers=30, pin_memory=True, sampler=None, shuffle=False, collate_fn=None, drop_last=False, ) json_book = json.load(open(config["disease_book"], "r")) disease_book = [json_book[i] for i in json_book] ana_list = [ "trachea", "left_hilar", "right_hilar", "hilar_unspec", "left_pleural", "right_pleural", "pleural_unspec", "heart_size", "heart_border", "left_diaphragm", "right_diaphragm", "diaphragm_unspec", "retrocardiac", "lower_left_lobe", "upper_left_lobe", "lower_right_lobe", "middle_right_lobe", "upper_right_lobe", "left_lower_lung", "left_mid_lung", "left_upper_lung", "left_apical_lung", "left_lung_unspec", "right_lower_lung", "right_mid_lung", "right_upper_lung", "right_apical_lung", "right_lung_unspec", "lung_apices", "lung_bases", "left_costophrenic", "right_costophrenic", "costophrenic_unspec", "cardiophrenic_sulcus", "mediastinal", "spine", "clavicle", "rib", "stomach", "right_atrium", "right_ventricle", "aorta", "svc", "interstitium", "parenchymal", "cavoatrial_junction", "cardiopulmonary", "pulmonary", "lung_volumes", "unspecified", "other", ] ana_book = [] for i in ana_list: ana_book.append("It is located at " + i + ". ") tokenizer = BertTokenizer.from_pretrained(config["text_encoder"]) ana_book_tokenizer = get_tokenizer(tokenizer, ana_book).to(device) disease_book_tokenizer = get_tokenizer(tokenizer, disease_book).to(device) print("Creating model") model = MeDSLIP(config, ana_book_tokenizer, disease_book_tokenizer, mode="train") if args.ddp: model = nn.DataParallel( model, device_ids=[i for i in range(torch.cuda.device_count())] ) model = model.to(device) checkpoint = torch.load(args.checkpoint, map_location="cpu") state_dict = checkpoint["model"] model.load_state_dict(state_dict, strict=False) print("load checkpoint from %s" % args.checkpoint) print("Start testing") model.eval() dice_score_A = torch.FloatTensor() dice_score_A = dice_score_A.to(device) mass_score_A = torch.FloatTensor() mass_score_A = mass_score_A.to(device) total_num_A = 0 point_num_A = 0 loop = tqdm(test_dataloader) for i, sample in enumerate(loop): loop.set_description(f"Testing: {i+1}/{len(test_dataloader)}") images = sample["image"].to(device) image_path = sample["image_path"] batch_size = images.shape[0] labels = sample["label"].to(device) seg_map = sample["seg_map"][:, 0, :, :].to(device) # B C H W with torch.no_grad(): _, _, ws_e, ws_p, features_e, features_p = model( images, labels, is_train=False ) features_e = features_e.transpose(0, 1) features_p = features_p.transpose(0, 1) ws_e = (ws_e[-4] + ws_e[-3] + ws_e[-2] + ws_e[-1]) / 4 ws_p = (ws_p[-4] + ws_p[-3] + ws_p[-2] + ws_p[-1]) / 4 pred_map = ws_e[:, original_class.index("pneumonia"), :] threshold = 0 if args.use_ws_p: pred_map = pred_map.unsqueeze(1) pred_map = pred_map.repeat(1, ws_p.shape[1], 1) pred_map = (pred_map * ws_p).mean(axis=1) threshold = 0.01 pred_map = pred_map / torch.max(pred_map) pred_map = pred_map.reshape(batch_size, 14, 14).detach().cpu().numpy() pred_map = torch.from_numpy( pred_map.repeat(16, axis=1).repeat(16, axis=2) ).to( device ) # Final Grounding Heatmap total_num, point_num, mass_score, dice_score = score_cal( labels, seg_map, pred_map, threshold=threshold ) total_num_A = total_num_A + total_num point_num_A = point_num_A + point_num dice_score_A = torch.cat((dice_score_A, dice_score), dim=0) mass_score_A = torch.cat((mass_score_A, mass_score), dim=0) dice_score_avg = torch.mean(dice_score_A) mass_score_avg = torch.mean(mass_score_A) print( "The average dice_score is {dice_score_avg:.5f}".format( dice_score_avg=dice_score_avg ) ) print( "The average iou_score is {mass_score_avg:.5f}".format( mass_score_avg=mass_score_avg ) ) point_score = point_num_A / total_num_A print( "The average point_score is {point_score:.5f}".format(point_score=point_score) ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--config", default="Sample_Zero-Shot_Grounding_RSNA/configs/MeDSLIP_config.yaml", ) parser.add_argument("--checkpoint", default="MeDSLIP_resnet50.pth") parser.add_argument("--device", default="cuda") parser.add_argument("--gpu", type=str, default="0", help="gpu") parser.add_argument("--ddp", action="store_true", help="whether to use ddp") args = parser.parse_args() config = yaml.load(open(args.config, "r"), Loader=yaml.Loader) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.gpu != "-1": torch.cuda.current_device() torch.cuda._initialized = True main(args, config)