import csv import pathlib from typing import Any, Callable, Optional, Tuple import torch from PIL import Image from .utils import check_integrity, verify_str_arg from .vision import VisionDataset class FER2013(VisionDataset): """`FER2013 `_ Dataset. Args: root (string): Root directory of dataset where directory ``root/fer2013`` exists. split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ _RESOURCES = { "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"), "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"), } def __init__( self, root: str, split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: self._split = verify_str_arg(split, "split", self._RESOURCES.keys()) super().__init__(root, transform=transform, target_transform=target_transform) base_folder = pathlib.Path(self.root) / "fer2013" file_name, md5 = self._RESOURCES[self._split] data_file = base_folder / file_name if not check_integrity(str(data_file), md5=md5): raise RuntimeError( f"{file_name} not found in {base_folder} or corrupted. " f"You can download it from " f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" ) with open(data_file, "r", newline="") as file: self._samples = [ ( torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48), int(row["emotion"]) if "emotion" in row else None, ) for row in csv.DictReader(file) ] def __len__(self) -> int: return len(self._samples) def __getitem__(self, idx: int) -> Tuple[Any, Any]: image_tensor, target = self._samples[idx] image = Image.fromarray(image_tensor.numpy()) if self.transform is not None: image = self.transform(image) if self.target_transform is not None: target = self.target_transform(target) return image, target def extra_repr(self) -> str: return f"split={self._split}"