|
import json |
|
from torch.utils.data import DataLoader |
|
import PIL |
|
from torch.utils.data import Dataset |
|
import numpy as np |
|
import pandas as pd |
|
from torchvision import transforms |
|
from PIL import Image |
|
import random |
|
from dataset.randaugment import RandomAugment |
|
|
|
|
|
class MeDSLIP_Dataset(Dataset): |
|
def __init__(self, csv_path, np_path, mode="train", num_neg_samples=7): |
|
self.num_neg_samples = num_neg_samples |
|
self.ann = json.load(open(csv_path, "r")) |
|
self.img_path_list = list(self.ann) |
|
self.anaomy_list = [ |
|
"trachea", |
|
"left_hilar", |
|
"right_hilar", |
|
"hilar_unspec", |
|
"left_pleural", |
|
"right_pleural", |
|
"pleural_unspec", |
|
"heart_size", |
|
"heart_border", |
|
"left_diaphragm", |
|
"right_diaphragm", |
|
"diaphragm_unspec", |
|
"retrocardiac", |
|
"lower_left_lobe", |
|
"upper_left_lobe", |
|
"lower_right_lobe", |
|
"middle_right_lobe", |
|
"upper_right_lobe", |
|
"left_lower_lung", |
|
"left_mid_lung", |
|
"left_upper_lung", |
|
"left_apical_lung", |
|
"left_lung_unspec", |
|
"right_lower_lung", |
|
"right_mid_lung", |
|
"right_upper_lung", |
|
"right_apical_lung", |
|
"right_lung_unspec", |
|
"lung_apices", |
|
"lung_bases", |
|
"left_costophrenic", |
|
"right_costophrenic", |
|
"costophrenic_unspec", |
|
"cardiophrenic_sulcus", |
|
"mediastinal", |
|
"spine", |
|
"clavicle", |
|
"rib", |
|
"stomach", |
|
"right_atrium", |
|
"right_ventricle", |
|
"aorta", |
|
"svc", |
|
"interstitium", |
|
"parenchymal", |
|
"cavoatrial_junction", |
|
"cardiopulmonary", |
|
"pulmonary", |
|
"lung_volumes", |
|
"unspecified", |
|
"other", |
|
] |
|
self.obs_list = [ |
|
"normal", |
|
"clear", |
|
"sharp", |
|
"sharply", |
|
"unremarkable", |
|
"intact", |
|
"stable", |
|
"free", |
|
"effusion", |
|
"opacity", |
|
"pneumothorax", |
|
"edema", |
|
"atelectasis", |
|
"tube", |
|
"consolidation", |
|
"process", |
|
"abnormality", |
|
"enlarge", |
|
"tip", |
|
"low", |
|
"pneumonia", |
|
"line", |
|
"congestion", |
|
"catheter", |
|
"cardiomegaly", |
|
"fracture", |
|
"air", |
|
"tortuous", |
|
"lead", |
|
"disease", |
|
"calcification", |
|
"prominence", |
|
"device", |
|
"engorgement", |
|
"picc", |
|
"clip", |
|
"elevation", |
|
"expand", |
|
"nodule", |
|
"wire", |
|
"fluid", |
|
"degenerative", |
|
"pacemaker", |
|
"thicken", |
|
"marking", |
|
"scar", |
|
"hyperinflate", |
|
"blunt", |
|
"loss", |
|
"widen", |
|
"collapse", |
|
"density", |
|
"emphysema", |
|
"aerate", |
|
"mass", |
|
"crowd", |
|
"infiltrate", |
|
"obscure", |
|
"deformity", |
|
"hernia", |
|
"drainage", |
|
"distention", |
|
"shift", |
|
"stent", |
|
"pressure", |
|
"lesion", |
|
"finding", |
|
"borderline", |
|
"hardware", |
|
"dilation", |
|
"chf", |
|
"redistribution", |
|
"aspiration", |
|
"tail_abnorm_obs", |
|
"excluded_obs", |
|
] |
|
self.rad_graph_results = np.load(np_path) |
|
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
if mode == "train": |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.RandomResizedCrop( |
|
224, scale=(0.2, 1.0), interpolation=Image.BICUBIC |
|
), |
|
transforms.RandomHorizontalFlip(), |
|
RandomAugment( |
|
2, |
|
7, |
|
isPIL=True, |
|
augs=[ |
|
"Identity", |
|
"AutoContrast", |
|
"Equalize", |
|
"Brightness", |
|
"Sharpness", |
|
"ShearX", |
|
"ShearY", |
|
"TranslateX", |
|
"TranslateY", |
|
"Rotate", |
|
], |
|
), |
|
transforms.ToTensor(), |
|
normalize, |
|
] |
|
) |
|
if mode == "test": |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize([224, 224]), |
|
transforms.ToTensor(), |
|
normalize, |
|
] |
|
) |
|
|
|
def __getitem__(self, index): |
|
img_path = self.img_path_list[index] |
|
class_label = self.rad_graph_results[ |
|
self.ann[img_path]["labels_id"], :, : |
|
] |
|
labels_pathology = np.zeros(class_label.shape[-1]) - 1 |
|
labels_anatomy = np.zeros(class_label.shape[0]) - 1 |
|
labels_pathology, index_list_pathology = self.triplet_extraction_pathology( |
|
class_label |
|
) |
|
labels_anatomy, index_list_anatomy = self.triplet_extraction_anatomy( |
|
class_label |
|
) |
|
index_list_pathology = np.array(index_list_pathology) |
|
index_list_anatomy = np.array(index_list_anatomy) |
|
|
|
img = PIL.Image.open(img_path).convert("RGB") |
|
image = self.transform(img) |
|
|
|
return { |
|
"image": image, |
|
"label_pathology": labels_pathology, |
|
"index_pathology": index_list_pathology, |
|
"label_anatomy": labels_anatomy, |
|
"index_anatomy": index_list_anatomy, |
|
"matrix": class_label, |
|
} |
|
|
|
def triplet_extraction_pathology(self, class_label): |
|
""" |
|
This is for ProtoCL. Therefore, we need to extract anatomies to use in pathology stream. |
|
""" |
|
|
|
exist_labels = np.zeros(class_label.shape[-1]) - 1 |
|
anatomy_list = [] |
|
for i in range(class_label.shape[1]): |
|
temp_list = [] |
|
|
|
if 0 in class_label[:, i]: |
|
exist_labels[i] = 0 |
|
|
|
if 1 in class_label[:, i]: |
|
exist_labels[i] = 1 |
|
|
|
|
|
temp_list.append(-1) |
|
|
|
try: |
|
temp_list = temp_list + random.sample( |
|
np.where(class_label[:, i] != 1)[0].tolist(), |
|
self.num_neg_samples, |
|
) |
|
except: |
|
print("fatal error") |
|
if temp_list == []: |
|
temp_list = temp_list + random.sample( |
|
np.where(class_label[:, i] != 1)[0].tolist(), |
|
self.num_neg_samples + 1, |
|
) |
|
anatomy_list.append(temp_list) |
|
|
|
return exist_labels, anatomy_list |
|
|
|
def triplet_extraction_anatomy(self, class_label): |
|
""" |
|
This is for ProtoCL. Therefore, we need to extract pathological labels to use in anatomy stream. |
|
""" |
|
exist_labels = np.zeros(class_label.shape[0]) - 1 |
|
pathology_list = [] |
|
for i in range(class_label.shape[0]): |
|
temp_list = [] |
|
|
|
if 0 in class_label[i, :]: |
|
exist_labels[i] = 0 |
|
|
|
if 1 in class_label[i, :]: |
|
exist_labels[i] = 1 |
|
|
|
|
|
temp_list.append(-1) |
|
|
|
try: |
|
temp_list = temp_list + random.sample( |
|
np.where(class_label[i, :] != 1)[0].tolist(), |
|
self.num_neg_samples, |
|
) |
|
except: |
|
print("fatal error") |
|
if temp_list == []: |
|
temp_list = temp_list + random.sample( |
|
np.where(class_label[i, :] != 1)[0].tolist(), |
|
self.num_neg_samples + 1, |
|
) |
|
pathology_list.append(temp_list) |
|
|
|
return exist_labels, pathology_list |
|
|
|
def __len__(self): |
|
return len(self.ann) |
|
|
|
|
|
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): |
|
loaders = [] |
|
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( |
|
datasets, samplers, batch_size, num_workers, is_trains, collate_fns |
|
): |
|
if is_train: |
|
shuffle = sampler is None |
|
drop_last = True |
|
else: |
|
shuffle = False |
|
drop_last = False |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=bs, |
|
num_workers=n_worker, |
|
pin_memory=True, |
|
sampler=sampler, |
|
shuffle=shuffle, |
|
collate_fn=collate_fn, |
|
drop_last=drop_last, |
|
) |
|
loaders.append(loader) |
|
return loaders |
|
|