VisAt / utils /data_process.py
Tanzeer's picture
Upload 15 files
8395863
raw
history blame
3.63 kB
import cv2
from PIL import Image
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader
def preprocess_img(img_dir, channels=3):
if channels == 1:
img = cv2.imread(img_dir, 0)
elif channels == 3:
img = cv2.imread(img_dir)
shape_r = 288
shape_c = 384
img_padded = np.ones((shape_r, shape_c, channels), dtype=np.uint8)
if channels == 1:
img_padded = np.zeros((shape_r, shape_c), dtype=np.uint8)
original_shape = img.shape
rows_rate = original_shape[0] / shape_r
cols_rate = original_shape[1] / shape_c
if rows_rate > cols_rate:
new_cols = (original_shape[1] * shape_r) // original_shape[0]
img = cv2.resize(img, (new_cols, shape_r))
if new_cols > shape_c:
new_cols = shape_c
img_padded[:,
((img_padded.shape[1] - new_cols) // 2):((img_padded.shape[1] - new_cols) // 2 + new_cols)] = img
else:
new_rows = (original_shape[0] * shape_c) // original_shape[1]
img = cv2.resize(img, (shape_c, new_rows))
if new_rows > shape_r:
new_rows = shape_r
img_padded[((img_padded.shape[0] - new_rows) // 2):((img_padded.shape[0] - new_rows) // 2 + new_rows),
:] = img
return img_padded
def postprocess_img(pred, org_dir):
pred = np.array(pred)
org = cv2.imread(org_dir, 0)
shape_r = org.shape[0]
shape_c = org.shape[1]
predictions_shape = pred.shape
rows_rate = shape_r / predictions_shape[0]
cols_rate = shape_c / predictions_shape[1]
if rows_rate > cols_rate:
new_cols = (predictions_shape[1] * shape_r) // predictions_shape[0]
pred = cv2.resize(pred, (new_cols, shape_r))
img = pred[:, ((pred.shape[1] - shape_c) // 2):((pred.shape[1] - shape_c) // 2 + shape_c)]
else:
new_rows = (predictions_shape[0] * shape_c) // predictions_shape[1]
pred = cv2.resize(pred, (shape_c, new_rows))
img = pred[((pred.shape[0] - shape_r) // 2):((pred.shape[0] - shape_r) // 2 + shape_r), :]
return img
class MyDataset(Dataset):
"""Load dataset."""
def __init__(self, ids, stimuli_dir, saliency_dir, fixation_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.ids = ids
self.stimuli_dir = stimuli_dir
self.saliency_dir = saliency_dir
self.fixation_dir = fixation_dir
self.transform = transform
def __len__(self):
return len(self.ids)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
im_path = self.stimuli_dir + self.ids.iloc[idx, 0]
image = Image.open(im_path).convert('RGB')
img = np.array(image) / 255.
img = np.transpose(img, (2, 0, 1))
img = torch.from_numpy(img)
# if self.transform:
# img = self.transform(image)
smap_path = self.saliency_dir + self.ids.iloc[idx, 1]
saliency = Image.open(smap_path)
smap = np.expand_dims(np.array(saliency) / 255., axis=0)
smap = torch.from_numpy(smap)
fmap_path = self.fixation_dir + self.ids.iloc[idx, 2]
fixation = Image.open(fmap_path)
fmap = np.expand_dims(np.array(fixation) / 255., axis=0)
fmap = torch.from_numpy(fmap)
sample = {'image': img, 'saliency': smap, 'fixation': fmap}
return sample