|
import cv2 |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
def preprocess_img(img_dir, channels=3): |
|
|
|
if channels == 1: |
|
img = cv2.imread(img_dir, 0) |
|
elif channels == 3: |
|
img = cv2.imread(img_dir) |
|
|
|
shape_r = 288 |
|
shape_c = 384 |
|
img_padded = np.ones((shape_r, shape_c, channels), dtype=np.uint8) |
|
if channels == 1: |
|
img_padded = np.zeros((shape_r, shape_c), dtype=np.uint8) |
|
original_shape = img.shape |
|
rows_rate = original_shape[0] / shape_r |
|
cols_rate = original_shape[1] / shape_c |
|
if rows_rate > cols_rate: |
|
new_cols = (original_shape[1] * shape_r) // original_shape[0] |
|
img = cv2.resize(img, (new_cols, shape_r)) |
|
if new_cols > shape_c: |
|
new_cols = shape_c |
|
img_padded[:, |
|
((img_padded.shape[1] - new_cols) // 2):((img_padded.shape[1] - new_cols) // 2 + new_cols)] = img |
|
else: |
|
new_rows = (original_shape[0] * shape_c) // original_shape[1] |
|
img = cv2.resize(img, (shape_c, new_rows)) |
|
|
|
if new_rows > shape_r: |
|
new_rows = shape_r |
|
img_padded[((img_padded.shape[0] - new_rows) // 2):((img_padded.shape[0] - new_rows) // 2 + new_rows), |
|
:] = img |
|
|
|
return img_padded |
|
|
|
|
|
def postprocess_img(pred, org_dir): |
|
pred = np.array(pred) |
|
org = cv2.imread(org_dir, 0) |
|
shape_r = org.shape[0] |
|
shape_c = org.shape[1] |
|
predictions_shape = pred.shape |
|
|
|
rows_rate = shape_r / predictions_shape[0] |
|
cols_rate = shape_c / predictions_shape[1] |
|
|
|
if rows_rate > cols_rate: |
|
new_cols = (predictions_shape[1] * shape_r) // predictions_shape[0] |
|
pred = cv2.resize(pred, (new_cols, shape_r)) |
|
img = pred[:, ((pred.shape[1] - shape_c) // 2):((pred.shape[1] - shape_c) // 2 + shape_c)] |
|
else: |
|
new_rows = (predictions_shape[0] * shape_c) // predictions_shape[1] |
|
pred = cv2.resize(pred, (shape_c, new_rows)) |
|
img = pred[((pred.shape[0] - shape_r) // 2):((pred.shape[0] - shape_r) // 2 + shape_r), :] |
|
|
|
return img |
|
|
|
|
|
class MyDataset(Dataset): |
|
"""Load dataset.""" |
|
|
|
def __init__(self, ids, stimuli_dir, saliency_dir, fixation_dir, transform=None): |
|
""" |
|
Args: |
|
csv_file (string): Path to the csv file with annotations. |
|
root_dir (string): Directory with all the images. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
self.ids = ids |
|
self.stimuli_dir = stimuli_dir |
|
self.saliency_dir = saliency_dir |
|
self.fixation_dir = fixation_dir |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return len(self.ids) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
im_path = self.stimuli_dir + self.ids.iloc[idx, 0] |
|
image = Image.open(im_path).convert('RGB') |
|
img = np.array(image) / 255. |
|
img = np.transpose(img, (2, 0, 1)) |
|
img = torch.from_numpy(img) |
|
|
|
|
|
|
|
smap_path = self.saliency_dir + self.ids.iloc[idx, 1] |
|
saliency = Image.open(smap_path) |
|
|
|
smap = np.expand_dims(np.array(saliency) / 255., axis=0) |
|
smap = torch.from_numpy(smap) |
|
|
|
fmap_path = self.fixation_dir + self.ids.iloc[idx, 2] |
|
fixation = Image.open(fmap_path) |
|
|
|
fmap = np.expand_dims(np.array(fixation) / 255., axis=0) |
|
fmap = torch.from_numpy(fmap) |
|
|
|
sample = {'image': img, 'saliency': smap, 'fixation': fmap} |
|
|
|
return sample |
|
|
|
|
|
|
|
|
|
|
|
|