|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from typing import Tuple |
|
|
|
|
|
class SoftErosion(torch.nn.Module): |
|
def __init__( |
|
self, kernel_size: int = 15, threshold: float = 0.6, iterations: int = 1 |
|
): |
|
super(SoftErosion, self).__init__() |
|
r = kernel_size // 2 |
|
self.padding = r |
|
self.iterations = iterations |
|
self.threshold = threshold |
|
|
|
|
|
y_indices, x_indices = torch.meshgrid( |
|
torch.arange(0.0, kernel_size), torch.arange(0.0, kernel_size) |
|
) |
|
dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2) |
|
kernel = dist.max() - dist |
|
kernel /= kernel.sum() |
|
kernel = kernel.view(1, 1, *kernel.shape) |
|
self.register_buffer("weight", kernel) |
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
for i in range(self.iterations - 1): |
|
x = torch.min( |
|
x, |
|
F.conv2d( |
|
x, weight=self.weight, groups=x.shape[1], padding=self.padding |
|
), |
|
) |
|
x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding) |
|
|
|
mask = x >= self.threshold |
|
|
|
x[mask] = 1.0 |
|
|
|
x[~mask] /= (x[~mask].max() + 1e-7) |
|
|
|
return x, mask |
|
|
|
|
|
def encode_segmentation_rgb( |
|
segmentation: np.ndarray, no_neck: bool = True |
|
) -> np.ndarray: |
|
parse = segmentation |
|
|
|
face_part_ids = ( |
|
[1, 2, 3, 4, 5, 6, 10, 12, 13] |
|
if no_neck |
|
else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14] |
|
) |
|
mouth_id = 11 |
|
|
|
face_map = np.zeros([parse.shape[0], parse.shape[1]]) |
|
mouth_map = np.zeros([parse.shape[0], parse.shape[1]]) |
|
|
|
|
|
for valid_id in face_part_ids: |
|
valid_index = np.where(parse == valid_id) |
|
face_map[valid_index] = 255 |
|
valid_index = np.where(parse == mouth_id) |
|
mouth_map[valid_index] = 255 |
|
|
|
|
|
|
|
return np.stack([face_map, mouth_map], axis=2) |
|
|
|
|
|
def encode_segmentation_rgb_batch( |
|
segmentation: torch.Tensor, no_neck: bool = True |
|
) -> torch.Tensor: |
|
|
|
face_part_ids = ( |
|
[1, 2, 3, 4, 5, 6, 10, 12, 13] |
|
if no_neck |
|
else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14] |
|
) |
|
mouth_id = 11 |
|
|
|
segmentation = segmentation.int() |
|
face_map = torch.zeros_like(segmentation) |
|
mouth_map = torch.zeros_like(segmentation) |
|
|
|
|
|
white_tensor = face_map + 255 |
|
for valid_id in face_part_ids: |
|
face_map = torch.where(segmentation == valid_id, white_tensor, face_map) |
|
mouth_map = torch.where(segmentation == mouth_id, white_tensor, mouth_map) |
|
|
|
return torch.cat([face_map, mouth_map], dim=1) |
|
|
|
|
|
def postprocess( |
|
swapped_face: np.ndarray, |
|
target: np.ndarray, |
|
target_mask: np.ndarray, |
|
smooth_mask: torch.nn.Module, |
|
) -> np.ndarray: |
|
|
|
|
|
mask_tensor = ( |
|
torch.from_numpy(target_mask.copy().transpose((2, 0, 1))) |
|
.float() |
|
.mul_(1 / 255.0) |
|
.cuda() |
|
) |
|
face_mask_tensor = mask_tensor[0] + mask_tensor[1] |
|
|
|
soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0)) |
|
soft_face_mask_tensor.squeeze_() |
|
|
|
soft_face_mask = soft_face_mask_tensor.cpu().numpy() |
|
soft_face_mask = soft_face_mask[:, :, np.newaxis] |
|
|
|
result = swapped_face * soft_face_mask + target * (1 - soft_face_mask) |
|
result = result[:, :, ::-1] |
|
return result |
|
|