File size: 3,116 Bytes
b63fd37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

from fake_face_detection.utils.compute_weights import compute_weights
from torch.utils.data import Dataset
from PIL import Image
from glob import glob
import numpy as np
import torch
import os

class LionCheetahDataset(Dataset):

    def __init__(self, lion_path: str, cheetah_path: str, id_map: dict, transformer, **transformer_kwargs):
        
        # let us recuperate the transformer
        self.transformer = transformer
        
        # let us recuperate the transformer kwargs
        self.transformer_kwargs = transformer_kwargs
        
        # let us load the images 
        lion_images = glob(os.path.join(lion_path, "*"))
        
        cheetah_images = glob(os.path.join(cheetah_path, "*"))
        
        # recuperate rgb images
        self.lion_images = []
        
        self.cheetah_images = []
        
        for lion in lion_images:
            
            try:
                
                with Image.open(lion) as img:
                    
                    # let us add a transformation on the images
                    if self.transformer:
                        
                        image = self.transformer(img, **self.transformer_kwargs)
                
                self.lion_images.append(lion)
            
            except Exception as e:
                
                pass
            
        for cheetah in cheetah_images:
            
            try:
                
                with Image.open(cheetah) as img:
                    
                    # let us add a transformation on the images
                    if self.transformer:
                        
                        image = self.transformer(img, **self.transformer_kwargs)
                
                self.cheetah_images.append(cheetah)
            
            except Exception as e:
                
                pass
        
        self.images = self.lion_images + self.cheetah_images
        
        # let us recuperate the labels
        self.lion_labels = [int(id_map['lion'])] * len(self.lion_images)
        
        self.cheetah_labels = [int(id_map['cheetah'])] * len(self.cheetah_images)
        
        self.labels = self.lion_labels + self.cheetah_labels
        
        # let us recuperate the weights
        self.weights = torch.from_numpy(compute_weights(self.labels))
        
        # let us recuperate the length
        self.length = len(self.labels)
        
    def __getitem__(self, index):
        
        # let us recuperate an image
        image = self.images[index]
        
        with Image.open(image) as img:
            
            # let us recuperate a label
            label = self.labels[index]
            
            # let us add a transformation on the images
            if self.transformer:
                
                image = self.transformer(img, **self.transformer_kwargs)
                
        # let us add the label inside the obtained dictionary
        image['labels'] = label
        
        return image    
        
    def __len__(self):
        
        return self.length