File size: 1,387 Bytes
712b45c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from config import *
class InpaintDataset(Dataset):
def __init__(self):
self.imglist = [INIMAGE]
self.masklist = [MASKIMAGE]
self.setsize = RESIZE_TO
def __len__(self):
return len(self.imglist)
def __getitem__(self, index):
# image
img = cv2.imread(self.imglist[index])
mask = cv2.imread(self.masklist[index])[:, :, 0]
## COMMENTING FOR NOW
# h, w = mask.shape
# # img = cv2.resize(img, (w, h))
img = cv2.resize(img, self.setsize)
mask = cv2.resize(mask, self.setsize)
##
# find the Minimum bounding rectangle in the mask
"""
contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cidx, cnt in enumerate(contours):
(x, y, w, h) = cv2.boundingRect(cnt)
mask[y:y+h, x:x+w] = 255
"""
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = (
torch.from_numpy(img.astype(np.float32) / 255.0)
.permute(2, 0, 1)
.contiguous()
)
mask = (
torch.from_numpy(mask.astype(np.float32) / 255.0)
.unsqueeze(0)
.contiguous()
)
return img, mask
|