|
import random |
|
import os |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
from data_transforms.btcv_transform import BTCV_Transform |
|
|
|
|
|
class BTCV_Dataset(Dataset): |
|
def __init__(self, config, is_train=False, shuffle_list = True, apply_norm=True, no_text_mode=False) -> None: |
|
super().__init__() |
|
self.root_path = config['data']['root_path'] |
|
self.img_names = [] |
|
self.img_path_list = [] |
|
self.label_path_list = [] |
|
self.label_list = [] |
|
self.is_train = is_train |
|
self.label_names = config['data']['label_names'] |
|
self.num_classes = len(self.label_names) |
|
self.config = config |
|
self.apply_norm = apply_norm |
|
self.no_text_mode = no_text_mode |
|
self.label_dict = { |
|
"Spleen":1, |
|
"Right Kidney": 2, |
|
"Left Kidney": 3, |
|
"Gall Bladder": 4, |
|
"Liver": 5, |
|
"Stomach": 6, |
|
"Aorta": 7, |
|
"Pancreas": 8 |
|
} |
|
|
|
self.populate_lists() |
|
if shuffle_list: |
|
p = [x for x in range(len(self.img_path_list))] |
|
random.shuffle(p) |
|
self.img_path_list = [self.img_path_list[pi] for pi in p] |
|
self.img_names = [self.img_names[pi] for pi in p] |
|
self.label_path_list = [self.label_path_list[pi] for pi in p] |
|
self.label_list = [self.label_list[pi] for pi in p] |
|
|
|
|
|
self.data_transform = BTCV_Transform(config=config) |
|
|
|
def __len__(self): |
|
return len(self.img_path_list) |
|
|
|
def populate_lists(self): |
|
imgs_labels_path = os.path.join(self.root_path, 'train_npz') |
|
|
|
for npz in os.listdir(imgs_labels_path): |
|
case_no = int(npz[4:npz.find("_")]) |
|
if self.is_train: |
|
if case_no>=34: |
|
continue |
|
else: |
|
if case_no<34: |
|
continue |
|
|
|
if self.no_text_mode: |
|
self.img_names.append(npz) |
|
self.img_path_list.append(os.path.join(imgs_labels_path,npz)) |
|
self.label_path_list.append(os.path.join(imgs_labels_path, npz)) |
|
self.label_list.append('') |
|
else: |
|
for label_name in self.label_names: |
|
self.img_names.append(npz) |
|
self.img_path_list.append(os.path.join(imgs_labels_path,npz)) |
|
self.label_path_list.append(os.path.join(imgs_labels_path, npz)) |
|
self.label_list.append(label_name) |
|
|
|
|
|
def __getitem__(self, index): |
|
data = np.load(self.img_path_list[index]) |
|
img, all_class_labels = data['image'], data['label'] |
|
|
|
img = torch.Tensor(img).unsqueeze(0).repeat(3,1,1) |
|
|
|
try: |
|
label = torch.Tensor(all_class_labels)==self.label_dict[self.label_list[index]]+0 |
|
if len(label.shape)==3: |
|
label = label[:,:,0] |
|
except: |
|
1/0 |
|
label = torch.zeros(img.shape[1], img.shape[2]) |
|
|
|
label = label.unsqueeze(0) |
|
label_of_interest = self.label_list[index] |
|
|
|
|
|
img, label = self.data_transform(img, label, is_train=self.is_train, apply_norm=self.apply_norm) |
|
label = (label>=0.5)+0 |
|
label = label[0] |
|
|
|
|
|
return img, label, self.img_path_list[index], label_of_interest |
|
|