|
import os |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from data_transforms.btcv_transform import BTCV_Transform |
|
|
|
class ArcadeDataset(Dataset): |
|
def __init__(self, config, file_list=None, is_train=False, shuffle_list=True, apply_norm=True, no_text_mode=False) -> None: |
|
super().__init__() |
|
self.root_dir = config['data']['root_path'] |
|
self.is_train = is_train |
|
self.config = config |
|
self.apply_norm = apply_norm |
|
self.no_text_mode = no_text_mode |
|
self.label_names = config['data']['label_names'] |
|
self.label_list = config['data']['label_list'] |
|
|
|
self.image_dir = os.path.join(self.root_dir, 'images') |
|
self.mask_dir = os.path.join(self.root_dir, 'masks') |
|
|
|
if not os.path.exists(self.image_dir) or not os.path.exists(self.mask_dir): |
|
raise ValueError(f"Image or mask directory not found in {self.root_dir}") |
|
|
|
|
|
if file_list is not None: |
|
self.images = file_list |
|
else: |
|
self.images = sorted([f for f in os.listdir(self.image_dir) if f.endswith('.png') or f.endswith('.jpg')]) |
|
|
|
if shuffle_list: |
|
np.random.shuffle(self.images) |
|
|
|
self.data_transform = BTCV_Transform(config=config) |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, index): |
|
img_name = self.images[index] |
|
image_name = os.path.splitext(os.path.basename(img_name))[0] |
|
img_path = os.path.join(self.image_dir, image_name+'.png') |
|
mask_name = os.path.splitext(image_name)[0] + '_mask.png' |
|
|
|
|
|
mask_path = os.path.join(self.mask_dir, mask_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img = Image.open(img_path).convert("RGB") |
|
img = torch.as_tensor(np.array(img)).permute(2, 0, 1) |
|
|
|
|
|
if os.path.exists(mask_path): |
|
mask = Image.open(mask_path).convert("L") |
|
mask = torch.as_tensor(np.array(mask)) |
|
else: |
|
print(f"Mask not found for {mask_name}, using blank mask") |
|
mask = torch.zeros((img.shape[1], img.shape[2]), dtype=torch.uint8) |
|
|
|
|
|
if mask.shape != (img.shape[1], img.shape[2]): |
|
mask = torch.as_tensor(np.array(Image.fromarray(mask.numpy()).resize((img.shape[2], img.shape[1])))) |
|
|
|
|
|
mask = (mask > 0).float() |
|
|
|
|
|
img, mask = self.data_transform(img, mask.unsqueeze(0), is_train=self.is_train, apply_norm=self.apply_norm) |
|
|
|
if self.no_text_mode: |
|
return img, mask, img_path, "" |
|
else: |
|
return img, mask[0], img_path, self.label_names[1] |
|
|
|
def get_category_ids(self, image_id): |
|
img_name = self.images[image_id] |
|
mask_name = os.path.splitext(img_name)[0] + '_mask.png' |
|
mask_path = os.path.join(self.mask_dir, mask_name) |
|
mask = Image.open(mask_path).convert("L") |
|
mask = np.array(mask) |
|
unique_values = np.unique(mask) |
|
category_ids = [self.label_list[v] for v in unique_values if v in self.label_list] |
|
return category_ids |
|
|