import cv2 from PIL import Image import numpy as np import os import torch from torch.utils.data import Dataset, DataLoader def preprocess_img(img_dir, channels=3): if channels == 1: img = cv2.imread(img_dir, 0) elif channels == 3: img = cv2.imread(img_dir) shape_r = 288 shape_c = 384 img_padded = np.ones((shape_r, shape_c, channels), dtype=np.uint8) if channels == 1: img_padded = np.zeros((shape_r, shape_c), dtype=np.uint8) original_shape = img.shape rows_rate = original_shape[0] / shape_r cols_rate = original_shape[1] / shape_c if rows_rate > cols_rate: new_cols = (original_shape[1] * shape_r) // original_shape[0] img = cv2.resize(img, (new_cols, shape_r)) if new_cols > shape_c: new_cols = shape_c img_padded[:, ((img_padded.shape[1] - new_cols) // 2):((img_padded.shape[1] - new_cols) // 2 + new_cols)] = img else: new_rows = (original_shape[0] * shape_c) // original_shape[1] img = cv2.resize(img, (shape_c, new_rows)) if new_rows > shape_r: new_rows = shape_r img_padded[((img_padded.shape[0] - new_rows) // 2):((img_padded.shape[0] - new_rows) // 2 + new_rows), :] = img return img_padded def postprocess_img(pred, org_dir): pred = np.array(pred) org = cv2.imread(org_dir, 0) shape_r = org.shape[0] shape_c = org.shape[1] predictions_shape = pred.shape rows_rate = shape_r / predictions_shape[0] cols_rate = shape_c / predictions_shape[1] if rows_rate > cols_rate: new_cols = (predictions_shape[1] * shape_r) // predictions_shape[0] pred = cv2.resize(pred, (new_cols, shape_r)) img = pred[:, ((pred.shape[1] - shape_c) // 2):((pred.shape[1] - shape_c) // 2 + shape_c)] else: new_rows = (predictions_shape[0] * shape_c) // predictions_shape[1] pred = cv2.resize(pred, (shape_c, new_rows)) img = pred[((pred.shape[0] - shape_r) // 2):((pred.shape[0] - shape_r) // 2 + shape_r), :] return img class MyDataset(Dataset): """Load dataset.""" def __init__(self, ids, stimuli_dir, saliency_dir, fixation_dir, transform=None): """ Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.ids = ids self.stimuli_dir = stimuli_dir self.saliency_dir = saliency_dir self.fixation_dir = fixation_dir self.transform = transform def __len__(self): return len(self.ids) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() im_path = self.stimuli_dir + self.ids.iloc[idx, 0] image = Image.open(im_path).convert('RGB') img = np.array(image) / 255. img = np.transpose(img, (2, 0, 1)) img = torch.from_numpy(img) # if self.transform: # img = self.transform(image) smap_path = self.saliency_dir + self.ids.iloc[idx, 1] saliency = Image.open(smap_path) smap = np.expand_dims(np.array(saliency) / 255., axis=0) smap = torch.from_numpy(smap) fmap_path = self.fixation_dir + self.ids.iloc[idx, 2] fixation = Image.open(fmap_path) fmap = np.expand_dims(np.array(fixation) / 255., axis=0) fmap = torch.from_numpy(fmap) sample = {'image': img, 'saliency': smap, 'fixation': fmap} return sample