|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
from config import *
|
|
|
|
|
|
class InpaintDataset(Dataset):
|
|
def __init__(self, in_image, mask_image, resize_to):
|
|
if resize_to is None:
|
|
resize_to = RESIZE_TO
|
|
self.imglist = [in_image]
|
|
self.masklist = [mask_image]
|
|
self.setsize = resize_to
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.imglist)
|
|
|
|
def __getitem__(self, index):
|
|
|
|
img = cv2.imread(self.imglist[index])
|
|
mask = cv2.imread(self.masklist[index])[:, :, 0]
|
|
|
|
|
|
|
|
img = cv2.resize(img, self.setsize)
|
|
mask = cv2.resize(mask, self.setsize)
|
|
|
|
|
|
"""
|
|
contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
for cidx, cnt in enumerate(contours):
|
|
(x, y, w, h) = cv2.boundingRect(cnt)
|
|
mask[y:y+h, x:x+w] = 255
|
|
"""
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
img = (
|
|
torch.from_numpy(img.astype(np.float32) / 255.0)
|
|
.permute(2, 0, 1)
|
|
.contiguous()
|
|
)
|
|
mask = (
|
|
torch.from_numpy(mask.astype(np.float32) / 255.0)
|
|
.unsqueeze(0)
|
|
.contiguous()
|
|
)
|
|
return img, mask
|
|
|