Spaces:
Running
Running
import os | |
import os.path | |
from typing import Any, Callable, cast, Dict, List, Optional, Tuple | |
from typing import Union | |
from PIL import Image | |
import pandas as pd | |
from torchvision.datasets import VisionDataset | |
import torch | |
def pil_loader(path: str) -> Image.Image: | |
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
with open(path, "rb") as f: | |
img = Image.open(f) | |
return img.convert("RGB") | |
class BinaryWaterbirds(VisionDataset): | |
def __init__( | |
self, | |
root: str, | |
split: str, | |
loader: Callable[[str], Any] = pil_loader, | |
transform: Optional[Callable] = None, | |
target_transform: Optional[Callable] = None, | |
) -> None: | |
super().__init__(root, transform=transform, target_transform=target_transform) | |
self.loader = loader | |
csv = pd.read_csv(os.path.join(root, 'metadata.csv')) | |
split = {'test': 2, 'valid': 1, 'train': 0}[split] | |
csv = csv[csv['split'] == split] | |
self.samples = [(os.path.join(root, csv.iloc[i]['img_filename']), csv.iloc[i]['y']) for i in range(len(csv))] | |
def __getitem__(self, index: int) -> Tuple[Any, Any]: | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: (sample, target) where target is class_index of the target class. | |
""" | |
path, target = self.samples[index] | |
sample = self.loader(path) | |
if self.transform is not None: | |
sample = self.transform(sample) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return sample, target | |
def __len__(self) -> int: | |
return len(self.samples) | |