from typing import List, Tuple, Callable from pathlib import Path import datasets import torch from torch.utils.data import Dataset class SegmentationDataset(Dataset): def __init__( self, dataset: datasets.Dataset, train: bool = True, transform: Callable = None, target_transform: Callable = None, test_size: float = 0.25, ) -> None: super().__init__() self.dataset = dataset self.train = train self.transform = transform self.target_transform = target_transform self.test_size = test_size total_size = len(dataset) indices = list(range(total_size)) split = int(self.test_size * total_size) if train: self.indices = indices[split:] else: self.indices = indices[:split] def __len__(self) -> int: return len(self.indices) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: item = self.dataset[self.indices[idx]] image = item["image"] mask = item["mask"] if self.transform: image = self.transform(image) if self.target_transform: mask = self.target_transform(mask) return image, mask def collate_fn(items: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]: images = torch.stack([item[0] for item in items]) masks = torch.stack([item[1] for item in items]) return images, masks