faces-segmentation / src /dataset.py
eksemyashkina's picture
Upload 9 files
b108d0c verified
raw
history blame
1.47 kB
from typing import List, Tuple, Callable
from pathlib import Path
import PIL.Image
import numpy as np
import datasets
import torch
from torch.utils.data import Dataset
class SegmentationDataset(Dataset):
def __init__(
self,
root: str,
subset: str,
transform: Callable = None,
target_transform: Callable = None,
) -> None:
super().__init__()
self.images_dir = Path(root) / "images" / subset
self.masks_dir = Path(root) / "annotations" / subset
self.transform = transform
self.target_transform = target_transform
self.images = sorted(list(Path(self.images_dir).glob("**/*.jpg")))
self.masks = sorted(list(Path(self.masks_dir).glob("**/*.png")))
def __len__(self) -> int:
return len(self.images)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
image = PIL.Image.open(self.images[idx]).convert("RGB")
mask = PIL.Image.open(self.masks[idx]).convert("L")
if self.transform:
image = self.transform(image)
if self.target_transform:
mask = self.target_transform(mask)
return image, mask
def collate_fn(items: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
images = torch.stack([item[0] for item in items])
masks = torch.stack([item[1] for item in items])
return images, masks