=
adding the best model to hugging face
b63fd37
from fake_face_detection.utils.compute_weights import compute_weights
from torch.utils.data import Dataset
from PIL import Image
from glob import glob
import numpy as np
import torch
import os
class LionCheetahDataset(Dataset):
def __init__(self, lion_path: str, cheetah_path: str, id_map: dict, transformer, **transformer_kwargs):
# let us recuperate the transformer
self.transformer = transformer
# let us recuperate the transformer kwargs
self.transformer_kwargs = transformer_kwargs
# let us load the images
lion_images = glob(os.path.join(lion_path, "*"))
cheetah_images = glob(os.path.join(cheetah_path, "*"))
# recuperate rgb images
self.lion_images = []
self.cheetah_images = []
for lion in lion_images:
try:
with Image.open(lion) as img:
# let us add a transformation on the images
if self.transformer:
image = self.transformer(img, **self.transformer_kwargs)
self.lion_images.append(lion)
except Exception as e:
pass
for cheetah in cheetah_images:
try:
with Image.open(cheetah) as img:
# let us add a transformation on the images
if self.transformer:
image = self.transformer(img, **self.transformer_kwargs)
self.cheetah_images.append(cheetah)
except Exception as e:
pass
self.images = self.lion_images + self.cheetah_images
# let us recuperate the labels
self.lion_labels = [int(id_map['lion'])] * len(self.lion_images)
self.cheetah_labels = [int(id_map['cheetah'])] * len(self.cheetah_images)
self.labels = self.lion_labels + self.cheetah_labels
# let us recuperate the weights
self.weights = torch.from_numpy(compute_weights(self.labels))
# let us recuperate the length
self.length = len(self.labels)
def __getitem__(self, index):
# let us recuperate an image
image = self.images[index]
with Image.open(image) as img:
# let us recuperate a label
label = self.labels[index]
# let us add a transformation on the images
if self.transformer:
image = self.transformer(img, **self.transformer_kwargs)
# let us add the label inside the obtained dictionary
image['labels'] = label
return image
def __len__(self):
return self.length