Spaces:
Runtime error
Runtime error
import torch | |
import cv2 | |
import numpy as np | |
import torchvision.transforms as transforms | |
from torch.utils.data import Dataset | |
class ImageDataset(Dataset): | |
def __init__(self, csv, train, test): | |
self.csv = csv | |
self.train = train | |
self.test = test | |
self.all_image_names = self.csv[:]['Id'] | |
self.all_labels = np.array(self.csv.drop(['Id', 'Genre'], axis=1)) | |
self.train_ratio = int(0.85 * len(self.csv)) | |
self.valid_ratio = len(self.csv) - self.train_ratio | |
# set the training data images and labels | |
if self.train == True: | |
print(f"Number of training images: {self.train_ratio}") | |
self.image_names = list(self.all_image_names[:self.train_ratio]) | |
self.labels = list(self.all_labels[:self.train_ratio]) | |
# define the training transforms | |
self.transform = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Resize((400, 400)), | |
transforms.RandomHorizontalFlip(p=0.5), | |
transforms.RandomRotation(degrees=45), | |
transforms.ToTensor(), | |
]) | |
# set the validation data images and labels | |
elif self.train == False and self.test == False: | |
print(f"Number of validation images: {self.valid_ratio}") | |
self.image_names = list(self.all_image_names[-self.valid_ratio:-10]) | |
self.labels = list(self.all_labels[-self.valid_ratio:]) | |
# define the validation transforms | |
self.transform = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Resize((400, 400)), | |
transforms.ToTensor(), | |
]) | |
# set the test data images and labels, only last 10 images | |
# this, we will use in a separate inference script | |
elif self.test == True and self.train == False: | |
self.image_names = list(self.all_image_names[-10:]) | |
self.labels = list(self.all_labels[-10:]) | |
# define the test transforms | |
self.transform = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.ToTensor(), | |
]) | |
def __len__(self): | |
return len(self.image_names) | |
def __getitem__(self, index): | |
image = cv2.imread(f"../input/movie-classifier/Multi_Label_dataset/Images/{self.image_names[index]}.jpg") | |
# convert the image from BGR to RGB color format | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# apply image transforms | |
image = self.transform(image) | |
targets = self.labels[index] | |
return { | |
'image': torch.tensor(image, dtype=torch.float32), | |
'label': torch.tensor(targets, dtype=torch.float32) | |
} |