Spaces:
Build error
Build error
File size: 3,116 Bytes
b63fd37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from fake_face_detection.utils.compute_weights import compute_weights
from torch.utils.data import Dataset
from PIL import Image
from glob import glob
import numpy as np
import torch
import os
class LionCheetahDataset(Dataset):
def __init__(self, lion_path: str, cheetah_path: str, id_map: dict, transformer, **transformer_kwargs):
# let us recuperate the transformer
self.transformer = transformer
# let us recuperate the transformer kwargs
self.transformer_kwargs = transformer_kwargs
# let us load the images
lion_images = glob(os.path.join(lion_path, "*"))
cheetah_images = glob(os.path.join(cheetah_path, "*"))
# recuperate rgb images
self.lion_images = []
self.cheetah_images = []
for lion in lion_images:
try:
with Image.open(lion) as img:
# let us add a transformation on the images
if self.transformer:
image = self.transformer(img, **self.transformer_kwargs)
self.lion_images.append(lion)
except Exception as e:
pass
for cheetah in cheetah_images:
try:
with Image.open(cheetah) as img:
# let us add a transformation on the images
if self.transformer:
image = self.transformer(img, **self.transformer_kwargs)
self.cheetah_images.append(cheetah)
except Exception as e:
pass
self.images = self.lion_images + self.cheetah_images
# let us recuperate the labels
self.lion_labels = [int(id_map['lion'])] * len(self.lion_images)
self.cheetah_labels = [int(id_map['cheetah'])] * len(self.cheetah_images)
self.labels = self.lion_labels + self.cheetah_labels
# let us recuperate the weights
self.weights = torch.from_numpy(compute_weights(self.labels))
# let us recuperate the length
self.length = len(self.labels)
def __getitem__(self, index):
# let us recuperate an image
image = self.images[index]
with Image.open(image) as img:
# let us recuperate a label
label = self.labels[index]
# let us add a transformation on the images
if self.transformer:
image = self.transformer(img, **self.transformer_kwargs)
# let us add the label inside the obtained dictionary
image['labels'] = label
return image
def __len__(self):
return self.length
|