File size: 3,551 Bytes
4a1f918 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import random
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from data_transforms.btcv_transform import BTCV_Transform
class BTCV_Dataset(Dataset):
def __init__(self, config, is_train=False, shuffle_list = True, apply_norm=True, no_text_mode=False) -> None:
super().__init__()
self.root_path = config['data']['root_path']
self.img_names = []
self.img_path_list = []
self.label_path_list = []
self.label_list = []
self.is_train = is_train
self.label_names = config['data']['label_names']
self.num_classes = len(self.label_names)
self.config = config
self.apply_norm = apply_norm
self.no_text_mode = no_text_mode
self.label_dict = {
"Spleen":1,
"Right Kidney": 2,
"Left Kidney": 3,
"Gall Bladder": 4,
"Liver": 5,
"Stomach": 6,
"Aorta": 7,
"Pancreas": 8
}
self.populate_lists()
if shuffle_list:
p = [x for x in range(len(self.img_path_list))]
random.shuffle(p)
self.img_path_list = [self.img_path_list[pi] for pi in p]
self.img_names = [self.img_names[pi] for pi in p]
self.label_path_list = [self.label_path_list[pi] for pi in p]
self.label_list = [self.label_list[pi] for pi in p]
#define data transform
self.data_transform = BTCV_Transform(config=config)
def __len__(self):
return len(self.img_path_list)
def populate_lists(self):
imgs_labels_path = os.path.join(self.root_path, 'train_npz')
for npz in os.listdir(imgs_labels_path):
case_no = int(npz[4:npz.find("_")])
if self.is_train:
if case_no>=34:
continue
else:
if case_no<34:
continue
# print(img)
if self.no_text_mode:
self.img_names.append(npz)
self.img_path_list.append(os.path.join(imgs_labels_path,npz))
self.label_path_list.append(os.path.join(imgs_labels_path, npz))
self.label_list.append('')
else:
for label_name in self.label_names:
self.img_names.append(npz)
self.img_path_list.append(os.path.join(imgs_labels_path,npz))
self.label_path_list.append(os.path.join(imgs_labels_path, npz))
self.label_list.append(label_name)
def __getitem__(self, index):
data = np.load(self.img_path_list[index])
img, all_class_labels = data['image'], data['label']
# print("img max min: ", np.max(img), np.min(img))
img = torch.Tensor(img).unsqueeze(0).repeat(3,1,1)
try:
label = torch.Tensor(all_class_labels)==self.label_dict[self.label_list[index]]+0
if len(label.shape)==3:
label = label[:,:,0]
except:
1/0
label = torch.zeros(img.shape[1], img.shape[2])
label = label.unsqueeze(0)
label_of_interest = self.label_list[index]
#convert all grayscale pixels due to resizing back to 0, 1
img, label = self.data_transform(img, label, is_train=self.is_train, apply_norm=self.apply_norm)
label = (label>=0.5)+0
label = label[0]
return img, label, self.img_path_list[index], label_of_interest
|