import os import cv2 import gdown import shutil import argparse import numpy as np import torch import torch.backends.cudnn as cudnn import torchvision.transforms as transforms from torchvision.utils import save_image from inplace_abn import InPlaceABN from dml_csr import dml_csr from dml_csr import transforms as dml_transforms def parse_args(): parser = argparse.ArgumentParser(description="Plot segmentation mask of an image.") parser.add_argument( "--image_path", type=str, default=None, help="Path to the image file." ) parser.add_argument("--size", type=int, default=512) parser.add_argument( "--checkpoint_path", type=str, default='ckpt/DML_CSR/dml_csr_celebA.pth', help="Path to the DML-CSR pretrained model." ) parser.add_argument( "--output_dir", type=str, default="output/masks/", help="Folder to save segmentation mask." ) args = parser.parse_args() return args def download_checkpoint(): os.makedirs('ckpt', exist_ok=True) id = "1xttWuAj633-ujp_vcm5DtL98PP0b-sUm" gdown.download(id=id, output='ckpt/DML_CSR.zip') shutil.unpack_archive('ckpt/DML_CSR.zip', 'ckpt') os.remove('ckpt/DML_CSR.zip') def box2cs(box: list) -> tuple: x, y, w, h = box[:4] return xywh2cs(x, y, w, h) def xywh2cs(x: float, y: float, w: float, h: float) -> tuple: center = np.zeros((2), dtype=np.float32) center[0] = x + w * 0.5 center[1] = y + h * 0.5 if w > h: h = w elif w < h: w = h scale = np.array([w * 1.0, h * 1.0], dtype=np.float32) return center, scale def labelcolormap(N): if N == 19: # CelebAMask-HQ cmap = np.array([(0, 0, 0), (204, 0, 0), (76, 153, 0), (204, 204, 0), (204, 0, 204), (204, 0, 204), (255, 204, 204), (255, 204, 204), (102, 51, 0), (102, 51, 0), (102, 204, 0), (255, 255, 0), (0, 0, 153), (0, 0, 204), (255, 51, 153), (0, 204, 204), (0, 51, 0), (255, 153, 51), (0, 204, 0)], dtype=np.uint8) else: def uint82bin(n, count=8): """returns the binary of integer n, count refers to amount of bits""" return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) cmap = np.zeros((N, 3), dtype=np.uint8) for i in range(N): r, g, b = 0, 0, 0 id = i for j in range(7): str_id = uint82bin(id) r = r ^ (np.uint8(str_id[-1]) << (7-j)) g = g ^ (np.uint8(str_id[-2]) << (7-j)) b = b ^ (np.uint8(str_id[-3]) << (7-j)) id = id >> 3 cmap[i, 0] = r cmap[i, 1] = g cmap[i, 2] = b return cmap class Colorize(object): def __init__(self, n=19): self.cmap = labelcolormap(n) self.cmap = torch.from_numpy(self.cmap[:n]) def __call__(self, gray_image): size = gray_image.size() color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) for label in range(0, len(self.cmap)): mask = (label == gray_image[0]).cpu() color_image[0][mask] = self.cmap[label][0] color_image[1][mask] = self.cmap[label][1] color_image[2][mask] = self.cmap[label][2] return color_image def tensor2label(label_tensor, n_label): label_tensor = label_tensor.cpu().float() if label_tensor.size()[0] > 1: label_tensor = label_tensor.max(0, keepdim=True)[1] label_tensor = Colorize(n_label)(label_tensor) #label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) label_numpy = label_tensor.numpy() label_numpy = label_numpy / 255.0 return label_numpy def generate_label(inputs, imsize): pred_batch = [] for input in inputs: input = input.view(1, 19, imsize, imsize) pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0) pred_batch.append(pred) pred_batch = np.array(pred_batch) pred_batch = torch.from_numpy(pred_batch) label_batch = [] for p in pred_batch: p = p.view(1, imsize, imsize) label_batch.append(tensor2label(p, 19)) label_batch = np.array(label_batch) label_batch = torch.from_numpy(label_batch) return label_batch def get_mask(model, image, input_size): interp = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True) image = image.unsqueeze(0) with torch.no_grad(): outputs = model(image.cuda()) labels = generate_label(interp(outputs), input_size[0]) return labels[0] def save_mask(args): os.makedirs(args.output_dir, exist_ok=True) cudnn.benchmark = True cudnn.enabled = True model = dml_csr.DML_CSR(19, InPlaceABN, False) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = transforms.Compose([transforms.ToTensor(), normalize]) input_size = (args.size, args.size) image = cv2.imread(args.image_path, cv2.IMREAD_COLOR) h, w, _ = image.shape center, s = box2cs([0, 0, w - 1, h - 1]) r = 0 crop_size = np.asarray(input_size) trans = dml_transforms.get_affine_transform(center, s, r, crop_size) image = cv2.warpAffine(image, trans, (int(crop_size[1]), int(crop_size[0])), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0, 0, 0)) image = transform(image) if not os.path.exists(args.checkpoint_path): download_checkpoint() state_dict = torch.load(args.checkpoint_path, map_location='cuda:0') model.load_state_dict(state_dict) model.cuda() model.eval() mask = get_mask(model, image, input_size) filename = os.path.join(args.output_dir, os.path.basename(args.image_path).split('.')[0] + '.png') save_image(mask, filename) print(f'Mask saved in {filename}') if __name__ == '__main__': args = parse_args() save_mask(args)