File size: 1,551 Bytes
af720c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import List, Tuple, Callable
from pathlib import Path
import datasets
import torch
from torch.utils.data import Dataset


class SegmentationDataset(Dataset):
    def __init__(

        self,

        dataset: datasets.Dataset,

        train: bool = True,

        transform: Callable = None,

        target_transform: Callable = None,

        test_size: float = 0.25,

    ) -> None:
        super().__init__()
        self.dataset = dataset
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.test_size = test_size

        total_size = len(dataset)
        indices = list(range(total_size))
        split = int(self.test_size * total_size)

        if train:
            self.indices = indices[split:]
        else:
            self.indices = indices[:split]

    def __len__(self) -> int:
        return len(self.indices)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        item = self.dataset[self.indices[idx]]
        image = item["image"]
        mask = item["mask"]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)
        return image, mask


def collate_fn(items: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
    images = torch.stack([item[0] for item in items])
    masks = torch.stack([item[1] for item in items])
    return images, masks