NeuralFalcon's picture
Upload test_dataset.py
e214599 verified
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from config import *
class InpaintDataset(Dataset):
def __init__(self, in_image, mask_image, resize_to):
if resize_to is None:
resize_to = RESIZE_TO
self.imglist = [in_image]
self.masklist = [mask_image]
self.setsize = resize_to
# def __init__(self):
# self.imglist = [INIMAGE]
# self.masklist = [MASKIMAGE]
# self.setsize = RESIZE_TO
def __len__(self):
return len(self.imglist)
def __getitem__(self, index):
# image
img = cv2.imread(self.imglist[index])
mask = cv2.imread(self.masklist[index])[:, :, 0]
## COMMENTING FOR NOW
# h, w = mask.shape
# # img = cv2.resize(img, (w, h))
img = cv2.resize(img, self.setsize)
mask = cv2.resize(mask, self.setsize)
##
# find the Minimum bounding rectangle in the mask
"""
contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cidx, cnt in enumerate(contours):
(x, y, w, h) = cv2.boundingRect(cnt)
mask[y:y+h, x:x+w] = 255
"""
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = (
torch.from_numpy(img.astype(np.float32) / 255.0)
.permute(2, 0, 1)
.contiguous()
)
mask = (
torch.from_numpy(mask.astype(np.float32) / 255.0)
.unsqueeze(0)
.contiguous()
)
return img, mask