File size: 3,561 Bytes
4a1f918
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from data_transforms.btcv_transform import BTCV_Transform

class ArcadeDataset(Dataset):
    def __init__(self, config, file_list=None, is_train=False, shuffle_list=True, apply_norm=True, no_text_mode=False) -> None:
        super().__init__()
        self.root_dir = config['data']['root_path']
        self.is_train = is_train
        self.config = config
        self.apply_norm = apply_norm
        self.no_text_mode = no_text_mode
        self.label_names = config['data']['label_names']
        self.label_list = config['data']['label_list']

        self.image_dir = os.path.join(self.root_dir, 'images')
        self.mask_dir = os.path.join(self.root_dir, 'masks')

        if not os.path.exists(self.image_dir) or not os.path.exists(self.mask_dir):
            raise ValueError(f"Image or mask directory not found in {self.root_dir}")

        # If a file list is provided, use it. Otherwise, load all images in the directory.
        if file_list is not None:
            self.images = file_list
        else:
            self.images = sorted([f for f in os.listdir(self.image_dir) if f.endswith('.png') or f.endswith('.jpg')])

        if shuffle_list:
            np.random.shuffle(self.images)

        self.data_transform = BTCV_Transform(config=config)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_name = self.images[index]
        image_name = os.path.splitext(os.path.basename(img_name))[0]
        img_path = os.path.join(self.image_dir, image_name+'.png')
        mask_name = os.path.splitext(image_name)[0] + '_mask.png'
        # print(self.mask_dir)
        # print(mask_name)
        mask_path = os.path.join(self.mask_dir, mask_name)

        # print("Data point")
        # print("Train: " , self.is_train)
        # print("Img:" , img_path)
        # print("Mask: ", mask_path)

        # Load and process image
        img = Image.open(img_path).convert("RGB")
        img = torch.as_tensor(np.array(img)).permute(2, 0, 1)  # Change to CHW format

        # Load and process mask
        if os.path.exists(mask_path):
            mask = Image.open(mask_path).convert("L")
            mask = torch.as_tensor(np.array(mask))
        else:
            print(f"Mask not found for {mask_name}, using blank mask")
            mask = torch.zeros((img.shape[1], img.shape[2]), dtype=torch.uint8)

        # Resize mask to match image size if necessary
        if mask.shape != (img.shape[1], img.shape[2]):
            mask = torch.as_tensor(np.array(Image.fromarray(mask.numpy()).resize((img.shape[2], img.shape[1]))))

        # Convert mask to binary
        mask = (mask > 0).float()

        # Apply data transformations
        img, mask = self.data_transform(img, mask.unsqueeze(0), is_train=self.is_train, apply_norm=self.apply_norm)

        if self.no_text_mode:
            return img, mask, img_path, ""
        else:
            return img, mask[0], img_path, self.label_names[1]  # Assuming "Vein" is the label of interest

    def get_category_ids(self, image_id):
        img_name = self.images[image_id]
        mask_name = os.path.splitext(img_name)[0] + '_mask.png'
        mask_path = os.path.join(self.mask_dir, mask_name)
        mask = Image.open(mask_path).convert("L")
        mask = np.array(mask)
        unique_values = np.unique(mask)
        category_ids = [self.label_list[v] for v in unique_values if v in self.label_list]
        return category_ids