|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from typing import Iterable, Tuple, Union |
|
from pathlib import Path |
|
from torchvision import transforms |
|
import kornia |
|
from omegaconf import DictConfig |
|
|
|
from src.FaceDetector.face_detector import Detection |
|
from src.FaceAlign.face_align import align_face, inverse_transform_batch |
|
from src.PostProcess.utils import SoftErosion |
|
from src.model_loader import get_model |
|
from src.Misc.types import CheckpointType, FaceAlignmentType |
|
from src.Misc.utils import tensor2img |
|
|
|
|
|
class SimSwap: |
|
def __init__( |
|
self, |
|
config: DictConfig, |
|
id_image: Union[np.ndarray, None] = None, |
|
specific_image: Union[np.ndarray, None] = None, |
|
): |
|
|
|
self.id_image: Union[np.ndarray, None] = id_image |
|
self.id_latent: Union[torch.Tensor, None] = None |
|
self.specific_id_image: Union[np.ndarray, None] = specific_image |
|
self.specific_latent: Union[torch.Tensor, None] = None |
|
|
|
self.use_mask: Union[bool, None] = True |
|
self.crop_size: Union[int, None] = None |
|
self.checkpoint_type: Union[CheckpointType, None] = None |
|
self.face_alignment_type: Union[FaceAlignmentType, None] = None |
|
self.smooth_mask_iter: Union[int, None] = None |
|
self.smooth_mask_kernel_size: Union[int, None] = None |
|
self.smooth_mask_threshold: Union[float, None] = None |
|
self.face_detector_threshold: Union[float, None] = None |
|
self.specific_latent_match_threshold: Union[float, None] = None |
|
self.device = torch.device(config.device) |
|
|
|
self.set_parameters(config) |
|
|
|
|
|
self.to_tensor_normalize = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
|
|
self.to_tensor = transforms.ToTensor() |
|
|
|
self.face_detector = get_model( |
|
"face_detector", |
|
device=self.device, |
|
load_state_dice=False, |
|
model_path=Path(config.face_detector_weights), |
|
det_thresh=self.face_detector_threshold, |
|
det_size=(640, 640), |
|
mode="ffhq", |
|
) |
|
|
|
self.face_id_net = get_model( |
|
"arcface", |
|
device=self.device, |
|
load_state_dice=False, |
|
model_path=Path(config.face_id_weights), |
|
) |
|
|
|
self.bise_net = get_model( |
|
"parsing_model", |
|
device=self.device, |
|
load_state_dice=True, |
|
model_path=Path(config.parsing_model_weights), |
|
n_classes=19, |
|
) |
|
|
|
gen_model = "generator_512" if self.crop_size == 512 else "generator_224" |
|
self.simswap_net = get_model( |
|
gen_model, |
|
device=self.device, |
|
load_state_dice=True, |
|
model_path=Path(config.simswap_weights), |
|
input_nc=3, |
|
output_nc=3, |
|
latent_size=512, |
|
n_blocks=9, |
|
deep=True if self.crop_size == 512 else False, |
|
use_last_act=True |
|
if self.checkpoint_type == CheckpointType.OFFICIAL_224 |
|
else False, |
|
) |
|
|
|
self.blend = get_model( |
|
"blend_module", |
|
device=self.device, |
|
load_state_dice=False, |
|
model_path=Path(config.blend_module_weights) |
|
) |
|
|
|
self.enhance_output = config.enhance_output |
|
if config.enhance_output: |
|
self.gfpgan_net = get_model( |
|
"gfpgan", |
|
device=self.device, |
|
load_state_dice=True, |
|
model_path=Path(config.gfpgan_weights) |
|
) |
|
|
|
def set_parameters(self, config) -> None: |
|
self.set_crop_size(config.crop_size) |
|
self.set_checkpoint_type(config.checkpoint_type) |
|
self.set_face_alignment_type(config.face_alignment_type) |
|
self.set_face_detector_threshold(config.face_detector_threshold) |
|
self.set_specific_latent_match_threshold(config.specific_latent_match_threshold) |
|
self.set_smooth_mask_kernel_size(config.smooth_mask_kernel_size) |
|
self.set_smooth_mask_threshold(config.smooth_mask_threshold) |
|
self.set_smooth_mask_iter(config.smooth_mask_iter) |
|
|
|
def set_crop_size(self, crop_size: int) -> None: |
|
if crop_size < 0: |
|
raise "Invalid crop_size! Must be a positive value." |
|
|
|
self.crop_size = crop_size |
|
|
|
def set_checkpoint_type(self, checkpoint_type: str) -> None: |
|
type = CheckpointType(checkpoint_type) |
|
if type not in (CheckpointType.OFFICIAL_224, CheckpointType.UNOFFICIAL): |
|
raise "Invalid checkpoint_type! Must be one of the predefined values." |
|
|
|
self.checkpoint_type = type |
|
|
|
def set_face_alignment_type(self, face_alignment_type: str) -> None: |
|
type = FaceAlignmentType(face_alignment_type) |
|
if type not in ( |
|
FaceAlignmentType.FFHQ, |
|
FaceAlignmentType.DEFAULT, |
|
): |
|
raise "Invalid face_alignment_type! Must be one of the predefined values." |
|
|
|
self.face_alignment_type = type |
|
|
|
def set_face_detector_threshold(self, face_detector_threshold: float) -> None: |
|
if face_detector_threshold < 0.0 or face_detector_threshold > 1.0: |
|
raise "Invalid face_detector_threshold! Must be a positive value in range [0.0...1.0]." |
|
|
|
self.face_detector_threshold = face_detector_threshold |
|
|
|
def set_specific_latent_match_threshold( |
|
self, specific_latent_match_threshold: float |
|
) -> None: |
|
if specific_latent_match_threshold < 0.0: |
|
raise "Invalid specific_latent_match_th! Must be a positive value." |
|
|
|
self.specific_latent_match_threshold = specific_latent_match_threshold |
|
|
|
def re_initialize_soft_mask(self): |
|
self.smooth_mask = SoftErosion(kernel_size=self.smooth_mask_kernel_size, |
|
threshold=self.smooth_mask_threshold, |
|
iterations=self.smooth_mask_iter).to(self.device) |
|
|
|
def set_smooth_mask_kernel_size(self, smooth_mask_kernel_size: int) -> None: |
|
if smooth_mask_kernel_size < 0: |
|
raise "Invalid smooth_mask_kernel_size! Must be a positive value." |
|
smooth_mask_kernel_size += 1 if smooth_mask_kernel_size % 2 == 0 else 0 |
|
self.smooth_mask_kernel_size = smooth_mask_kernel_size |
|
self.re_initialize_soft_mask() |
|
|
|
def set_smooth_mask_threshold(self, smooth_mask_threshold: int) -> None: |
|
if smooth_mask_threshold < 0 or smooth_mask_threshold > 1.0: |
|
raise "Invalid smooth_mask_threshold! Must be within 0...1 range." |
|
self.smooth_mask_threshold = smooth_mask_threshold |
|
self.re_initialize_soft_mask() |
|
|
|
def set_smooth_mask_iter(self, smooth_mask_iter: float) -> None: |
|
if smooth_mask_iter < 0: |
|
raise "Invalid smooth_mask_iter! Must be a positive value.." |
|
self.smooth_mask_iter = smooth_mask_iter |
|
self.re_initialize_soft_mask() |
|
|
|
def run_detect_align(self, image: np.ndarray, for_id: bool = False) -> Tuple[Union[Iterable[np.ndarray], None], |
|
Union[Iterable[np.ndarray], None], |
|
np.ndarray]: |
|
detection: Detection = self.face_detector(image) |
|
|
|
if detection.bbox is None: |
|
if for_id: |
|
raise "Can't detect a face! Please change the ID image!" |
|
return None, None, detection.score |
|
|
|
kps = detection.key_points |
|
|
|
if for_id: |
|
max_score_ind = np.argmax(detection.score, axis=0) |
|
kps = detection.key_points[max_score_ind] |
|
kps = kps[None, ...] |
|
|
|
align_imgs, transforms = align_face( |
|
image, |
|
kps, |
|
crop_size=self.crop_size, |
|
mode="ffhq" |
|
if self.face_alignment_type == FaceAlignmentType.FFHQ |
|
else "none", |
|
) |
|
|
|
return align_imgs, transforms, detection.score |
|
|
|
def __call__(self, att_image: np.ndarray) -> np.ndarray: |
|
if self.id_latent is None: |
|
align_id_imgs, id_transforms, _ = self.run_detect_align( |
|
self.id_image, for_id=True |
|
) |
|
|
|
self.id_latent: torch.Tensor = self.face_id_net( |
|
align_id_imgs, normalize=True |
|
) |
|
|
|
if self.specific_id_image is not None and self.specific_latent is None: |
|
align_specific_imgs, specific_transforms, _ = self.run_detect_align( |
|
self.specific_id_image, for_id=True |
|
) |
|
self.specific_latent: torch.Tensor = self.face_id_net( |
|
align_specific_imgs, normalize=False |
|
) |
|
|
|
|
|
align_att_imgs, att_transforms, att_detection_score = self.run_detect_align( |
|
att_image, for_id=False |
|
) |
|
|
|
if align_att_imgs is None and att_transforms is None: |
|
return att_image |
|
|
|
|
|
if self.specific_latent is not None: |
|
att_latent: torch.Tensor = self.face_id_net(align_att_imgs, normalize=False) |
|
latent_dist = torch.mean( |
|
F.mse_loss( |
|
att_latent, |
|
self.specific_latent.repeat(att_latent.shape[0], 1), |
|
reduction="none", |
|
), |
|
dim=-1, |
|
) |
|
|
|
att_detection_score = torch.tensor( |
|
att_detection_score, device=latent_dist.device |
|
) |
|
|
|
min_index = torch.argmin(latent_dist * att_detection_score) |
|
min_value = latent_dist[min_index] |
|
|
|
if min_value < self.specific_latent_match_threshold: |
|
align_att_imgs = [align_att_imgs[min_index]] |
|
att_transforms = [att_transforms[min_index]] |
|
else: |
|
return att_image |
|
|
|
swapped_img: torch.Tensor = self.simswap_net(align_att_imgs, self.id_latent) |
|
|
|
if self.enhance_output: |
|
swapped_img = self.gfpgan_net.enhance(swapped_img, weight=0.5) |
|
|
|
|
|
align_att_img_batch_for_parsing_model: torch.Tensor = torch.stack( |
|
[self.to_tensor_normalize(x) for x in align_att_imgs], dim=0 |
|
) |
|
align_att_img_batch_for_parsing_model = ( |
|
align_att_img_batch_for_parsing_model.to(self.device) |
|
) |
|
|
|
att_transforms: torch.Tensor = torch.stack( |
|
[torch.tensor(x).float() for x in att_transforms], dim=0 |
|
) |
|
att_transforms = att_transforms.to(self.device, non_blocking=True) |
|
|
|
align_att_img_batch: torch.Tensor = torch.stack( |
|
[self.to_tensor(x) for x in align_att_imgs], dim=0 |
|
) |
|
align_att_img_batch = align_att_img_batch.to(self.device, non_blocking=True) |
|
|
|
|
|
face_mask, ignore_mask_ids = self.bise_net.get_mask( |
|
align_att_img_batch_for_parsing_model, self.crop_size |
|
) |
|
|
|
inv_att_transforms: torch.Tensor = inverse_transform_batch(att_transforms) |
|
|
|
soft_face_mask, _ = self.smooth_mask(face_mask) |
|
|
|
swapped_img[ignore_mask_ids, ...] = align_att_img_batch[ignore_mask_ids, ...] |
|
|
|
frame_size = (att_image.shape[0], att_image.shape[1]) |
|
|
|
att_image = self.to_tensor(att_image).to(self.device, non_blocking=True).unsqueeze(0) |
|
|
|
target_image = kornia.geometry.transform.warp_affine( |
|
swapped_img, |
|
inv_att_transforms, |
|
frame_size, |
|
mode="bilinear", |
|
padding_mode="border", |
|
align_corners=True, |
|
fill_value=torch.zeros(3), |
|
) |
|
|
|
soft_face_mask = kornia.geometry.transform.warp_affine( |
|
soft_face_mask, |
|
inv_att_transforms, |
|
frame_size, |
|
mode="bilinear", |
|
padding_mode="zeros", |
|
align_corners=True, |
|
fill_value=torch.zeros(3), |
|
) |
|
|
|
result = self.blend(target_image, soft_face_mask, att_image) |
|
|
|
return tensor2img(result) |
|
|