|
import os |
|
import cv2 |
|
import gdown |
|
import shutil |
|
import argparse |
|
import numpy as np |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
import torchvision.transforms as transforms |
|
from torchvision.utils import save_image |
|
|
|
from inplace_abn import InPlaceABN |
|
from dml_csr import dml_csr |
|
from dml_csr import transforms as dml_transforms |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Plot segmentation mask of an image.") |
|
parser.add_argument( |
|
"--image_path", |
|
type=str, |
|
default=None, |
|
help="Path to the image file." |
|
) |
|
parser.add_argument("--size", type=int, default=512) |
|
parser.add_argument( |
|
"--checkpoint_path", |
|
type=str, |
|
default='ckpt/DML_CSR/dml_csr_celebA.pth', |
|
help="Path to the DML-CSR pretrained model." |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default="output/masks/", |
|
help="Folder to save segmentation mask." |
|
) |
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
def download_checkpoint(): |
|
os.makedirs('ckpt', exist_ok=True) |
|
id = "1xttWuAj633-ujp_vcm5DtL98PP0b-sUm" |
|
gdown.download(id=id, output='ckpt/DML_CSR.zip') |
|
shutil.unpack_archive('ckpt/DML_CSR.zip', 'ckpt') |
|
os.remove('ckpt/DML_CSR.zip') |
|
|
|
def box2cs(box: list) -> tuple: |
|
x, y, w, h = box[:4] |
|
return xywh2cs(x, y, w, h) |
|
|
|
def xywh2cs(x: float, y: float, w: float, h: float) -> tuple: |
|
center = np.zeros((2), dtype=np.float32) |
|
center[0] = x + w * 0.5 |
|
center[1] = y + h * 0.5 |
|
if w > h: |
|
h = w |
|
elif w < h: |
|
w = h |
|
scale = np.array([w * 1.0, h * 1.0], dtype=np.float32) |
|
|
|
return center, scale |
|
|
|
def labelcolormap(N): |
|
if N == 19: |
|
cmap = np.array([(0, 0, 0), (204, 0, 0), (76, 153, 0), |
|
(204, 204, 0), (204, 0, 204), (204, 0, 204), (255, 204, 204), |
|
(255, 204, 204), (102, 51, 0), (102, 51, 0), (102, 204, 0), |
|
(255, 255, 0), (0, 0, 153), (0, 0, 204), (255, 51, 153), |
|
(0, 204, 204), (0, 51, 0), (255, 153, 51), (0, 204, 0)], |
|
dtype=np.uint8) |
|
else: |
|
def uint82bin(n, count=8): |
|
"""returns the binary of integer n, count refers to amount of bits""" |
|
return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) |
|
|
|
cmap = np.zeros((N, 3), dtype=np.uint8) |
|
for i in range(N): |
|
r, g, b = 0, 0, 0 |
|
id = i |
|
for j in range(7): |
|
str_id = uint82bin(id) |
|
r = r ^ (np.uint8(str_id[-1]) << (7-j)) |
|
g = g ^ (np.uint8(str_id[-2]) << (7-j)) |
|
b = b ^ (np.uint8(str_id[-3]) << (7-j)) |
|
id = id >> 3 |
|
cmap[i, 0] = r |
|
cmap[i, 1] = g |
|
cmap[i, 2] = b |
|
return cmap |
|
|
|
class Colorize(object): |
|
def __init__(self, n=19): |
|
self.cmap = labelcolormap(n) |
|
self.cmap = torch.from_numpy(self.cmap[:n]) |
|
|
|
def __call__(self, gray_image): |
|
size = gray_image.size() |
|
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) |
|
|
|
for label in range(0, len(self.cmap)): |
|
mask = (label == gray_image[0]).cpu() |
|
color_image[0][mask] = self.cmap[label][0] |
|
color_image[1][mask] = self.cmap[label][1] |
|
color_image[2][mask] = self.cmap[label][2] |
|
|
|
return color_image |
|
|
|
def tensor2label(label_tensor, n_label): |
|
label_tensor = label_tensor.cpu().float() |
|
if label_tensor.size()[0] > 1: |
|
label_tensor = label_tensor.max(0, keepdim=True)[1] |
|
label_tensor = Colorize(n_label)(label_tensor) |
|
|
|
label_numpy = label_tensor.numpy() |
|
label_numpy = label_numpy / 255.0 |
|
|
|
return label_numpy |
|
|
|
def generate_label(inputs, imsize): |
|
pred_batch = [] |
|
for input in inputs: |
|
input = input.view(1, 19, imsize, imsize) |
|
pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0) |
|
pred_batch.append(pred) |
|
|
|
pred_batch = np.array(pred_batch) |
|
pred_batch = torch.from_numpy(pred_batch) |
|
|
|
label_batch = [] |
|
for p in pred_batch: |
|
p = p.view(1, imsize, imsize) |
|
label_batch.append(tensor2label(p, 19)) |
|
|
|
label_batch = np.array(label_batch) |
|
label_batch = torch.from_numpy(label_batch) |
|
|
|
return label_batch |
|
|
|
def get_mask(model, image, input_size): |
|
interp = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True) |
|
|
|
image = image.unsqueeze(0) |
|
with torch.no_grad(): |
|
outputs = model(image.cuda()) |
|
labels = generate_label(interp(outputs), input_size[0]) |
|
return labels[0] |
|
|
|
def save_mask(args): |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
cudnn.benchmark = True |
|
cudnn.enabled = True |
|
|
|
model = dml_csr.DML_CSR(19, InPlaceABN, False) |
|
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
transform = transforms.Compose([transforms.ToTensor(), normalize]) |
|
|
|
input_size = (args.size, args.size) |
|
image = cv2.imread(args.image_path, cv2.IMREAD_COLOR) |
|
h, w, _ = image.shape |
|
center, s = box2cs([0, 0, w - 1, h - 1]) |
|
r = 0 |
|
crop_size = np.asarray(input_size) |
|
trans = dml_transforms.get_affine_transform(center, s, r, crop_size) |
|
image = cv2.warpAffine(image, trans, (int(crop_size[1]), int(crop_size[0])), |
|
flags=cv2.INTER_LINEAR, |
|
borderMode=cv2.BORDER_CONSTANT, |
|
borderValue=(0, 0, 0)) |
|
image = transform(image) |
|
|
|
if not os.path.exists(args.checkpoint_path): |
|
download_checkpoint() |
|
state_dict = torch.load(args.checkpoint_path, map_location='cuda:0') |
|
model.load_state_dict(state_dict) |
|
|
|
model.cuda() |
|
model.eval() |
|
|
|
mask = get_mask(model, image, input_size) |
|
filename = os.path.join(args.output_dir, os.path.basename(args.image_path).split('.')[0] + '.png') |
|
save_image(mask, filename) |
|
print(f'Mask saved in {filename}') |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
save_mask(args) |