import torch import cv2 import numpy as np import torchvision.transforms as transforms from torch.utils.data import Dataset class ImageDataset(Dataset): def __init__(self, csv, train, test): self.csv = csv self.train = train self.test = test self.all_image_names = self.csv[:]['Id'] self.all_labels = np.array(self.csv.drop(['Id', 'Genre'], axis=1)) self.train_ratio = int(0.85 * len(self.csv)) self.valid_ratio = len(self.csv) - self.train_ratio # set the training data images and labels if self.train == True: print(f"Number of training images: {self.train_ratio}") self.image_names = list(self.all_image_names[:self.train_ratio]) self.labels = list(self.all_labels[:self.train_ratio]) # define the training transforms self.transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((400, 400)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=45), transforms.ToTensor(), ]) # set the validation data images and labels elif self.train == False and self.test == False: print(f"Number of validation images: {self.valid_ratio}") self.image_names = list(self.all_image_names[-self.valid_ratio:-10]) self.labels = list(self.all_labels[-self.valid_ratio:]) # define the validation transforms self.transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((400, 400)), transforms.ToTensor(), ]) # set the test data images and labels, only last 10 images # this, we will use in a separate inference script elif self.test == True and self.train == False: self.image_names = list(self.all_image_names[-10:]) self.labels = list(self.all_labels[-10:]) # define the test transforms self.transform = transforms.Compose([ transforms.ToPILImage(), transforms.ToTensor(), ]) def __len__(self): return len(self.image_names) def __getitem__(self, index): image = cv2.imread(f"../input/movie-classifier/Multi_Label_dataset/Images/{self.image_names[index]}.jpg") # convert the image from BGR to RGB color format image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # apply image transforms image = self.transform(image) targets = self.labels[index] return { 'image': torch.tensor(image, dtype=torch.float32), 'label': torch.tensor(targets, dtype=torch.float32) }