import numpy as np import torch import torch.nn.functional as F from typing import Iterable, Tuple, Union from pathlib import Path from torchvision import transforms import kornia from omegaconf import DictConfig from src.FaceDetector.face_detector import Detection from src.FaceAlign.face_align import align_face, inverse_transform_batch from src.PostProcess.utils import SoftErosion from src.model_loader import get_model from src.Misc.types import CheckpointType, FaceAlignmentType from src.Misc.utils import tensor2img class SimSwap: def __init__( self, config: DictConfig, id_image: Union[np.ndarray, None] = None, specific_image: Union[np.ndarray, None] = None, ): self.id_image: Union[np.ndarray, None] = id_image self.id_latent: Union[torch.Tensor, None] = None self.specific_id_image: Union[np.ndarray, None] = specific_image self.specific_latent: Union[torch.Tensor, None] = None self.use_mask: Union[bool, None] = True self.crop_size: Union[int, None] = None self.checkpoint_type: Union[CheckpointType, None] = None self.face_alignment_type: Union[FaceAlignmentType, None] = None self.smooth_mask_iter: Union[int, None] = None self.smooth_mask_kernel_size: Union[int, None] = None self.smooth_mask_threshold: Union[float, None] = None self.face_detector_threshold: Union[float, None] = None self.specific_latent_match_threshold: Union[float, None] = None self.device = torch.device(config.device) self.set_parameters(config) # For BiSeNet and for official_224 SimSwap self.to_tensor_normalize = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) # For SimSwap models trained with the updated code self.to_tensor = transforms.ToTensor() self.face_detector = get_model( "face_detector", device=self.device, load_state_dice=False, model_path=Path(config.face_detector_weights), det_thresh=self.face_detector_threshold, det_size=(640, 640), mode="ffhq", ) self.face_id_net = get_model( "arcface", device=self.device, load_state_dice=False, model_path=Path(config.face_id_weights), ) self.bise_net = get_model( "parsing_model", device=self.device, load_state_dice=True, model_path=Path(config.parsing_model_weights), n_classes=19, ) gen_model = "generator_512" if self.crop_size == 512 else "generator_224" self.simswap_net = get_model( gen_model, device=self.device, load_state_dice=True, model_path=Path(config.simswap_weights), input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=True if self.crop_size == 512 else False, use_last_act=True if self.checkpoint_type == CheckpointType.OFFICIAL_224 else False, ) self.blend = get_model( "blend_module", device=self.device, load_state_dice=False, model_path=Path(config.blend_module_weights) ) self.enhance_output = config.enhance_output if config.enhance_output: self.gfpgan_net = get_model( "gfpgan", device=self.device, load_state_dice=True, model_path=Path(config.gfpgan_weights) ) def set_parameters(self, config) -> None: self.set_crop_size(config.crop_size) self.set_checkpoint_type(config.checkpoint_type) self.set_face_alignment_type(config.face_alignment_type) self.set_face_detector_threshold(config.face_detector_threshold) self.set_specific_latent_match_threshold(config.specific_latent_match_threshold) self.set_smooth_mask_kernel_size(config.smooth_mask_kernel_size) self.set_smooth_mask_threshold(config.smooth_mask_threshold) self.set_smooth_mask_iter(config.smooth_mask_iter) def set_crop_size(self, crop_size: int) -> None: if crop_size < 0: raise "Invalid crop_size! Must be a positive value." self.crop_size = crop_size def set_checkpoint_type(self, checkpoint_type: str) -> None: type = CheckpointType(checkpoint_type) if type not in (CheckpointType.OFFICIAL_224, CheckpointType.UNOFFICIAL): raise "Invalid checkpoint_type! Must be one of the predefined values." self.checkpoint_type = type def set_face_alignment_type(self, face_alignment_type: str) -> None: type = FaceAlignmentType(face_alignment_type) if type not in ( FaceAlignmentType.FFHQ, FaceAlignmentType.DEFAULT, ): raise "Invalid face_alignment_type! Must be one of the predefined values." self.face_alignment_type = type def set_face_detector_threshold(self, face_detector_threshold: float) -> None: if face_detector_threshold < 0.0 or face_detector_threshold > 1.0: raise "Invalid face_detector_threshold! Must be a positive value in range [0.0...1.0]." self.face_detector_threshold = face_detector_threshold def set_specific_latent_match_threshold( self, specific_latent_match_threshold: float ) -> None: if specific_latent_match_threshold < 0.0: raise "Invalid specific_latent_match_th! Must be a positive value." self.specific_latent_match_threshold = specific_latent_match_threshold def re_initialize_soft_mask(self): self.smooth_mask = SoftErosion(kernel_size=self.smooth_mask_kernel_size, threshold=self.smooth_mask_threshold, iterations=self.smooth_mask_iter).to(self.device) def set_smooth_mask_kernel_size(self, smooth_mask_kernel_size: int) -> None: if smooth_mask_kernel_size < 0: raise "Invalid smooth_mask_kernel_size! Must be a positive value." smooth_mask_kernel_size += 1 if smooth_mask_kernel_size % 2 == 0 else 0 self.smooth_mask_kernel_size = smooth_mask_kernel_size self.re_initialize_soft_mask() def set_smooth_mask_threshold(self, smooth_mask_threshold: int) -> None: if smooth_mask_threshold < 0 or smooth_mask_threshold > 1.0: raise "Invalid smooth_mask_threshold! Must be within 0...1 range." self.smooth_mask_threshold = smooth_mask_threshold self.re_initialize_soft_mask() def set_smooth_mask_iter(self, smooth_mask_iter: float) -> None: if smooth_mask_iter < 0: raise "Invalid smooth_mask_iter! Must be a positive value.." self.smooth_mask_iter = smooth_mask_iter self.re_initialize_soft_mask() def run_detect_align(self, image: np.ndarray, for_id: bool = False) -> Tuple[Union[Iterable[np.ndarray], None], Union[Iterable[np.ndarray], None], np.ndarray]: detection: Detection = self.face_detector(image) if detection.bbox is None: if for_id: raise "Can't detect a face! Please change the ID image!" return None, None, detection.score kps = detection.key_points if for_id: max_score_ind = np.argmax(detection.score, axis=0) kps = detection.key_points[max_score_ind] kps = kps[None, ...] align_imgs, transforms = align_face( image, kps, crop_size=self.crop_size, mode="ffhq" if self.face_alignment_type == FaceAlignmentType.FFHQ else "none", ) return align_imgs, transforms, detection.score def __call__(self, att_image: np.ndarray) -> np.ndarray: if self.id_latent is None: align_id_imgs, id_transforms, _ = self.run_detect_align( self.id_image, for_id=True ) # normalize=True, because official SimSwap model trained with normalized id_lattent self.id_latent: torch.Tensor = self.face_id_net( align_id_imgs, normalize=True ) if self.specific_id_image is not None and self.specific_latent is None: align_specific_imgs, specific_transforms, _ = self.run_detect_align( self.specific_id_image, for_id=True ) self.specific_latent: torch.Tensor = self.face_id_net( align_specific_imgs, normalize=False ) # for_id=False, because we want to get all faces align_att_imgs, att_transforms, att_detection_score = self.run_detect_align( att_image, for_id=False ) if align_att_imgs is None and att_transforms is None: return att_image # Select specific crop from the target image if self.specific_latent is not None: att_latent: torch.Tensor = self.face_id_net(align_att_imgs, normalize=False) latent_dist = torch.mean( F.mse_loss( att_latent, self.specific_latent.repeat(att_latent.shape[0], 1), reduction="none", ), dim=-1, ) att_detection_score = torch.tensor( att_detection_score, device=latent_dist.device ) min_index = torch.argmin(latent_dist * att_detection_score) min_value = latent_dist[min_index] if min_value < self.specific_latent_match_threshold: align_att_imgs = [align_att_imgs[min_index]] att_transforms = [att_transforms[min_index]] else: return att_image swapped_img: torch.Tensor = self.simswap_net(align_att_imgs, self.id_latent) if self.enhance_output: swapped_img = self.gfpgan_net.enhance(swapped_img, weight=0.5) # Put all crops/transformations into a batch align_att_img_batch_for_parsing_model: torch.Tensor = torch.stack( [self.to_tensor_normalize(x) for x in align_att_imgs], dim=0 ) align_att_img_batch_for_parsing_model = ( align_att_img_batch_for_parsing_model.to(self.device) ) att_transforms: torch.Tensor = torch.stack( [torch.tensor(x).float() for x in att_transforms], dim=0 ) att_transforms = att_transforms.to(self.device, non_blocking=True) align_att_img_batch: torch.Tensor = torch.stack( [self.to_tensor(x) for x in align_att_imgs], dim=0 ) align_att_img_batch = align_att_img_batch.to(self.device, non_blocking=True) # Get face masks for the attribute image face_mask, ignore_mask_ids = self.bise_net.get_mask( align_att_img_batch_for_parsing_model, self.crop_size ) inv_att_transforms: torch.Tensor = inverse_transform_batch(att_transforms) soft_face_mask, _ = self.smooth_mask(face_mask) swapped_img[ignore_mask_ids, ...] = align_att_img_batch[ignore_mask_ids, ...] frame_size = (att_image.shape[0], att_image.shape[1]) att_image = self.to_tensor(att_image).to(self.device, non_blocking=True).unsqueeze(0) target_image = kornia.geometry.transform.warp_affine( swapped_img, inv_att_transforms, frame_size, mode="bilinear", padding_mode="border", align_corners=True, fill_value=torch.zeros(3), ) soft_face_mask = kornia.geometry.transform.warp_affine( soft_face_mask, inv_att_transforms, frame_size, mode="bilinear", padding_mode="zeros", align_corners=True, fill_value=torch.zeros(3), ) result = self.blend(target_image, soft_face_mask, att_image) return tensor2img(result)