Spaces:
Build error
Build error
from fake_face_detection.utils.compute_weights import compute_weights | |
from torch.utils.data import Dataset | |
from PIL import Image | |
from glob import glob | |
import numpy as np | |
import torch | |
import os | |
class LionCheetahDataset(Dataset): | |
def __init__(self, lion_path: str, cheetah_path: str, id_map: dict, transformer, **transformer_kwargs): | |
# let us recuperate the transformer | |
self.transformer = transformer | |
# let us recuperate the transformer kwargs | |
self.transformer_kwargs = transformer_kwargs | |
# let us load the images | |
lion_images = glob(os.path.join(lion_path, "*")) | |
cheetah_images = glob(os.path.join(cheetah_path, "*")) | |
# recuperate rgb images | |
self.lion_images = [] | |
self.cheetah_images = [] | |
for lion in lion_images: | |
try: | |
with Image.open(lion) as img: | |
# let us add a transformation on the images | |
if self.transformer: | |
image = self.transformer(img, **self.transformer_kwargs) | |
self.lion_images.append(lion) | |
except Exception as e: | |
pass | |
for cheetah in cheetah_images: | |
try: | |
with Image.open(cheetah) as img: | |
# let us add a transformation on the images | |
if self.transformer: | |
image = self.transformer(img, **self.transformer_kwargs) | |
self.cheetah_images.append(cheetah) | |
except Exception as e: | |
pass | |
self.images = self.lion_images + self.cheetah_images | |
# let us recuperate the labels | |
self.lion_labels = [int(id_map['lion'])] * len(self.lion_images) | |
self.cheetah_labels = [int(id_map['cheetah'])] * len(self.cheetah_images) | |
self.labels = self.lion_labels + self.cheetah_labels | |
# let us recuperate the weights | |
self.weights = torch.from_numpy(compute_weights(self.labels)) | |
# let us recuperate the length | |
self.length = len(self.labels) | |
def __getitem__(self, index): | |
# let us recuperate an image | |
image = self.images[index] | |
with Image.open(image) as img: | |
# let us recuperate a label | |
label = self.labels[index] | |
# let us add a transformation on the images | |
if self.transformer: | |
image = self.transformer(img, **self.transformer_kwargs) | |
# let us add the label inside the obtained dictionary | |
image['labels'] = label | |
return image | |
def __len__(self): | |
return self.length | |