|
import argparse |
|
import os |
|
import ruamel_yaml as yaml |
|
import numpy as np |
|
import json |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
from dataset.dataset_RSNA import RSNA2018_Dataset |
|
from models.model_MeDSLIP import MeDSLIP |
|
from models.tokenization_bert import BertTokenizer |
|
from tqdm import tqdm |
|
|
|
original_class = [ |
|
"normal", |
|
"clear", |
|
"sharp", |
|
"sharply", |
|
"unremarkable", |
|
"intact", |
|
"stable", |
|
"free", |
|
"effusion", |
|
"opacity", |
|
"pneumothorax", |
|
"edema", |
|
"atelectasis", |
|
"tube", |
|
"consolidation", |
|
"process", |
|
"abnormality", |
|
"enlarge", |
|
"tip", |
|
"low", |
|
"pneumonia", |
|
"line", |
|
"congestion", |
|
"catheter", |
|
"cardiomegaly", |
|
"fracture", |
|
"air", |
|
"tortuous", |
|
"lead", |
|
"disease", |
|
"calcification", |
|
"prominence", |
|
"device", |
|
"engorgement", |
|
"picc", |
|
"clip", |
|
"elevation", |
|
"expand", |
|
"nodule", |
|
"wire", |
|
"fluid", |
|
"degenerative", |
|
"pacemaker", |
|
"thicken", |
|
"marking", |
|
"scar", |
|
"hyperinflate", |
|
"blunt", |
|
"loss", |
|
"widen", |
|
"collapse", |
|
"density", |
|
"emphysema", |
|
"aerate", |
|
"mass", |
|
"crowd", |
|
"infiltrate", |
|
"obscure", |
|
"deformity", |
|
"hernia", |
|
"drainage", |
|
"distention", |
|
"shift", |
|
"stent", |
|
"pressure", |
|
"lesion", |
|
"finding", |
|
"borderline", |
|
"hardware", |
|
"dilation", |
|
"chf", |
|
"redistribution", |
|
"aspiration", |
|
"tail_abnorm_obs", |
|
"excluded_obs", |
|
] |
|
|
|
|
|
def get_tokenizer(tokenizer, target_text): |
|
|
|
target_tokenizer = tokenizer( |
|
list(target_text), |
|
padding="max_length", |
|
truncation=True, |
|
max_length=64, |
|
return_tensors="pt", |
|
) |
|
|
|
return target_tokenizer |
|
|
|
|
|
def score_cal(labels, seg_map, pred_map, threshold=0.005): |
|
""" |
|
labels B * 1 |
|
seg_map B *H * W |
|
pred_map B * H * W |
|
""" |
|
device = labels.device |
|
total_num = torch.sum(labels) |
|
mask = (labels == 1).squeeze() |
|
seg_map = seg_map[mask, :, :].reshape(total_num, -1) |
|
pred_map = pred_map[mask, :, :].reshape(total_num, -1) |
|
one_hot_map = pred_map > threshold |
|
dot_product = (seg_map * one_hot_map).reshape(total_num, -1) |
|
|
|
max_number = torch.max(pred_map, dim=-1)[0] |
|
point_score = 0 |
|
for i, number in enumerate(max_number): |
|
temp_pred = (pred_map[i] == number).type(torch.int) |
|
flag = int((torch.sum(temp_pred * seg_map[i])) > 0) |
|
point_score = point_score + flag |
|
mass_score = torch.sum(dot_product, dim=-1) / ( |
|
(torch.sum(seg_map, dim=-1) + torch.sum(one_hot_map, dim=-1)) |
|
- torch.sum(dot_product, dim=-1) |
|
) |
|
dice_score = ( |
|
2 |
|
* (torch.sum(dot_product, dim=-1)) |
|
/ (torch.sum(seg_map, dim=-1) + torch.sum(one_hot_map, dim=-1)) |
|
) |
|
return total_num, point_score, mass_score.to(device), dice_score.to(device) |
|
|
|
|
|
def main(args, config): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print("Total CUDA devices: ", torch.cuda.device_count()) |
|
torch.set_default_tensor_type("torch.FloatTensor") |
|
|
|
|
|
print("Creating dataset") |
|
test_dataset = RSNA2018_Dataset(config["test_file"]) |
|
test_dataloader = DataLoader( |
|
test_dataset, |
|
batch_size=config["test_batch_size"], |
|
num_workers=30, |
|
pin_memory=True, |
|
sampler=None, |
|
shuffle=False, |
|
collate_fn=None, |
|
drop_last=False, |
|
) |
|
json_book = json.load(open(config["disease_book"], "r")) |
|
disease_book = [json_book[i] for i in json_book] |
|
ana_list = [ |
|
"trachea", |
|
"left_hilar", |
|
"right_hilar", |
|
"hilar_unspec", |
|
"left_pleural", |
|
"right_pleural", |
|
"pleural_unspec", |
|
"heart_size", |
|
"heart_border", |
|
"left_diaphragm", |
|
"right_diaphragm", |
|
"diaphragm_unspec", |
|
"retrocardiac", |
|
"lower_left_lobe", |
|
"upper_left_lobe", |
|
"lower_right_lobe", |
|
"middle_right_lobe", |
|
"upper_right_lobe", |
|
"left_lower_lung", |
|
"left_mid_lung", |
|
"left_upper_lung", |
|
"left_apical_lung", |
|
"left_lung_unspec", |
|
"right_lower_lung", |
|
"right_mid_lung", |
|
"right_upper_lung", |
|
"right_apical_lung", |
|
"right_lung_unspec", |
|
"lung_apices", |
|
"lung_bases", |
|
"left_costophrenic", |
|
"right_costophrenic", |
|
"costophrenic_unspec", |
|
"cardiophrenic_sulcus", |
|
"mediastinal", |
|
"spine", |
|
"clavicle", |
|
"rib", |
|
"stomach", |
|
"right_atrium", |
|
"right_ventricle", |
|
"aorta", |
|
"svc", |
|
"interstitium", |
|
"parenchymal", |
|
"cavoatrial_junction", |
|
"cardiopulmonary", |
|
"pulmonary", |
|
"lung_volumes", |
|
"unspecified", |
|
"other", |
|
] |
|
ana_book = [] |
|
for i in ana_list: |
|
ana_book.append("It is located at " + i + ". ") |
|
tokenizer = BertTokenizer.from_pretrained(config["text_encoder"]) |
|
ana_book_tokenizer = get_tokenizer(tokenizer, ana_book).to(device) |
|
disease_book_tokenizer = get_tokenizer(tokenizer, disease_book).to(device) |
|
|
|
print("Creating model") |
|
model = MeDSLIP(config, ana_book_tokenizer, disease_book_tokenizer, mode="train") |
|
if args.ddp: |
|
model = nn.DataParallel( |
|
model, device_ids=[i for i in range(torch.cuda.device_count())] |
|
) |
|
model = model.to(device) |
|
|
|
checkpoint = torch.load(args.checkpoint, map_location="cpu") |
|
state_dict = checkpoint["model"] |
|
model.load_state_dict(state_dict, strict=False) |
|
print("load checkpoint from %s" % args.checkpoint) |
|
|
|
print("Start testing") |
|
model.eval() |
|
|
|
dice_score_A = torch.FloatTensor() |
|
dice_score_A = dice_score_A.to(device) |
|
mass_score_A = torch.FloatTensor() |
|
mass_score_A = mass_score_A.to(device) |
|
total_num_A = 0 |
|
point_num_A = 0 |
|
loop = tqdm(test_dataloader) |
|
for i, sample in enumerate(loop): |
|
loop.set_description(f"Testing: {i+1}/{len(test_dataloader)}") |
|
images = sample["image"].to(device) |
|
image_path = sample["image_path"] |
|
batch_size = images.shape[0] |
|
labels = sample["label"].to(device) |
|
seg_map = sample["seg_map"][:, 0, :, :].to(device) |
|
|
|
with torch.no_grad(): |
|
_, _, ws_e, ws_p, features_e, features_p = model( |
|
images, labels, is_train=False |
|
) |
|
features_e = features_e.transpose(0, 1) |
|
features_p = features_p.transpose(0, 1) |
|
ws_e = (ws_e[-4] + ws_e[-3] + ws_e[-2] + ws_e[-1]) / 4 |
|
ws_p = (ws_p[-4] + ws_p[-3] + ws_p[-2] + ws_p[-1]) / 4 |
|
pred_map = ws_e[:, original_class.index("pneumonia"), :] |
|
|
|
threshold = 0 |
|
if args.use_ws_p: |
|
pred_map = pred_map.unsqueeze(1) |
|
pred_map = pred_map.repeat(1, ws_p.shape[1], 1) |
|
pred_map = (pred_map * ws_p).mean(axis=1) |
|
threshold = 0.01 |
|
|
|
pred_map = pred_map / torch.max(pred_map) |
|
|
|
pred_map = pred_map.reshape(batch_size, 14, 14).detach().cpu().numpy() |
|
|
|
pred_map = torch.from_numpy( |
|
pred_map.repeat(16, axis=1).repeat(16, axis=2) |
|
).to( |
|
device |
|
) |
|
|
|
total_num, point_num, mass_score, dice_score = score_cal( |
|
labels, seg_map, pred_map, threshold=threshold |
|
) |
|
total_num_A = total_num_A + total_num |
|
point_num_A = point_num_A + point_num |
|
dice_score_A = torch.cat((dice_score_A, dice_score), dim=0) |
|
mass_score_A = torch.cat((mass_score_A, mass_score), dim=0) |
|
|
|
dice_score_avg = torch.mean(dice_score_A) |
|
mass_score_avg = torch.mean(mass_score_A) |
|
print( |
|
"The average dice_score is {dice_score_avg:.5f}".format( |
|
dice_score_avg=dice_score_avg |
|
) |
|
) |
|
print( |
|
"The average iou_score is {mass_score_avg:.5f}".format( |
|
mass_score_avg=mass_score_avg |
|
) |
|
) |
|
point_score = point_num_A / total_num_A |
|
print( |
|
"The average point_score is {point_score:.5f}".format(point_score=point_score) |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--config", |
|
default="Sample_Zero-Shot_Grounding_RSNA/configs/MeDSLIP_config.yaml", |
|
) |
|
parser.add_argument("--checkpoint", default="MeDSLIP_resnet50.pth") |
|
parser.add_argument("--device", default="cuda") |
|
parser.add_argument("--gpu", type=str, default="0", help="gpu") |
|
parser.add_argument("--ddp", action="store_true", help="whether to use ddp") |
|
|
|
args = parser.parse_args() |
|
|
|
config = yaml.load(open(args.config, "r"), Loader=yaml.Loader) |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
|
if args.gpu != "-1": |
|
torch.cuda.current_device() |
|
torch.cuda._initialized = True |
|
|
|
main(args, config) |
|
|