Dabs's picture
first commit
cb8043e
import torch
import cv2
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
class ImageDataset(Dataset):
def __init__(self, csv, train, test):
self.csv = csv
self.train = train
self.test = test
self.all_image_names = self.csv[:]['Id']
self.all_labels = np.array(self.csv.drop(['Id', 'Genre'], axis=1))
self.train_ratio = int(0.85 * len(self.csv))
self.valid_ratio = len(self.csv) - self.train_ratio
# set the training data images and labels
if self.train == True:
print(f"Number of training images: {self.train_ratio}")
self.image_names = list(self.all_image_names[:self.train_ratio])
self.labels = list(self.all_labels[:self.train_ratio])
# define the training transforms
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((400, 400)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=45),
transforms.ToTensor(),
])
# set the validation data images and labels
elif self.train == False and self.test == False:
print(f"Number of validation images: {self.valid_ratio}")
self.image_names = list(self.all_image_names[-self.valid_ratio:-10])
self.labels = list(self.all_labels[-self.valid_ratio:])
# define the validation transforms
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((400, 400)),
transforms.ToTensor(),
])
# set the test data images and labels, only last 10 images
# this, we will use in a separate inference script
elif self.test == True and self.train == False:
self.image_names = list(self.all_image_names[-10:])
self.labels = list(self.all_labels[-10:])
# define the test transforms
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
])
def __len__(self):
return len(self.image_names)
def __getitem__(self, index):
image = cv2.imread(f"../input/movie-classifier/Multi_Label_dataset/Images/{self.image_names[index]}.jpg")
# convert the image from BGR to RGB color format
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# apply image transforms
image = self.transform(image)
targets = self.labels[index]
return {
'image': torch.tensor(image, dtype=torch.float32),
'label': torch.tensor(targets, dtype=torch.float32)
}