Spaces:
Sleeping
Sleeping
import csv | |
import pathlib | |
from typing import Any, Callable, Optional, Tuple | |
import torch | |
from PIL import Image | |
from .utils import check_integrity, verify_str_arg | |
from .vision import VisionDataset | |
class FER2013(VisionDataset): | |
"""`FER2013 | |
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset. | |
Args: | |
root (string): Root directory of dataset where directory | |
``root/fer2013`` exists. | |
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. | |
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed | |
version. E.g, ``transforms.RandomCrop`` | |
target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |
""" | |
_RESOURCES = { | |
"train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"), | |
"test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"), | |
} | |
def __init__( | |
self, | |
root: str, | |
split: str = "train", | |
transform: Optional[Callable] = None, | |
target_transform: Optional[Callable] = None, | |
) -> None: | |
self._split = verify_str_arg(split, "split", self._RESOURCES.keys()) | |
super().__init__(root, transform=transform, target_transform=target_transform) | |
base_folder = pathlib.Path(self.root) / "fer2013" | |
file_name, md5 = self._RESOURCES[self._split] | |
data_file = base_folder / file_name | |
if not check_integrity(str(data_file), md5=md5): | |
raise RuntimeError( | |
f"{file_name} not found in {base_folder} or corrupted. " | |
f"You can download it from " | |
f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" | |
) | |
with open(data_file, "r", newline="") as file: | |
self._samples = [ | |
( | |
torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48), | |
int(row["emotion"]) if "emotion" in row else None, | |
) | |
for row in csv.DictReader(file) | |
] | |
def __len__(self) -> int: | |
return len(self._samples) | |
def __getitem__(self, idx: int) -> Tuple[Any, Any]: | |
image_tensor, target = self._samples[idx] | |
image = Image.fromarray(image_tensor.numpy()) | |
if self.transform is not None: | |
image = self.transform(image) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return image, target | |
def extra_repr(self) -> str: | |
return f"split={self._split}" | |