diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000000000000000000000000000000000000..d1e7c358b3f7bbc2fddc2b4918a811eef70b6ac8 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,56 @@ +name: Bug Report +description: You think somethings is broken +labels: ["bug", "new"] + +body: + - type: checkboxes + attributes: + label: First, confirm + description: Make sure you use the latest version of the ReActor extension and you have already searched to see if an issue already exists for the bug you encountered before you create a new Issue. + options: + - label: I have read the [instruction](https://github.com/Gourieff/comfyui-reactor-node/blob/main/README.md) carefully + required: true + - label: I have searched the existing issues + required: true + - label: I have updated the extension to the latest version + required: true + - type: markdown + attributes: + value: | + *Please fill this form with as much information as possible and *provide screenshots if possible** + - type: textarea + id: what-did + attributes: + label: What happened? + description: Tell what happened in a very clear and simple way + validations: + required: true + - type: textarea + id: steps + attributes: + label: Steps to reproduce the problem + description: Please provide with precise step by step instructions on how to reproduce the bug + value: | + Your workflow + validations: + required: true + - type: textarea + id: sysinfo + attributes: + label: Sysinfo + description: Describe your platform. OS, browser, GPU, what other nodes are also enabled. + validations: + required: true + - type: textarea + id: logs + attributes: + label: Relevant console log + description: Please provide cmd/terminal logs from the moment you started UI to the momemt you got an error. This will be automatically formatted into code, so no need for backticks. + render: Shell + validations: + required: true + - type: textarea + id: misc + attributes: + label: Additional information + description: Please provide with any relevant additional info or context. diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..649c432c64704ead1fefabf627ef8efcf6345057 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: ReActor Node Community Support + url: https://github.com/Gourieff/comfyui-reactor-node/discussions + about: Please ask and answer questions here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000000000000000000000000000000000000..6ac4f2cce9bb99741e020d22599648ee8364233b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,16 @@ +name: Feature request +description: Suggest an idea for this project +title: "[Feature]: " +labels: ["enhancement", "new"] + +body: + - type: textarea + id: description + attributes: + label: Feature description + description: Describe the feature in a clear and simple way + value: + - type: markdown + attributes: + value: | + The best way to propose an idea is to start a new discussion via the [Discussions](https://github.com/Gourieff/comfyui-reactor-node/discussions) section (choose the "Idea" category) diff --git a/.github/workflows/publish_action.yml b/.github/workflows/publish_action.yml new file mode 100644 index 0000000000000000000000000000000000000000..eb2a326921975540c9d9e0c16bf5788e947a024b --- /dev/null +++ b/.github/workflows/publish_action.yml @@ -0,0 +1,20 @@ +name: Publish to Comfy registry +on: + workflow_dispatch: + push: + branches: + - main + paths: + - "pyproject.toml" + +jobs: + publish-node: + name: Publish Custom Node to registry + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Publish Custom Node + uses: Comfy-Org/publish-node-action@main + with: + personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github secrets and reference it here. \ No newline at end of file diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/__pycache__/__init__.cpython-311.pyc b/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55c5d692ef90a2842a2f7fbf0a2dc020af5436b2 Binary files /dev/null and b/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/modules/__pycache__/processing.cpython-311.pyc b/modules/__pycache__/processing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f367d1b63edcaa800de25802f0b919b4108b0f8 Binary files /dev/null and b/modules/__pycache__/processing.cpython-311.pyc differ diff --git a/modules/__pycache__/scripts.cpython-311.pyc b/modules/__pycache__/scripts.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f92f78fbe21cf5408574ced8cf13c531a689aadc Binary files /dev/null and b/modules/__pycache__/scripts.cpython-311.pyc differ diff --git a/modules/__pycache__/scripts_postprocessing.cpython-311.pyc b/modules/__pycache__/scripts_postprocessing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b5627629fa9885ba895507d7d1ad326efc2e90b Binary files /dev/null and b/modules/__pycache__/scripts_postprocessing.cpython-311.pyc differ diff --git a/modules/__pycache__/shared.cpython-311.pyc b/modules/__pycache__/shared.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..badf51be31fb4970cb3174d7c7a6c31f425ab15c Binary files /dev/null and b/modules/__pycache__/shared.cpython-311.pyc differ diff --git a/modules/images.py b/modules/images.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/processing.py b/modules/processing.py new file mode 100644 index 0000000000000000000000000000000000000000..649c3a118d9420673fe9f29d352fa0784e76a1db --- /dev/null +++ b/modules/processing.py @@ -0,0 +1,13 @@ +class StableDiffusionProcessing: + + def __init__(self, init_imgs): + self.init_images = init_imgs + self.width = init_imgs[0].width + self.height = init_imgs[0].height + self.extra_generation_params = {} + + +class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): + + def __init__(self, init_img): + super().__init__(init_img) diff --git a/modules/scripts.py b/modules/scripts.py new file mode 100644 index 0000000000000000000000000000000000000000..276dfb4a266a01d9bd3f62659ab031c3656bc708 --- /dev/null +++ b/modules/scripts.py @@ -0,0 +1,13 @@ +import os + + +class Script: + pass + + +def basedir(): + return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class PostprocessImageArgs: + pass diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/shared.py b/modules/shared.py new file mode 100644 index 0000000000000000000000000000000000000000..a1dc8710f3ea0ff4781174867472be645386b534 --- /dev/null +++ b/modules/shared.py @@ -0,0 +1,19 @@ +class Options: + img2img_background_color = "#ffffff" # Set to white for now + + +class State: + interrupted = False + + def begin(self): + pass + + def end(self): + pass + + +opts = Options() +state = State() +cmd_opts = None +sd_upscalers = [] +face_restorers = [] diff --git a/r_chainner/__pycache__/model_loading.cpython-311.pyc b/r_chainner/__pycache__/model_loading.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..262b8d3497ea867f48a999ec7e93028c888f695e Binary files /dev/null and b/r_chainner/__pycache__/model_loading.cpython-311.pyc differ diff --git a/r_chainner/__pycache__/types.cpython-311.pyc b/r_chainner/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80dd90f2c56f46503f9eb0a9535e8df730a70101 Binary files /dev/null and b/r_chainner/__pycache__/types.cpython-311.pyc differ diff --git a/r_chainner/archs/face/__pycache__/gfpganv1_clean_arch.cpython-311.pyc b/r_chainner/archs/face/__pycache__/gfpganv1_clean_arch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1439b942d6db37e794422225a3a746c5bdb139d3 Binary files /dev/null and b/r_chainner/archs/face/__pycache__/gfpganv1_clean_arch.cpython-311.pyc differ diff --git a/r_chainner/archs/face/__pycache__/stylegan2_clean_arch.cpython-311.pyc b/r_chainner/archs/face/__pycache__/stylegan2_clean_arch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97fab96c853ff3f3bb30617a2b409e3fc90d7b66 Binary files /dev/null and b/r_chainner/archs/face/__pycache__/stylegan2_clean_arch.cpython-311.pyc differ diff --git a/r_chainner/archs/face/gfpganv1_clean_arch.py b/r_chainner/archs/face/gfpganv1_clean_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c4d4983c24e04a0f6f1fda44217df79a06fd3a --- /dev/null +++ b/r_chainner/archs/face/gfpganv1_clean_arch.py @@ -0,0 +1,370 @@ +# pylint: skip-file +# type: ignore +import math +import random + +import torch +from torch import nn +from torch.nn import functional as F + +from r_chainner.archs.face.stylegan2_clean_arch import StyleGAN2GeneratorClean + + +class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + narrow=1, + sft_half=False, + ): + super(StyleGAN2GeneratorCSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow, + ) + self.sft_half = sft_half + + def forward( + self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False, + ): + """Forward function for StyleGAN2GeneratorCSFT. + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [ + getattr(self.noises, f"noise{i}") for i in range(self.num_layers) + ] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append( + truncation_latent + truncation * (style - truncation_latent) + ) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = ( + styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + ) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.style_convs[::2], + self.style_convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs, + ): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ResBlock(nn.Module): + """Residual block with bilinear upsampling/downsampling. + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + mode (str): Upsampling/downsampling mode. Options: down | up. Default: down. + """ + + def __init__(self, in_channels, out_channels, mode="down"): + super(ResBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) + self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) + if mode == "down": + self.scale_factor = 0.5 + elif mode == "up": + self.scale_factor = 2 + + def forward(self, x): + out = F.leaky_relu_(self.conv1(x), negative_slope=0.2) + # upsample/downsample + out = F.interpolate( + out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False + ) + out = F.leaky_relu_(self.conv2(out), negative_slope=0.2) + # skip + x = F.interpolate( + x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False + ) + skip = self.skip(x) + out = out + skip + return out + + +class GFPGANv1Clean(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + num_mlp (int): Layer number of MLP style layers. Default: 8. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + state_dict, + ): + super(GFPGANv1Clean, self).__init__() + + out_size = 512 + num_style_feat = 512 + channel_multiplier = 2 + decoder_load_path = None + fix_decoder = False + num_mlp = 8 + input_is_latent = True + different_w = True + narrow = 1 + sft_half = True + + self.model_arch = "GFPGAN" + self.sub_type = "Face SR" + self.scale = 8 + self.in_nc = 3 + self.out_nc = 3 + self.state = state_dict + + self.supports_fp16 = False + self.supports_bf16 = True + self.min_size_restriction = 512 + + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + "4": int(512 * unet_narrow), + "8": int(512 * unet_narrow), + "16": int(512 * unet_narrow), + "32": int(512 * unet_narrow), + "64": int(256 * channel_multiplier * unet_narrow), + "128": int(128 * channel_multiplier * unet_narrow), + "256": int(64 * channel_multiplier * unet_narrow), + "512": int(32 * channel_multiplier * unet_narrow), + "1024": int(16 * channel_multiplier * unet_narrow), + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2 ** (int(math.log(out_size, 2))) + + self.conv_body_first = nn.Conv2d(3, channels[f"{first_out_size}"], 1) + + # downsample + in_channels = channels[f"{first_out_size}"] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f"{2**(i - 1)}"] + self.conv_body_down.append(ResBlock(in_channels, out_channels, mode="down")) + in_channels = out_channels + + self.final_conv = nn.Conv2d(in_channels, channels["4"], 3, 1, 1) + + # upsample + in_channels = channels["4"] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f"{2**i}"] + self.conv_body_up.append(ResBlock(in_channels, out_channels, mode="up")) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(nn.Conv2d(channels[f"{2**i}"], 3, 1)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = nn.Linear(channels["4"] * 4 * 4, linear_out_channel) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorCSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow, + sft_half=sft_half, + ) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load( + decoder_load_path, map_location=lambda storage, loc: storage + )["params_ema"] + ) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f"{2**i}"] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1), + ) + ) + self.condition_shift.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1), + ) + ) + self.load_state_dict(state_dict) + + def forward( + self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs + ): + """Forward function for GFPGANv1Clean. + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder( + [style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise, + ) + + return image, out_rgbs diff --git a/r_chainner/archs/face/stylegan2_clean_arch.py b/r_chainner/archs/face/stylegan2_clean_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..a68655a4b3de22e6dc94f126331159f3425f3669 --- /dev/null +++ b/r_chainner/archs/face/stylegan2_clean_arch.py @@ -0,0 +1,453 @@ +# pylint: skip-file +# type: ignore +import math + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn import init +from torch.nn.modules.batchnorm import _BatchNorm + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +class NormStyleCode(nn.Module): + def forward(self, x): + """Normalize the style codes. + Args: + x (Tensor): Style codes with shape (b, c). + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + There is no bias in ModulatedConv2d. + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + eps (float): A value added to the denominator for numerical stability. Default: 1e-8. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-8, + ): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + + # modulation inside each modulated conv + self.modulation = nn.Linear(num_style_feat, in_channels, bias=True) + # initialization + default_init_weights( + self.modulation, + scale=1, + bias_fill=1, + a=0, + mode="fan_in", + nonlinearity="linear", + ) + + self.weight = nn.Parameter( + torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) + / math.sqrt(in_channels * kernel_size**2) + ) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view( + b * self.out_channels, c, self.kernel_size, self.kernel_size + ) + + # upsample or downsample if necessary + if self.sample_mode == "upsample": + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) + elif self.sample_mode == "downsample": + x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False) + + b, c, h, w = x.shape + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})" + ) + + +class StyleConv(nn.Module): + """Style conv used in StyleGAN2. + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + ): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=demodulate, + sample_mode=sample_mode, + ) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) + self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) * 2**0.5 # for conversion + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # add bias + out = out + self.bias + # activation + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB (image space) from features. + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + """ + + def __init__(self, in_channels, num_style_feat, upsample=True): + super(ToRGB, self).__init__() + self.upsample = upsample + self.modulated_conv = ModulatedConv2d( + in_channels, + 3, + kernel_size=1, + num_style_feat=num_style_feat, + demodulate=False, + sample_mode=None, + ) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate( + skip, scale_factor=2, mode="bilinear", align_corners=False + ) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +class StyleGAN2GeneratorClean(nn.Module): + """Clean version of StyleGAN2 Generator. + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__( + self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1 + ): + super(StyleGAN2GeneratorClean, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.extend( + [ + nn.Linear(num_style_feat, num_style_feat, bias=True), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ] + ) + self.style_mlp = nn.Sequential(*style_mlp_layers) + # initialization + default_init_weights( + self.style_mlp, + scale=1, + bias_fill=0, + a=0.2, + mode="fan_in", + nonlinearity="leaky_relu", + ) + + # channel list + channels = { + "4": int(512 * narrow), + "8": int(512 * narrow), + "16": int(512 * narrow), + "32": int(512 * narrow), + "64": int(256 * channel_multiplier * narrow), + "128": int(128 * channel_multiplier * narrow), + "256": int(64 * channel_multiplier * narrow), + "512": int(32 * channel_multiplier * narrow), + "1024": int(16 * channel_multiplier * narrow), + } + self.channels = channels + + self.constant_input = ConstantInput(channels["4"], size=4) + self.style_conv1 = StyleConv( + channels["4"], + channels["4"], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + ) + self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels["4"] + # noise + for layer_idx in range(self.num_layers): + resolution = 2 ** ((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f"{2**i}"] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode="upsample", + ) + ) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + ) + ) + self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn( + num_latent, self.num_style_feat, device=self.constant_input.weight.device + ) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward( + self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False, + ): + """Forward function for StyleGAN2GeneratorClean. + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [ + getattr(self.noises, f"noise{i}") for i in range(self.num_layers) + ] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append( + truncation_latent + truncation * (style - truncation_latent) + ) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = ( + styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + ) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.style_convs[::2], + self.style_convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs, + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None diff --git a/r_chainner/model_loading.py b/r_chainner/model_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..21fd51d7db5eab933583f6dd72a707ed30707b74 --- /dev/null +++ b/r_chainner/model_loading.py @@ -0,0 +1,28 @@ +from r_chainner.archs.face.gfpganv1_clean_arch import GFPGANv1Clean +from r_chainner.types import PyTorchModel + + +class UnsupportedModel(Exception): + pass + + +def load_state_dict(state_dict) -> PyTorchModel: + + state_dict_keys = list(state_dict.keys()) + + if "params_ema" in state_dict_keys: + state_dict = state_dict["params_ema"] + elif "params-ema" in state_dict_keys: + state_dict = state_dict["params-ema"] + elif "params" in state_dict_keys: + state_dict = state_dict["params"] + + state_dict_keys = list(state_dict.keys()) + + # GFPGAN + if ( + "toRGB.0.weight" in state_dict_keys + and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys + ): + model = GFPGANv1Clean(state_dict) + return model diff --git a/r_chainner/types.py b/r_chainner/types.py new file mode 100644 index 0000000000000000000000000000000000000000..73e6a28882d04d1a365ce3880a4f97597e8a9ec2 --- /dev/null +++ b/r_chainner/types.py @@ -0,0 +1,18 @@ +from typing import Union + +from r_chainner.archs.face.gfpganv1_clean_arch import GFPGANv1Clean + + +PyTorchFaceModels = (GFPGANv1Clean,) +PyTorchFaceModel = Union[GFPGANv1Clean] + + +def is_pytorch_face_model(model: object): + return isinstance(model, PyTorchFaceModels) + +PyTorchModels = (*PyTorchFaceModels, ) +PyTorchModel = Union[PyTorchFaceModel] + + +def is_pytorch_model(model: object): + return isinstance(model, PyTorchModels) diff --git a/r_facelib/__init__.py b/r_facelib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/r_facelib/__pycache__/__init__.cpython-311.pyc b/r_facelib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73adb311256de057f02674f7304eecf917151f4d Binary files /dev/null and b/r_facelib/__pycache__/__init__.cpython-311.pyc differ diff --git a/r_facelib/detection/__init__.py b/r_facelib/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c953bd55159136cd4728da3c4fa3f658514cac2 --- /dev/null +++ b/r_facelib/detection/__init__.py @@ -0,0 +1,102 @@ +import os +import torch +from torch import nn +from copy import deepcopy +import pathlib + +from r_facelib.utils import load_file_from_url +from r_facelib.utils import download_pretrained_models +from r_facelib.detection.yolov5face.models.common import Conv + +from .retinaface.retinaface import RetinaFace +from .yolov5face.face_detector import YoloDetector + + +def init_detection_model(model_name, half=False, device='cuda'): + if 'retinaface' in model_name: + model = init_retinaface_model(model_name, half, device) + elif 'YOLOv5' in model_name: + model = init_yolov5face_model(model_name, device) + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + return model + + +def init_retinaface_model(model_name, half=False, device='cuda'): + if model_name == 'retinaface_resnet50': + model = RetinaFace(network_name='resnet50', half=half) + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth' + elif model_name == 'retinaface_mobile0.25': + model = RetinaFace(network_name='mobile0.25', half=half) + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='../../models/facedetection', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + + return model + + +def init_yolov5face_model(model_name, device='cuda'): + current_dir = str(pathlib.Path(__file__).parent.resolve()) + if model_name == 'YOLOv5l': + model = YoloDetector(config_name=current_dir+'/yolov5face/models/yolov5l.yaml', device=device) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth' + elif model_name == 'YOLOv5n': + model = YoloDetector(config_name=current_dir+'/yolov5face/models/yolov5n.yaml', device=device) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='../../models/facedetection', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.detector.load_state_dict(load_net, strict=True) + model.detector.eval() + model.detector = model.detector.to(device).float() + + for m in model.detector.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: + m.inplace = True # pytorch 1.7.0 compatibility + elif isinstance(m, Conv): + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + + return model + + +# Download from Google Drive +# def init_yolov5face_model(model_name, device='cuda'): +# if model_name == 'YOLOv5l': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) +# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} +# elif model_name == 'YOLOv5n': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) +# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} +# else: +# raise NotImplementedError(f'{model_name} is not implemented.') + +# model_path = os.path.join('../../models/facedetection', list(f_id.keys())[0]) +# if not os.path.exists(model_path): +# download_pretrained_models(file_ids=f_id, save_path_root='../../models/facedetection') + +# load_net = torch.load(model_path, map_location=lambda storage, loc: storage) +# model.detector.load_state_dict(load_net, strict=True) +# model.detector.eval() +# model.detector = model.detector.to(device).float() + +# for m in model.detector.modules(): +# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: +# m.inplace = True # pytorch 1.7.0 compatibility +# elif isinstance(m, Conv): +# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + +# return model \ No newline at end of file diff --git a/r_facelib/detection/__pycache__/__init__.cpython-311.pyc b/r_facelib/detection/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f559c8bc30073970c816e62fd740b1af2966eac0 Binary files /dev/null and b/r_facelib/detection/__pycache__/__init__.cpython-311.pyc differ diff --git a/r_facelib/detection/__pycache__/align_trans.cpython-311.pyc b/r_facelib/detection/__pycache__/align_trans.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ae8c7946eab04f4568c98e782a4b4ca7b6783fd Binary files /dev/null and b/r_facelib/detection/__pycache__/align_trans.cpython-311.pyc differ diff --git a/r_facelib/detection/__pycache__/matlab_cp2tform.cpython-311.pyc b/r_facelib/detection/__pycache__/matlab_cp2tform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e5c7d4a1a6633d3595ba0702272fa43eecde807 Binary files /dev/null and b/r_facelib/detection/__pycache__/matlab_cp2tform.cpython-311.pyc differ diff --git a/r_facelib/detection/align_trans.py b/r_facelib/detection/align_trans.py new file mode 100644 index 0000000000000000000000000000000000000000..0b7374ab8a813f7dcb5313d7ad7b586b739424d7 --- /dev/null +++ b/r_facelib/detection/align_trans.py @@ -0,0 +1,219 @@ +import cv2 +import numpy as np + +from .matlab_cp2tform import get_similarity_transform_for_cv2 + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278], + [33.54930115, 92.3655014], [62.72990036, 92.20410156]] + +DEFAULT_CROP_SIZE = (96, 112) + + +class FaceWarpException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): + """ + Function: + ---------- + get reference 5 key points according to crop settings: + 0. Set default crop_size: + if default_square: + crop_size = (112, 112) + else: + crop_size = (96, 112) + 1. Pad the crop_size by inner_padding_factor in each side; + 2. Resize crop_size into (output_size - outer_padding*2), + pad into output_size with outer_padding; + 3. Output reference_5point; + Parameters: + ---------- + @output_size: (w, h) or None + size of aligned face image + @inner_padding_factor: (w_factor, h_factor) + padding factor for inner (w, h) + @outer_padding: (w_pad, h_pad) + each row is a pair of coordinates (x, y) + @default_square: True or False + if True: + default crop_size = (112, 112) + else: + default crop_size = (96, 112); + !!! make sure, if output_size is not None: + (output_size - outer_padding) + = some_scale * (default crop_size * (1.0 + + inner_padding_factor)) + Returns: + ---------- + @reference_5point: 5x2 np.array + each row is a pair of transformed coordinates (x, y) + """ + + tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) + tmp_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(tmp_crop_size) - tmp_crop_size + tmp_5pts += size_diff / 2 + tmp_crop_size += size_diff + + if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]): + + return tmp_5pts + + if (inner_padding_factor == 0 and outer_padding == (0, 0)): + if output_size is None: + return tmp_5pts + else: + raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') + + if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None): + output_size = tmp_crop_size * \ + (1 + inner_padding_factor * 2).astype(np.int32) + output_size += np.array(outer_padding) + if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): + raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])') + + # 1) pad the inner region according inner_padding_factor + if inner_padding_factor > 0: + size_diff = tmp_crop_size * inner_padding_factor * 2 + tmp_5pts += size_diff / 2 + tmp_crop_size += np.round(size_diff).astype(np.int32) + + # 2) resize the padded inner region + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + + if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: + raise FaceWarpException('Must have (output_size - outer_padding)' + '= some_scale * (crop_size * (1.0 + inner_padding_factor)') + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] + tmp_5pts = tmp_5pts * scale_factor + # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) + # tmp_5pts = tmp_5pts + size_diff / 2 + tmp_crop_size = size_bf_outer_pad + + # 3) add outer_padding to make output_size + reference_5point = tmp_5pts + np.array(outer_padding) + tmp_crop_size = output_size + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + """ + Function: + ---------- + get affine transform matrix 'tfm' from src_pts to dst_pts + Parameters: + ---------- + @src_pts: Kx2 np.array + source points matrix, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points matrix, each row is a pair of coordinates (x, y) + Returns: + ---------- + @tfm: 2x3 np.array + transform matrix from src_pts to dst_pts + """ + + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) + elif rank == 2: + tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) + + return tfm + + +def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'): + """ + Function: + ---------- + apply affine transform 'trans' to uv + Parameters: + ---------- + @src_img: 3x3 np.array + input image + @facial_pts: could be + 1)a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + @reference_pts: could be + 1) a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + or + 3) None + if None, use default reference facial points + @crop_size: (w, h) + output face image size + @align_type: transform type, could be one of + 1) 'similarity': use similarity transform + 2) 'cv2_affine': use the first 3 points to do affine transform, + by calling cv2.getAffineTransform() + 3) 'affine': use all points to do affine transform + Returns: + ---------- + @face_img: output face image with size (w, h) = @crop_size + """ + + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, + default_square) + + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException('facial_pts and reference_pts must have the same shape') + + if align_type == 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + elif align_type == 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + else: + tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) + + face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) + + return face_img diff --git a/r_facelib/detection/matlab_cp2tform.py b/r_facelib/detection/matlab_cp2tform.py new file mode 100644 index 0000000000000000000000000000000000000000..b1014a826ab19ed66e8fdc8a4f7dd9d3f5c08393 --- /dev/null +++ b/r_facelib/detection/matlab_cp2tform.py @@ -0,0 +1,317 @@ +import numpy as np +from numpy.linalg import inv, lstsq +from numpy.linalg import matrix_rank as rank +from numpy.linalg import norm + + +class MatlabCp2tormException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def tformfwd(trans, uv): + """ + Function: + ---------- + apply affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of transformed coordinates (x, y) + """ + uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy = np.dot(uv, trans) + xy = xy[:, 0:-1] + return xy + + +def tforminv(trans, uv): + """ + Function: + ---------- + apply the inverse of affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of inverse-transformed coordinates (x, y) + """ + Tinv = inv(trans) + xy = tformfwd(Tinv, uv) + return xy + + +def findNonreflectiveSimilarity(uv, xy, options=None): + options = {'K': 2} + + K = options['K'] + M = xy.shape[0] + x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + + tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) + tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) + X = np.vstack((tmp1, tmp2)) + + u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + U = np.vstack((u, v)) + + # We know that X * r = U + if rank(X) >= 2 * K: + r, _, _, _ = lstsq(X, U, rcond=-1) + r = np.squeeze(r) + else: + raise Exception('cp2tform:twoUniquePointsReq') + sc = r[0] + ss = r[1] + tx = r[2] + ty = r[3] + + Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) + T = inv(Tinv) + T[:, 2] = np.array([0, 0, 1]) + + return T, Tinv + + +def findSimilarity(uv, xy, options=None): + options = {'K': 2} + + # uv = np.array(uv) + # xy = np.array(xy) + + # Solve for trans1 + trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) + + # Solve for trans2 + + # manually reflect the xy data across the Y-axis + xyR = xy + xyR[:, 0] = -1 * xyR[:, 0] + + trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) + + # manually reflect the tform to undo the reflection done on xyR + TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + trans2 = np.dot(trans2r, TreflectY) + + # Figure out if trans1 or trans2 is better + xy1 = tformfwd(trans1, uv) + norm1 = norm(xy1 - xy) + + xy2 = tformfwd(trans2, uv) + norm2 = norm(xy2 - xy) + + if norm1 <= norm2: + return trans1, trans1_inv + else: + trans2_inv = inv(trans2) + return trans2, trans2_inv + + +def get_similarity_transform(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'trans': + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y, 1] = [u, v, 1] * trans + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + @reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + trans_inv: 3x3 np.array + inverse of trans, transform matrix from xy to uv + """ + + if reflective: + trans, trans_inv = findSimilarity(src_pts, dst_pts) + else: + trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) + + return trans, trans_inv + + +def cvt_tform_mat_for_cv2(trans): + """ + Function: + ---------- + Convert Transform Matrix 'trans' into 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + cv2_trans = trans[:, 0:2].T + + return cv2_trans + + +def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) + cv2_trans = cvt_tform_mat_for_cv2(trans) + + return cv2_trans + + +if __name__ == '__main__': + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + # In Matlab, run: + # + # uv = [u'; v']; + # xy = [x'; y']; + # tform_sim=cp2tform(uv,xy,'similarity'); + # + # trans = tform_sim.tdata.T + # ans = + # -0.0764 -1.6190 0 + # 1.6190 -0.0764 0 + # -3.2156 0.0290 1.0000 + # trans_inv = tform_sim.tdata.Tinv + # ans = + # + # -0.0291 0.6163 0 + # -0.6163 -0.0291 0 + # -0.0756 1.9826 1.0000 + # xy_m=tformfwd(tform_sim, u,v) + # + # xy_m = + # + # -3.2156 0.0290 + # 1.1833 -9.9143 + # 5.0323 2.8853 + # uv_m=tforminv(tform_sim, x,y) + # + # uv_m = + # + # 0.5698 1.3953 + # 6.0872 2.2733 + # -2.6570 4.3314 + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + uv = np.array((u, v)).T + xy = np.array((x, y)).T + + print('\n--->uv:') + print(uv) + print('\n--->xy:') + print(xy) + + trans, trans_inv = get_similarity_transform(uv, xy) + + print('\n--->trans matrix:') + print(trans) + + print('\n--->trans_inv matrix:') + print(trans_inv) + + print('\n---> apply transform to uv') + print('\nxy_m = uv_augmented * trans') + uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy_m = np.dot(uv_aug, trans) + print(xy_m) + + print('\nxy_m = tformfwd(trans, uv)') + xy_m = tformfwd(trans, uv) + print(xy_m) + + print('\n---> apply inverse transform to xy') + print('\nuv_m = xy_augmented * trans_inv') + xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) + uv_m = np.dot(xy_aug, trans_inv) + print(uv_m) + + print('\nuv_m = tformfwd(trans_inv, xy)') + uv_m = tformfwd(trans_inv, xy) + print(uv_m) + + uv_m = tforminv(trans, xy) + print('\nuv_m = tforminv(trans, xy)') + print(uv_m) diff --git a/r_facelib/detection/retinaface/__pycache__/retinaface.cpython-311.pyc b/r_facelib/detection/retinaface/__pycache__/retinaface.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..554aee09fd4813e82515a91fc82d6baae8a924f6 Binary files /dev/null and b/r_facelib/detection/retinaface/__pycache__/retinaface.cpython-311.pyc differ diff --git a/r_facelib/detection/retinaface/__pycache__/retinaface_net.cpython-311.pyc b/r_facelib/detection/retinaface/__pycache__/retinaface_net.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3f740546cdeea645ffbf6c0f4ec899b99ee8a9c Binary files /dev/null and b/r_facelib/detection/retinaface/__pycache__/retinaface_net.cpython-311.pyc differ diff --git a/r_facelib/detection/retinaface/__pycache__/retinaface_utils.cpython-311.pyc b/r_facelib/detection/retinaface/__pycache__/retinaface_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44c2ae4c6995ad2f5511a67ba9cb4daa59df5375 Binary files /dev/null and b/r_facelib/detection/retinaface/__pycache__/retinaface_utils.cpython-311.pyc differ diff --git a/r_facelib/detection/retinaface/retinaface.py b/r_facelib/detection/retinaface/retinaface.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9770a178f8a53628a3c3ebc40474455da05335 --- /dev/null +++ b/r_facelib/detection/retinaface/retinaface.py @@ -0,0 +1,389 @@ +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter + +from modules import shared + +from r_facelib.detection.align_trans import get_reference_facial_points, warp_and_crop_face +from r_facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head +from r_facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm, + py_cpu_nms) + +#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch.backends.mps.is_available(): + device = torch.device('mps') +# elif hasattr(torch,'dml'): +# device = torch.device('dml') +elif hasattr(torch,'dml') or hasattr(torch,'privateuseone'): # AMD + if shared.cmd_opts is not None: # A1111 + if shared.cmd_opts.device_id is not None: + device = torch.device(f'privateuseone:{shared.cmd_opts.device_id}') + else: + device = torch.device('privateuseone:0') + else: + device = torch.device('privateuseone:0') +else: + device = torch.device('cpu') + + +def generate_config(network_name): + + cfg_mnet = { + 'name': 'mobilenet0.25', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 32, + 'ngpu': 1, + 'epoch': 250, + 'decay1': 190, + 'decay2': 220, + 'image_size': 640, + 'return_layers': { + 'stage1': 1, + 'stage2': 2, + 'stage3': 3 + }, + 'in_channel': 32, + 'out_channel': 64 + } + + cfg_re50 = { + 'name': 'Resnet50', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 24, + 'ngpu': 4, + 'epoch': 100, + 'decay1': 70, + 'decay2': 90, + 'image_size': 840, + 'return_layers': { + 'layer2': 1, + 'layer3': 2, + 'layer4': 3 + }, + 'in_channel': 256, + 'out_channel': 256 + } + + if network_name == 'mobile0.25': + return cfg_mnet + elif network_name == 'resnet50': + return cfg_re50 + else: + raise NotImplementedError(f'network_name={network_name}') + + +class RetinaFace(nn.Module): + + def __init__(self, network_name='resnet50', half=False, phase='test'): + super(RetinaFace, self).__init__() + self.half_inference = half + cfg = generate_config(network_name) + self.backbone = cfg['name'] + + self.model_name = f'retinaface_{network_name}' + self.cfg = cfg + self.phase = phase + self.target_size, self.max_size = 1600, 2150 + self.resize, self.scale, self.scale1 = 1., None, None + self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device) + self.reference = get_reference_facial_points(default_square=True) + # Build network. + backbone = None + if cfg['name'] == 'mobilenet0.25': + backbone = MobileNetV1() + self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) + elif cfg['name'] == 'Resnet50': + import torchvision.models as models + backbone = models.resnet50(pretrained=False) + self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) + + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) + + self.to(device) + self.eval() + if self.half_inference: + self.half() + + def forward(self, inputs): + self.to(device) + out = self.body(inputs) + + if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50': + out = list(out.values()) + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) + classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1) + tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)] + ldm_regressions = (torch.cat(tmp, dim=1)) + + if self.phase == 'train': + output = (bbox_regressions, classifications, ldm_regressions) + else: + output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) + return output + + def __detect_faces(self, inputs): + # get scale + height, width = inputs.shape[2:] + self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device) + tmp = [width, height, width, height, width, height, width, height, width, height] + self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device) + + # forawrd + inputs = inputs.to(device) + if self.half_inference: + inputs = inputs.half() + loc, conf, landmarks = self(inputs) + + # get priorbox + priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:]) + priors = priorbox.forward().to(device) + + return loc, conf, landmarks, priors + + # single image detection + def transform(self, image, use_origin_size): + # convert to opencv format + if isinstance(image, Image.Image): + image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + image = image.astype(np.float32) + + # testing scale + im_size_min = np.min(image.shape[0:2]) + im_size_max = np.max(image.shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + + # convert to torch.tensor format + # image -= (104, 117, 123) + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).unsqueeze(0) + + return image, resize + + def detect_faces( + self, + image, + conf_threshold=0.8, + nms_threshold=0.4, + use_origin_size=True, + ): + """ + Params: + imgs: BGR image + """ + image, self.resize = self.transform(image, use_origin_size) + image = image.to(device) + if self.half_inference: + image = image.half() + image = image - self.mean_tensor + + loc, conf, landmarks, priors = self.__detect_faces(image) + + boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance']) + boxes = boxes * self.scale / self.resize + boxes = boxes.cpu().numpy() + + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + + landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance']) + landmarks = landmarks * self.scale1 / self.resize + landmarks = landmarks.cpu().numpy() + + # ignore low scores + inds = np.where(scores > conf_threshold)[0] + boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds] + + # sort + order = scores.argsort()[::-1] + boxes, landmarks, scores = boxes[order], landmarks[order], scores[order] + + # do NMS + bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep] + # self.t['forward_pass'].toc() + # print(self.t['forward_pass'].average_time) + # import sys + # sys.stdout.flush() + return np.concatenate((bounding_boxes, landmarks), axis=1) + + def __align_multi(self, image, boxes, landmarks, limit=None): + + if len(boxes) < 1: + return [], [] + + if limit: + boxes = boxes[:limit] + landmarks = landmarks[:limit] + + faces = [] + for landmark in landmarks: + facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)] + + warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112)) + faces.append(warped_face) + + return np.concatenate((boxes, landmarks), axis=1), faces + + def align_multi(self, img, conf_threshold=0.8, limit=None): + + rlt = self.detect_faces(img, conf_threshold=conf_threshold) + boxes, landmarks = rlt[:, 0:5], rlt[:, 5:] + + return self.__align_multi(img, boxes, landmarks, limit) + + # batched detection + def batched_transform(self, frames, use_origin_size): + """ + Arguments: + frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c], + type=np.float32, BGR format). + use_origin_size: whether to use origin size. + """ + from_PIL = True if isinstance(frames[0], Image.Image) else False + + # convert to opencv format + if from_PIL: + frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames] + frames = np.asarray(frames, dtype=np.float32) + + # testing scale + im_size_min = np.min(frames[0].shape[0:2]) + im_size_max = np.max(frames[0].shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + if not from_PIL: + frames = F.interpolate(frames, scale_factor=resize) + else: + frames = [ + cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + for frame in frames + ] + + # convert to torch.tensor format + if not from_PIL: + frames = frames.transpose(1, 2).transpose(1, 3).contiguous() + else: + frames = frames.transpose((0, 3, 1, 2)) + frames = torch.from_numpy(frames) + + return frames, resize + + def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True): + """ + Arguments: + frames: a list of PIL.Image, or np.array(shape=[n, h, w, c], + type=np.uint8, BGR format). + conf_threshold: confidence threshold. + nms_threshold: nms threshold. + use_origin_size: whether to use origin size. + Returns: + final_bounding_boxes: list of np.array ([n_boxes, 5], + type=np.float32). + final_landmarks: list of np.array ([n_boxes, 10], type=np.float32). + """ + # self.t['forward_pass'].tic() + frames, self.resize = self.batched_transform(frames, use_origin_size) + frames = frames.to(device) + frames = frames - self.mean_tensor + + b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames) + + final_bounding_boxes, final_landmarks = [], [] + + # decode + priors = priors.unsqueeze(0) + b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize + b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize + b_conf = b_conf[:, :, 1] + + # index for selection + b_indice = b_conf > conf_threshold + + # concat + b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float() + + for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice): + + # ignore low scores + pred, landm = pred[inds, :], landm[inds, :] + if pred.shape[0] == 0: + final_bounding_boxes.append(np.array([], dtype=np.float32)) + final_landmarks.append(np.array([], dtype=np.float32)) + continue + + # sort + # order = score.argsort(descending=True) + # box, landm, score = box[order], landm[order], score[order] + + # to CPU + bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy() + + # NMS + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep] + + # append + final_bounding_boxes.append(bounding_boxes) + final_landmarks.append(landmarks) + # self.t['forward_pass'].toc(average=True) + # self.batch_time += self.t['forward_pass'].diff + # self.total_frame += len(frames) + # print(self.batch_time / self.total_frame) + + return final_bounding_boxes, final_landmarks diff --git a/r_facelib/detection/retinaface/retinaface_net.py b/r_facelib/detection/retinaface/retinaface_net.py new file mode 100644 index 0000000000000000000000000000000000000000..c52535eecc33d93d8ba03f56538b623145f65ee3 --- /dev/null +++ b/r_facelib/detection/retinaface/retinaface_net.py @@ -0,0 +1,196 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv_bn(inp, oup, stride=1, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + + +class FPN(nn.Module): + + def __init__(self, in_channels_list, out_channels): + super(FPN, self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, input): + # names = list(input.keys()) + # input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest') + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest') + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + +class MobileNetV1(nn.Module): + + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky=0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + # x = self.model(x) + x = x.view(-1, 256) + x = self.fc(x) + return x + + +class ClassHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(ClassHead, self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(BboxHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(LandmarkHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 10) + + +def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels, anchor_num)) + return classhead + + +def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels, anchor_num)) + return bboxhead + + +def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels, anchor_num)) + return landmarkhead diff --git a/r_facelib/detection/retinaface/retinaface_utils.py b/r_facelib/detection/retinaface/retinaface_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f19e320943261086316b54be33ea5d3ca4fc861b --- /dev/null +++ b/r_facelib/detection/retinaface/retinaface_utils.py @@ -0,0 +1,421 @@ +import numpy as np +import torch +import torchvision +from itertools import product as product +from math import ceil + + +class PriorBox(object): + + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps] + self.name = 's' + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + keep = torchvision.ops.nms( + boxes=torch.Tensor(dets[:, :4]), + scores=torch.Tensor(dets[:, 4]), + iou_threshold=thresh, + ) + + return list(keep) + + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat( + ( + boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:] / 2), + 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat( + (boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy + boxes[:, 2:] - boxes[:, :2], + 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i) + + +def matrix_iof(a, b): + """ + return iof of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when matching boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + landms: (tensor) Ground truth landms, Shape [num_obj, 10]. + loc_t: (tensor) Tensor to be filled w/ encoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + landm_t: (tensor) Tensor to be filled w/ encoded landm targets. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location 2)confidence + 3)landm preds. + """ + # jaccard index + overlaps = jaccard(truths, point_form(priors)) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + loc_t[idx] = 0 + conf_t[idx] = 0 + return + + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 + conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 + conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 + loc = encode(matches, priors, variances) + + matches_landm = landms[best_truth_idx] + landm = encode_landm(matches_landm, priors, variances) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + landm_t[idx] = landm + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def encode_landm(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 10]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded landm (tensor), Shape: [num_priors, 10] + """ + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, :, 2:]) + # g_cxcy /= priors[:, :, 2:] + g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) + # return target for smooth_l1_loss + return g_cxcy + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + tmp = ( + priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], + ) + landms = torch.cat(tmp, dim=1) + return landms + + +def batched_decode(b_loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + b_loc (tensor): location predictions for loc layers, + Shape: [num_batches,num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + boxes = ( + priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]), + ) + boxes = torch.cat(boxes, dim=2) + + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + +def batched_decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_batches,num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + landms = ( + priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:], + ) + landms = torch.cat(landms, dim=2) + return landms + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = torch.Tensor(scores.size(0)).fill_(0).long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w * h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter / union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count diff --git a/r_facelib/detection/yolov5face/__init__.py b/r_facelib/detection/yolov5face/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/r_facelib/detection/yolov5face/__pycache__/__init__.cpython-311.pyc b/r_facelib/detection/yolov5face/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a56986ff3ca989f82dc8f1e6a6c8bc03fd42be58 Binary files /dev/null and b/r_facelib/detection/yolov5face/__pycache__/__init__.cpython-311.pyc differ diff --git a/r_facelib/detection/yolov5face/__pycache__/face_detector.cpython-311.pyc b/r_facelib/detection/yolov5face/__pycache__/face_detector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ede0fe43cb89aa9aa5f03a7e65bf2ad08968c4a Binary files /dev/null and b/r_facelib/detection/yolov5face/__pycache__/face_detector.cpython-311.pyc differ diff --git a/r_facelib/detection/yolov5face/face_detector.py b/r_facelib/detection/yolov5face/face_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6d8e3c3d92ddb5ea65decbff8aced90d24ef6f --- /dev/null +++ b/r_facelib/detection/yolov5face/face_detector.py @@ -0,0 +1,141 @@ +import copy +from pathlib import Path + +import cv2 +import numpy as np +import torch +from torch import torch_version + +from r_facelib.detection.yolov5face.models.common import Conv +from r_facelib.detection.yolov5face.models.yolo import Model +from r_facelib.detection.yolov5face.utils.datasets import letterbox +from r_facelib.detection.yolov5face.utils.general import ( + check_img_size, + non_max_suppression_face, + scale_coords, + scale_coords_landmarks, +) + +print(f"Torch version: {torch.__version__}") +IS_HIGH_VERSION = torch_version.__version__ >= "1.9.0" + +def isListempty(inList): + if isinstance(inList, list): # Is a list + return all(map(isListempty, inList)) + return False # Not a list + +class YoloDetector: + def __init__( + self, + config_name, + min_face=10, + target_size=None, + device='cuda', + ): + """ + config_name: name of .yaml config with network configuration from models/ folder. + min_face : minimal face size in pixels. + target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. + None for original resolution. + """ + self._class_path = Path(__file__).parent.absolute() + self.target_size = target_size + self.min_face = min_face + self.detector = Model(cfg=config_name) + self.device = device + + + def _preprocess(self, imgs): + """ + Preprocessing image before passing through the network. Resize and conversion to torch tensor. + """ + pp_imgs = [] + for img in imgs: + h0, w0 = img.shape[:2] # orig hw + if self.target_size: + r = self.target_size / min(h0, w0) # resize image to img_size + if r < 1: + img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) + + imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size + img = letterbox(img, new_shape=imgsz)[0] + pp_imgs.append(img) + pp_imgs = np.array(pp_imgs) + pp_imgs = pp_imgs.transpose(0, 3, 1, 2) + pp_imgs = torch.from_numpy(pp_imgs).to(self.device) + pp_imgs = pp_imgs.float() # uint8 to fp16/32 + return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 + + def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): + """ + Postprocessing of raw pytorch model output. + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + bboxes = [[] for _ in range(len(origimgs))] + landmarks = [[] for _ in range(len(origimgs))] + + pred = non_max_suppression_face(pred, conf_thres, iou_thres) + + for image_id, origimg in enumerate(origimgs): + img_shape = origimg.shape + image_height, image_width = img_shape[:2] + gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh + gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks + det = pred[image_id].cpu() + scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() + scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round() + + for j in range(det.size()[0]): + box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() + box = list( + map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height]) + ) + if box[3] - box[1] < self.min_face: + continue + lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() + lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)])) + lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] + bboxes[image_id].append(box) + landmarks[image_id].append(lm) + return bboxes, landmarks + + def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): + """ + Get bbox coordinates and keypoints of faces on original image. + Params: + imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) + conf_thres: confidence threshold for each prediction + iou_thres: threshold for NMS (filter of intersecting bboxes) + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + # Pass input images through face detector + images = imgs if isinstance(imgs, list) else [imgs] + images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] + origimgs = copy.deepcopy(images) + + images = self._preprocess(images) + + if IS_HIGH_VERSION: + with torch.inference_mode(): # for pytorch>=1.9 + pred = self.detector(images)[0] + else: + with torch.no_grad(): # for pytorch<1.9 + pred = self.detector(images)[0] + + bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres) + + # return bboxes, points + if not isListempty(points): + bboxes = np.array(bboxes).reshape(-1,4) + points = np.array(points).reshape(-1,10) + padding = bboxes[:,0].reshape(-1,1) + return np.concatenate((bboxes, padding, points), axis=1) + else: + return None + + def __call__(self, *args): + return self.predict(*args) diff --git a/r_facelib/detection/yolov5face/models/__init__.py b/r_facelib/detection/yolov5face/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/r_facelib/detection/yolov5face/models/__pycache__/__init__.cpython-311.pyc b/r_facelib/detection/yolov5face/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78e9cf86e6a0d53e6121855717593cbf2817153e Binary files /dev/null and b/r_facelib/detection/yolov5face/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/r_facelib/detection/yolov5face/models/__pycache__/common.cpython-311.pyc b/r_facelib/detection/yolov5face/models/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9db75c3a9ffdaa8e27f3ab7ef27b248cea9b321 Binary files /dev/null and b/r_facelib/detection/yolov5face/models/__pycache__/common.cpython-311.pyc differ diff --git a/r_facelib/detection/yolov5face/models/__pycache__/experimental.cpython-311.pyc b/r_facelib/detection/yolov5face/models/__pycache__/experimental.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b800c410257c1ca893f9aa9b3b413b43d82f291 Binary files /dev/null and b/r_facelib/detection/yolov5face/models/__pycache__/experimental.cpython-311.pyc differ diff --git a/r_facelib/detection/yolov5face/models/__pycache__/yolo.cpython-311.pyc b/r_facelib/detection/yolov5face/models/__pycache__/yolo.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78d46dd5ad785bab28103bfe9c7f84517e359b5d Binary files /dev/null and b/r_facelib/detection/yolov5face/models/__pycache__/yolo.cpython-311.pyc differ diff --git a/r_facelib/detection/yolov5face/models/common.py b/r_facelib/detection/yolov5face/models/common.py new file mode 100644 index 0000000000000000000000000000000000000000..96894d520dcb2b213ba73cc5c13e8c3a7e4add59 --- /dev/null +++ b/r_facelib/detection/yolov5face/models/common.py @@ -0,0 +1,299 @@ +# This file contains modules common to various models + +import math + +import numpy as np +import torch +from torch import nn + +from r_facelib.detection.yolov5face.utils.datasets import letterbox +from r_facelib.detection.yolov5face.utils.general import ( + make_divisible, + non_max_suppression, + scale_coords, + xyxy2xywh, +) + + +def autopad(k, p=None): # kernel, padding + # Pad to 'same' + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = torch.div(num_channels, groups, rounding_mode="trunc") + + # reshape + x = x.view(batchsize, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + return x.view(batchsize, -1, height, width) + + +def DWConv(c1, c2, k=1, s=1, act=True): + # Depthwise convolution + return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act) + + +class Conv(nn.Module): + # Standard convolution + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class StemBlock(nn.Module): + def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True): + super().__init__() + self.stem_1 = Conv(c1, c2, k, s, p, g, act) + self.stem_2a = Conv(c2, c2 // 2, 1, 1, 0) + self.stem_2b = Conv(c2 // 2, c2, 3, 2, 1) + self.stem_2p = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) + self.stem_3 = Conv(c2 * 2, c2, 1, 1, 0) + + def forward(self, x): + stem_1_out = self.stem_1(x) + stem_2a_out = self.stem_2a(stem_1_out) + stem_2b_out = self.stem_2b(stem_2a_out) + stem_2p_out = self.stem_2p(stem_1_out) + return self.stem_3(torch.cat((stem_2b_out, stem_2p_out), 1)) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_, c2, 3, 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class BottleneckCSP(nn.Module): + # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) + self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) + self.cv4 = Conv(2 * c_, c2, 1, 1) + self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) + self.act = nn.LeakyReLU(0.1, inplace=True) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + y1 = self.cv3(self.m(self.cv1(x))) + y2 = self.cv2(x) + return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) + + +class C3(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) + + +class ShuffleV2Block(nn.Module): + def __init__(self, inp, oup, stride): + super().__init__() + + if not 1 <= stride <= 3: + raise ValueError("illegal stride value") + self.stride = stride + + branch_features = oup // 2 + + if self.stride > 1: + self.branch1 = nn.Sequential( + self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(inp), + nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + ) + else: + self.branch1 = nn.Sequential() + + self.branch2 = nn.Sequential( + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(branch_features), + nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + ) + + @staticmethod + def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def forward(self, x): + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + else: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + out = channel_shuffle(out, 2) + return out + + +class SPP(nn.Module): + # Spatial pyramid pooling layer used in YOLOv3-SPP + def __init__(self, c1, c2, k=(5, 9, 13)): + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + + def forward(self, x): + x = self.cv1(x) + return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) + + +class Focus(nn.Module): + # Focus wh information into c-space + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + + def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) + return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) + + +class Concat(nn.Module): + # Concatenate a list of tensors along dimension + def __init__(self, dimension=1): + super().__init__() + self.d = dimension + + def forward(self, x): + return torch.cat(x, self.d) + + +class NMS(nn.Module): + # Non-Maximum Suppression (NMS) module + conf = 0.25 # confidence threshold + iou = 0.45 # IoU threshold + classes = None # (optional list) filter by class + + def forward(self, x): + return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) + + +class AutoShape(nn.Module): + # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS + img_size = 640 # inference size (pixels) + conf = 0.25 # NMS confidence threshold + iou = 0.45 # NMS IoU threshold + classes = None # (optional list) filter by class + + def __init__(self, model): + super().__init__() + self.model = model.eval() + + def autoshape(self): + print("autoShape already enabled, skipping... ") # model already converted to model.autoshape() + return self + + def forward(self, imgs, size=640, augment=False, profile=False): + # Inference from various sources. For height=720, width=1280, RGB images example inputs are: + # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) + # PIL: = Image.open('image.jpg') # HWC x(720,1280,3) + # numpy: = np.zeros((720,1280,3)) # HWC + # torch: = torch.zeros(16,3,720,1280) # BCHW + # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images + + p = next(self.model.parameters()) # for device and type + if isinstance(imgs, torch.Tensor): # torch + return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference + + # Pre-process + n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images + shape0, shape1 = [], [] # image and inference shapes + for i, im in enumerate(imgs): + im = np.array(im) # to numpy + if im.shape[0] < 5: # image in CHW + im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) + im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input + s = im.shape[:2] # HWC + shape0.append(s) # image shape + g = size / max(s) # gain + shape1.append([y * g for y in s]) + imgs[i] = im # update + shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape + x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad + x = np.stack(x, 0) if n > 1 else x[0][None] # stack + x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW + x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0 # uint8 to fp16/32 + + # Inference + with torch.no_grad(): + y = self.model(x, augment, profile)[0] # forward + y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + + # Post-process + for i in range(n): + scale_coords(shape1, y[i][:, :4], shape0[i]) + + return Detections(imgs, y, self.names) + + +class Detections: + # detections class for YOLOv5 inference results + def __init__(self, imgs, pred, names=None): + super().__init__() + d = pred[0].device # device + gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1.0, 1.0], device=d) for im in imgs] # normalizations + self.imgs = imgs # list of images as numpy arrays + self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) + self.names = names # class names + self.xyxy = pred # xyxy pixels + self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels + self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized + self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized + self.n = len(self.pred) + + def __len__(self): + return self.n + + def tolist(self): + # return a list of Detections objects, i.e. 'for result in results.tolist():' + x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)] + for d in x: + for k in ["imgs", "pred", "xyxy", "xyxyn", "xywh", "xywhn"]: + setattr(d, k, getattr(d, k)[0]) # pop out of list + return x diff --git a/r_facelib/detection/yolov5face/models/experimental.py b/r_facelib/detection/yolov5face/models/experimental.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf7aea88b75bd7fb938f4013ca79160b2c2b18c --- /dev/null +++ b/r_facelib/detection/yolov5face/models/experimental.py @@ -0,0 +1,45 @@ +# # This file contains experimental modules + +import numpy as np +import torch +from torch import nn + +from r_facelib.detection.yolov5face.models.common import Conv + + +class CrossConv(nn.Module): + # Cross Convolution Downsample + def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): + # ch_in, ch_out, kernel, stride, groups, expansion, shortcut + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, (1, k), (1, s)) + self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class MixConv2d(nn.Module): + # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 + def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): + super().__init__() + groups = len(k) + if equal_ch: # equal c_ per group + i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices + c_ = [(i == g).sum() for g in range(groups)] # intermediate channels + else: # equal weight.numel() per group + b = [c2] + [0] * groups + a = np.eye(groups + 1, groups, k=-1) + a -= np.roll(a, 1, axis=1) + a *= np.array(k) ** 2 + a[0] = 1 + c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b + + self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.LeakyReLU(0.1, inplace=True) + + def forward(self, x): + return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) diff --git a/r_facelib/detection/yolov5face/models/yolo.py b/r_facelib/detection/yolov5face/models/yolo.py new file mode 100644 index 0000000000000000000000000000000000000000..02479dc1f9c429e618e938794aba0e948b14d303 --- /dev/null +++ b/r_facelib/detection/yolov5face/models/yolo.py @@ -0,0 +1,235 @@ +import math +from copy import deepcopy +from pathlib import Path + +import torch +import yaml # for torch hub +from torch import nn + +from r_facelib.detection.yolov5face.models.common import ( + C3, + NMS, + SPP, + AutoShape, + Bottleneck, + BottleneckCSP, + Concat, + Conv, + DWConv, + Focus, + ShuffleV2Block, + StemBlock, +) +from r_facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d +from r_facelib.detection.yolov5face.utils.autoanchor import check_anchor_order +from r_facelib.detection.yolov5face.utils.general import make_divisible +from r_facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn + + +class Detect(nn.Module): + stride = None # strides computed during build + export = False # onnx export + + def __init__(self, nc=80, anchors=(), ch=()): # detection layer + super().__init__() + self.nc = nc # number of classes + self.no = nc + 5 + 10 # number of outputs per anchor + + self.nl = len(anchors) # number of detection layers + self.na = len(anchors[0]) // 2 # number of anchors + self.grid = [torch.zeros(1)] * self.nl # init grid + a = torch.tensor(anchors).float().view(self.nl, -1, 2) + self.register_buffer("anchors", a) # shape(nl,na,2) + self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) + self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + + def forward(self, x): + z = [] # inference output + if self.export: + for i in range(self.nl): + x[i] = self.m[i](x[i]) + return x + for i in range(self.nl): + x[i] = self.m[i](x[i]) # conv + bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) + x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + + if not self.training: # inference + if self.grid[i].shape[2:4] != x[i].shape[2:4]: + self.grid[i] = self._make_grid(nx, ny).to(x[i].device) + + y = torch.full_like(x[i], 0) + y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid() + y[..., 5:15] = x[i][..., 5:15] + + y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + + y[..., 5:7] = ( + y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x1 y1 + y[..., 7:9] = ( + y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x2 y2 + y[..., 9:11] = ( + y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x3 y3 + y[..., 11:13] = ( + y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x4 y4 + y[..., 13:15] = ( + y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x5 y5 + + z.append(y.view(bs, -1, self.no)) + + return x if self.training else (torch.cat(z, 1), x) + + @staticmethod + def _make_grid(nx=20, ny=20): + # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") # for pytorch>=1.10 + yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) + return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() + + +class Model(nn.Module): + def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes + super().__init__() + self.yaml_file = Path(cfg).name + with Path(cfg).open(encoding="utf8") as f: + self.yaml = yaml.safe_load(f) # model dict + + # Define model + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: + self.yaml["nc"] = nc # override yaml value + + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist + self.names = [str(i) for i in range(self.yaml["nc"])] # default names + + # Build strides, anchors + m = self.model[-1] # Detect() + if isinstance(m, Detect): + s = 128 # 2x min stride + m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward + m.anchors /= m.stride.view(-1, 1, 1) + check_anchor_order(m) + self.stride = m.stride + self._initialize_biases() # only run once + + def forward(self, x): + return self.forward_once(x) # single-scale inference, train + + def forward_once(self, x): + y = [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + + return x + + def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency + # https://arxiv.org/abs/1708.02002 section 3.3 + m = self.model[-1] # Detect() module + for mi, s in zip(m.m, m.stride): # from + b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) + b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) + b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls + mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + def _print_biases(self): + m = self.model[-1] # Detect() module + for mi in m.m: # from + b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) + print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) + + def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers + print("Fusing layers... ") + for m in self.model.modules(): + if isinstance(m, Conv) and hasattr(m, "bn"): + m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv + delattr(m, "bn") # remove batchnorm + m.forward = m.fuseforward # update forward + elif type(m) is nn.Upsample: + m.recompute_scale_factor = None # torch 1.11.0 compatibility + return self + + def nms(self, mode=True): # add or remove NMS module + present = isinstance(self.model[-1], NMS) # last layer is NMS + if mode and not present: + print("Adding NMS... ") + m = NMS() # module + m.f = -1 # from + m.i = self.model[-1].i + 1 # index + self.model.add_module(name=str(m.i), module=m) # add + self.eval() + elif not mode and present: + print("Removing NMS... ") + self.model = self.model[:-1] # remove + return self + + def autoshape(self): # add autoShape module + print("Adding autoShape... ") + m = AutoShape(self) # wrap model + copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes + return m + + +def parse_model(d, ch): # model_dict, input_channels(3) + anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"] + na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors + no = na * (nc + 5) # number of outputs = anchors * (classes + 5) + + layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out + for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args + m = eval(m) if isinstance(m, str) else m # eval strings + for j, a in enumerate(args): + try: + args[j] = eval(a) if isinstance(a, str) else a # eval strings + except: + pass + + n = max(round(n * gd), 1) if n > 1 else n # depth gain + if m in [ + Conv, + Bottleneck, + SPP, + DWConv, + MixConv2d, + Focus, + CrossConv, + BottleneckCSP, + C3, + ShuffleV2Block, + StemBlock, + ]: + c1, c2 = ch[f], args[0] + + c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 + + args = [c1, c2, *args[1:]] + if m in [BottleneckCSP, C3]: + args.insert(2, n) + n = 1 + elif m is nn.BatchNorm2d: + args = [ch[f]] + elif m is Concat: + c2 = sum(ch[-1 if x == -1 else x + 1] for x in f) + elif m is Detect: + args.append([ch[x + 1] for x in f]) + if isinstance(args[1], int): # number of anchors + args[1] = [list(range(args[1] * 2))] * len(f) + else: + c2 = ch[f] + + m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module + t = str(m)[8:-2].replace("__main__.", "") # module type + np = sum(x.numel() for x in m_.parameters()) # number params + m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params + save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist + layers.append(m_) + ch.append(c2) + return nn.Sequential(*layers), sorted(save) diff --git a/r_facelib/detection/yolov5face/models/yolov5l.yaml b/r_facelib/detection/yolov5face/models/yolov5l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..98a9e2cc51b7fb244b27f7797b652031a937421b --- /dev/null +++ b/r_facelib/detection/yolov5face/models/yolov5l.yaml @@ -0,0 +1,47 @@ +# parameters +nc: 1 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [4,5, 8,10, 13,16] # P3/8 + - [23,29, 43,55, 73,105] # P4/16 + - [146,217, 231,300, 335,433] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 2-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 4-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32 + [-1, 1, SPP, [1024, [3,5,7]]], + [-1, 3, C3, [1024, False]], # 8 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 5], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 12 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 3], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 16 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 13], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 19 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 9], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 22 (P5/32-large) + + [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] \ No newline at end of file diff --git a/r_facelib/detection/yolov5face/models/yolov5n.yaml b/r_facelib/detection/yolov5face/models/yolov5n.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a03fb09ba0c8ee86631d6dc6211b513b1a36435 --- /dev/null +++ b/r_facelib/detection/yolov5face/models/yolov5n.yaml @@ -0,0 +1,45 @@ +# parameters +nc: 1 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [4,5, 8,10, 13,16] # P3/8 + - [23,29, 43,55, 73,105] # P4/16 + - [146,217, 231,300, 335,433] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4 + [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8 + [-1, 3, ShuffleV2Block, [128, 1]], # 2 + [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16 + [-1, 7, ShuffleV2Block, [256, 1]], # 4 + [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32 + [-1, 3, ShuffleV2Block, [512, 1]], # 6 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P4 + [-1, 1, C3, [128, False]], # 10 + + [-1, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 2], 1, Concat, [1]], # cat backbone P3 + [-1, 1, C3, [128, False]], # 14 (P3/8-small) + + [-1, 1, Conv, [128, 3, 2]], + [[-1, 11], 1, Concat, [1]], # cat head P4 + [-1, 1, C3, [128, False]], # 17 (P4/16-medium) + + [-1, 1, Conv, [128, 3, 2]], + [[-1, 7], 1, Concat, [1]], # cat head P5 + [-1, 1, C3, [128, False]], # 20 (P5/32-large) + + [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/r_facelib/detection/yolov5face/utils/__init__.py b/r_facelib/detection/yolov5face/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/r_facelib/detection/yolov5face/utils/__pycache__/__init__.cpython-311.pyc b/r_facelib/detection/yolov5face/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd76919f993128b38ae3f8e14b3ad4a34d522462 Binary files /dev/null and b/r_facelib/detection/yolov5face/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/r_facelib/detection/yolov5face/utils/__pycache__/autoanchor.cpython-311.pyc b/r_facelib/detection/yolov5face/utils/__pycache__/autoanchor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83ea76a138d4db6b9ec343adf53a52a6b43456d0 Binary files /dev/null and b/r_facelib/detection/yolov5face/utils/__pycache__/autoanchor.cpython-311.pyc differ diff --git a/r_facelib/detection/yolov5face/utils/__pycache__/datasets.cpython-311.pyc b/r_facelib/detection/yolov5face/utils/__pycache__/datasets.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0eef452f3303b1f205bdb60fd36ed616e280f9d8 Binary files /dev/null and b/r_facelib/detection/yolov5face/utils/__pycache__/datasets.cpython-311.pyc differ diff --git a/r_facelib/detection/yolov5face/utils/__pycache__/general.cpython-311.pyc b/r_facelib/detection/yolov5face/utils/__pycache__/general.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b09d1f309a91b9caed50a9b04a939d7f7edc090 Binary files /dev/null and b/r_facelib/detection/yolov5face/utils/__pycache__/general.cpython-311.pyc differ diff --git a/r_facelib/detection/yolov5face/utils/__pycache__/torch_utils.cpython-311.pyc b/r_facelib/detection/yolov5face/utils/__pycache__/torch_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82bdd2e7494193067e7565212e43e127de1b3713 Binary files /dev/null and b/r_facelib/detection/yolov5face/utils/__pycache__/torch_utils.cpython-311.pyc differ diff --git a/r_facelib/detection/yolov5face/utils/autoanchor.py b/r_facelib/detection/yolov5face/utils/autoanchor.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0de894f5bf448a1442e7c479fe2ae80cb9acb3 --- /dev/null +++ b/r_facelib/detection/yolov5face/utils/autoanchor.py @@ -0,0 +1,12 @@ +# Auto-anchor utils + + +def check_anchor_order(m): + # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary + a = m.anchor_grid.prod(-1).view(-1) # anchor area + da = a[-1] - a[0] # delta a + ds = m.stride[-1] - m.stride[0] # delta s + if da.sign() != ds.sign(): # same order + print("Reversing anchor order") + m.anchors[:] = m.anchors.flip(0) + m.anchor_grid[:] = m.anchor_grid.flip(0) diff --git a/r_facelib/detection/yolov5face/utils/datasets.py b/r_facelib/detection/yolov5face/utils/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..a72609b420e37d94104ec244cd7cac899bc4d29a --- /dev/null +++ b/r_facelib/detection/yolov5face/utils/datasets.py @@ -0,0 +1,35 @@ +import cv2 +import numpy as np + + +def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): + # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 + shape = img.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better test mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding + elif scale_fill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return img, ratio, (dw, dh) diff --git a/r_facelib/detection/yolov5face/utils/extract_ckpt.py b/r_facelib/detection/yolov5face/utils/extract_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..e6bde00b034586d68ac95c1e819530e03261bbba --- /dev/null +++ b/r_facelib/detection/yolov5face/utils/extract_ckpt.py @@ -0,0 +1,5 @@ +import torch +import sys +sys.path.insert(0,'./facelib/detection/yolov5face') +model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model'] +torch.save(model.state_dict(),'../../models/facedetection') \ No newline at end of file diff --git a/r_facelib/detection/yolov5face/utils/general.py b/r_facelib/detection/yolov5face/utils/general.py new file mode 100644 index 0000000000000000000000000000000000000000..618d2f31ad818dadf7ee99bfeb57976a12af1088 --- /dev/null +++ b/r_facelib/detection/yolov5face/utils/general.py @@ -0,0 +1,271 @@ +import math +import time + +import numpy as np +import torch +import torchvision + + +def check_img_size(img_size, s=32): + # Verify img_size is a multiple of stride s + new_size = make_divisible(img_size, int(s)) # ceil gs-multiple + # if new_size != img_size: + # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}") + return new_size + + +def make_divisible(x, divisor): + # Returns x evenly divisible by divisor + return math.ceil(x / divisor) * divisor + + +def xyxy2xywh(x): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center + y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center + y[:, 2] = x[:, 2] - x[:, 0] # width + y[:, 3] = x[:, 3] - x[:, 1] # height + return y + + +def xywh2xyxy(x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2]] -= pad[0] # x padding + coords[:, [1, 3]] -= pad[1] # y padding + coords[:, :4] /= gain + clip_coords(coords, img0_shape) + return coords + + +def clip_coords(boxes, img_shape): + # Clip bounding xyxy bounding boxes to image shape (height, width) + boxes[:, 0].clamp_(0, img_shape[1]) # x1 + boxes[:, 1].clamp_(0, img_shape[0]) # y1 + boxes[:, 2].clamp_(0, img_shape[1]) # x2 + boxes[:, 3].clamp_(0, img_shape[0]) # y2 + + +def box_iou(box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) + + +def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 15 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label = labels[xi] + v = torch.zeros((len(label), nc + 15), device=x.device) + v[:, :4] = label[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, landmarks, cls) + if multi_label: + i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 15:].max(1, keepdim=True) + x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # If none remain process next image + n = x.shape[0] # number of boxes + if not n: + continue + + # Batched NMS + c = x[:, 15:16] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + break # time limit exceeded + + return output + + +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label_id = labels[xi] + v = torch.zeros((len(label_id), nc + 5), device=x.device) + v[:, :4] = label_id[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + + x = x[x[:, 4].argsort(descending=True)] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + print(f"WARNING: NMS time limit {time_limit}s exceeded") + break # time limit exceeded + + return output + + +def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding + coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding + coords[:, :10] /= gain + coords[:, 0].clamp_(0, img0_shape[1]) # x1 + coords[:, 1].clamp_(0, img0_shape[0]) # y1 + coords[:, 2].clamp_(0, img0_shape[1]) # x2 + coords[:, 3].clamp_(0, img0_shape[0]) # y2 + coords[:, 4].clamp_(0, img0_shape[1]) # x3 + coords[:, 5].clamp_(0, img0_shape[0]) # y3 + coords[:, 6].clamp_(0, img0_shape[1]) # x4 + coords[:, 7].clamp_(0, img0_shape[0]) # y4 + coords[:, 8].clamp_(0, img0_shape[1]) # x5 + coords[:, 9].clamp_(0, img0_shape[0]) # y5 + return coords diff --git a/r_facelib/detection/yolov5face/utils/torch_utils.py b/r_facelib/detection/yolov5face/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f70296232c59f291f7dfb4f75ef23f71de6c5dd6 --- /dev/null +++ b/r_facelib/detection/yolov5face/utils/torch_utils.py @@ -0,0 +1,40 @@ +import torch +from torch import nn + + +def fuse_conv_and_bn(conv, bn): + # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ + fusedconv = ( + nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) + + # prepare filters + w_conv = conv.weight.clone().view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) + + # prepare spatial bias + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fusedconv + + +def copy_attr(a, b, include=(), exclude=()): + # Copy attributes from b to a, options to only include [...] and to exclude [...] + for k, v in b.__dict__.items(): + if (include and k not in include) or k.startswith("_") or k in exclude: + continue + + setattr(a, k, v) diff --git a/r_facelib/parsing/__init__.py b/r_facelib/parsing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5aaa285f1efff228a04d6f092899883b3474c92 --- /dev/null +++ b/r_facelib/parsing/__init__.py @@ -0,0 +1,23 @@ +import torch + +from r_facelib.utils import load_file_from_url +from .bisenet import BiSeNet +from .parsenet import ParseNet + + +def init_parsing_model(model_name='bisenet', half=False, device='cuda'): + if model_name == 'bisenet': + model = BiSeNet(num_class=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth' + elif model_name == 'parsenet': + model = ParseNet(in_size=512, out_size=512, parsing_ch=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='../../models/facedetection', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + return model diff --git a/r_facelib/parsing/__pycache__/__init__.cpython-311.pyc b/r_facelib/parsing/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8656bf5080f4b3033d7e8c1e9df2631743787d06 Binary files /dev/null and b/r_facelib/parsing/__pycache__/__init__.cpython-311.pyc differ diff --git a/r_facelib/parsing/__pycache__/bisenet.cpython-311.pyc b/r_facelib/parsing/__pycache__/bisenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5640f4c45372e2402abca051842132ad89ef1b5 Binary files /dev/null and b/r_facelib/parsing/__pycache__/bisenet.cpython-311.pyc differ diff --git a/r_facelib/parsing/__pycache__/parsenet.cpython-311.pyc b/r_facelib/parsing/__pycache__/parsenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d04e747272be5fbed87f7fe0d96c80eb4db777f5 Binary files /dev/null and b/r_facelib/parsing/__pycache__/parsenet.cpython-311.pyc differ diff --git a/r_facelib/parsing/__pycache__/resnet.cpython-311.pyc b/r_facelib/parsing/__pycache__/resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db52f645e9c78d75587a4882185eedcfcef09fff Binary files /dev/null and b/r_facelib/parsing/__pycache__/resnet.cpython-311.pyc differ diff --git a/r_facelib/parsing/bisenet.py b/r_facelib/parsing/bisenet.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7a084c26abef2d3e71a6bbba8c2d14313e2892 --- /dev/null +++ b/r_facelib/parsing/bisenet.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .resnet import ResNet18 + + +class ConvBNReLU(nn.Module): + + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_chan) + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + +class BiSeNetOutput(nn.Module): + + def __init__(self, in_chan, mid_chan, num_class): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) + + def forward(self, x): + feat = self.conv(x) + out = self.conv_out(feat) + return out, feat + + +class AttentionRefinementModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + +class ContextPath(nn.Module): + + def __init__(self): + super(ContextPath, self).__init__() + self.resnet = ResNet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + def forward(self, x): + feat8, feat16, feat32 = self.resnet(x) + h8, w8 = feat8.size()[2:] + h16, w16 = feat16.size()[2:] + h32, w32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (h32, w32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + +class FeatureFusionModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) + self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + +class BiSeNet(nn.Module): + + def __init__(self, num_class): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, num_class) + self.conv_out16 = BiSeNetOutput(128, 64, num_class) + self.conv_out32 = BiSeNetOutput(128, 64, num_class) + + def forward(self, x, return_feat=False): + h, w = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature + feat_sp = feat_res8 # replace spatial path feature with res3b1 feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + out, feat = self.conv_out(feat_fuse) + out16, feat16 = self.conv_out16(feat_cp8) + out32, feat32 = self.conv_out32(feat_cp16) + + out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) + out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) + out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) + + if return_feat: + feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) + feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) + feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) + return out, out16, out32, feat, feat16, feat32 + else: + return out, out16, out32 diff --git a/r_facelib/parsing/parsenet.py b/r_facelib/parsing/parsenet.py new file mode 100644 index 0000000000000000000000000000000000000000..2e80921b41b0a2c452830ecf7a79456b8f431597 --- /dev/null +++ b/r_facelib/parsing/parsenet.py @@ -0,0 +1,194 @@ +"""Modified from https://github.com/chaofengc/PSFRGAN +""" +import numpy as np +import torch.nn as nn +from torch.nn import functional as F + + +class NormLayer(nn.Module): + """Normalization Layers. + + Args: + channels: input channels, for batch norm and instance norm. + input_size: input shape without batch size, for layer norm. + """ + + def __init__(self, channels, normalize_shape=None, norm_type='bn'): + super(NormLayer, self).__init__() + norm_type = norm_type.lower() + self.norm_type = norm_type + if norm_type == 'bn': + self.norm = nn.BatchNorm2d(channels, affine=True) + elif norm_type == 'in': + self.norm = nn.InstanceNorm2d(channels, affine=False) + elif norm_type == 'gn': + self.norm = nn.GroupNorm(32, channels, affine=True) + elif norm_type == 'pixel': + self.norm = lambda x: F.normalize(x, p=2, dim=1) + elif norm_type == 'layer': + self.norm = nn.LayerNorm(normalize_shape) + elif norm_type == 'none': + self.norm = lambda x: x * 1.0 + else: + assert 1 == 0, f'Norm type {norm_type} not support.' + + def forward(self, x, ref=None): + if self.norm_type == 'spade': + return self.norm(x, ref) + else: + return self.norm(x) + + +class ReluLayer(nn.Module): + """Relu Layer. + + Args: + relu type: type of relu layer, candidates are + - ReLU + - LeakyReLU: default relu slope 0.2 + - PRelu + - SELU + - none: direct pass + """ + + def __init__(self, channels, relu_type='relu'): + super(ReluLayer, self).__init__() + relu_type = relu_type.lower() + if relu_type == 'relu': + self.func = nn.ReLU(True) + elif relu_type == 'leakyrelu': + self.func = nn.LeakyReLU(0.2, inplace=True) + elif relu_type == 'prelu': + self.func = nn.PReLU(channels) + elif relu_type == 'selu': + self.func = nn.SELU(True) + elif relu_type == 'none': + self.func = lambda x: x * 1.0 + else: + assert 1 == 0, f'Relu type {relu_type} not support.' + + def forward(self, x): + return self.func(x) + + +class ConvLayer(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + scale='none', + norm_type='none', + relu_type='none', + use_pad=True, + bias=True): + super(ConvLayer, self).__init__() + self.use_pad = use_pad + self.norm_type = norm_type + if norm_type in ['bn']: + bias = False + + stride = 2 if scale == 'down' else 1 + + self.scale_func = lambda x: x + if scale == 'up': + self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') + + self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) + + self.relu = ReluLayer(out_channels, relu_type) + self.norm = NormLayer(out_channels, norm_type=norm_type) + + def forward(self, x): + out = self.scale_func(x) + if self.use_pad: + out = self.reflection_pad(out) + out = self.conv2d(out) + out = self.norm(out) + out = self.relu(out) + return out + + +class ResidualBlock(nn.Module): + """ + Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html + """ + + def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): + super(ResidualBlock, self).__init__() + + if scale == 'none' and c_in == c_out: + self.shortcut_func = lambda x: x + else: + self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) + + scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} + scale_conf = scale_config_dict[scale] + + self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) + self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') + + def forward(self, x): + identity = self.shortcut_func(x) + + res = self.conv1(x) + res = self.conv2(res) + return identity + res + + +class ParseNet(nn.Module): + + def __init__(self, + in_size=128, + out_size=128, + min_feat_size=32, + base_ch=64, + parsing_ch=19, + res_depth=10, + relu_type='LeakyReLU', + norm_type='bn', + ch_range=[32, 256]): + super().__init__() + self.res_depth = res_depth + act_args = {'norm_type': norm_type, 'relu_type': relu_type} + min_ch, max_ch = ch_range + + ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 + min_feat_size = min(in_size, min_feat_size) + + down_steps = int(np.log2(in_size // min_feat_size)) + up_steps = int(np.log2(out_size // min_feat_size)) + + # =============== define encoder-body-decoder ==================== + self.encoder = [] + self.encoder.append(ConvLayer(3, base_ch, 3, 1)) + head_ch = base_ch + for i in range(down_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) + self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) + head_ch = head_ch * 2 + + self.body = [] + for i in range(res_depth): + self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) + + self.decoder = [] + for i in range(up_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) + self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) + head_ch = head_ch // 2 + + self.encoder = nn.Sequential(*self.encoder) + self.body = nn.Sequential(*self.body) + self.decoder = nn.Sequential(*self.decoder) + self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) + self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) + + def forward(self, x): + feat = self.encoder(x) + x = feat + self.body(feat) + x = self.decoder(x) + out_img = self.out_img_conv(x) + out_mask = self.out_mask_conv(x) + return out_mask, out_img diff --git a/r_facelib/parsing/resnet.py b/r_facelib/parsing/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e7cc283df2869ca4ccae8020c39f21a6480c01c2 --- /dev/null +++ b/r_facelib/parsing/resnet.py @@ -0,0 +1,69 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum - 1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class ResNet18(nn.Module): + + def __init__(self): + super(ResNet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 diff --git a/r_facelib/utils/__init__.py b/r_facelib/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3397bda92eb2c6a9dcd7b5921eb8c18fc7d97cca --- /dev/null +++ b/r_facelib/utils/__init__.py @@ -0,0 +1,7 @@ +from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back +from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir + +__all__ = [ + 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', + 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir' +] diff --git a/r_facelib/utils/__pycache__/__init__.cpython-311.pyc b/r_facelib/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c37c26b5979b8d8851d6926fcf0fa1c838b121f Binary files /dev/null and b/r_facelib/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/r_facelib/utils/__pycache__/face_restoration_helper.cpython-311.pyc b/r_facelib/utils/__pycache__/face_restoration_helper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56b6e3c56765dd977015472e971068e2b711d330 Binary files /dev/null and b/r_facelib/utils/__pycache__/face_restoration_helper.cpython-311.pyc differ diff --git a/r_facelib/utils/__pycache__/face_utils.cpython-311.pyc b/r_facelib/utils/__pycache__/face_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78279bc2e3b3c7b17f2ea6ea6e152a5f0172fcba Binary files /dev/null and b/r_facelib/utils/__pycache__/face_utils.cpython-311.pyc differ diff --git a/r_facelib/utils/__pycache__/misc.cpython-311.pyc b/r_facelib/utils/__pycache__/misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b345aa2f976cc6f23949feb7e919b042ba788918 Binary files /dev/null and b/r_facelib/utils/__pycache__/misc.cpython-311.pyc differ diff --git a/r_facelib/utils/face_restoration_helper.py b/r_facelib/utils/face_restoration_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1db75c9df0aa2607fc6f377d35dc183a5a77b4d1 --- /dev/null +++ b/r_facelib/utils/face_restoration_helper.py @@ -0,0 +1,455 @@ +import cv2 +import numpy as np +import os +import torch +from torchvision.transforms.functional import normalize + +from r_facelib.detection import init_detection_model +from r_facelib.parsing import init_parsing_model +from r_facelib.utils.misc import img2tensor, imwrite + + +def get_largest_face(det_faces, h, w): + + def get_location(val, length): + if val < 0: + return 0 + elif val > length: + return length + else: + return val + + face_areas = [] + for det_face in det_faces: + left = get_location(det_face[0], w) + right = get_location(det_face[2], w) + top = get_location(det_face[1], h) + bottom = get_location(det_face[3], h) + face_area = (right - left) * (bottom - top) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + return det_faces[largest_idx], largest_idx + + +def get_center_face(det_faces, h=0, w=0, center=None): + if center is not None: + center = np.array(center) + else: + center = np.array([w / 2, h / 2]) + center_dist = [] + for det_face in det_faces: + face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]) + dist = np.linalg.norm(face_center - center) + center_dist.append(dist) + center_idx = center_dist.index(min(center_dist)) + return det_faces[center_idx], center_idx + + +class FaceRestoreHelper(object): + """Helper for the face restoration pipeline (base class).""" + + def __init__(self, + upscale_factor, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + template_3points=False, + pad_blur=False, + use_parse=False, + device=None): + self.template_3points = template_3points # improve robustness + self.upscale_factor = upscale_factor + # the cropped face ratio based on the square face + self.crop_ratio = crop_ratio # (h, w) + assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1' + self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0])) + + if self.template_3points: + self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) + else: + # standard 5 landmarks for FFHQ faces with 512 x 512 + # facexlib + self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935], + [201.26117, 371.41043], [313.08905, 371.15118]]) + + # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54 + # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894], + # [198.22603, 372.82502], [313.91018, 372.75659]]) + + + self.face_template = self.face_template * (face_size / 512.0) + if self.crop_ratio[0] > 1: + self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 + if self.crop_ratio[1] > 1: + self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 + self.save_ext = save_ext + self.pad_blur = pad_blur + if self.pad_blur is True: + self.template_3points = False + + self.all_landmarks_5 = [] + self.det_faces = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.pad_input_imgs = [] + + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = device + + # init face detection model + self.face_det = init_detection_model(det_model, half=False, device=self.device) + + # init face parsing model + self.use_parse = use_parse + self.face_parse = init_parsing_model(model_name='parsenet', device=self.device) + + def set_upscale_factor(self, upscale_factor): + self.upscale_factor = upscale_factor + + def read_image(self, img): + """img can be image path or cv2 loaded image.""" + # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255] + if isinstance(img, str): + img = cv2.imread(img) + + if np.max(img) > 256: # 16-bit image + img = img / 65535 * 255 + if len(img.shape) == 2: # gray image + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif img.shape[2] == 4: # BGRA image with alpha channel + img = img[:, :, 0:3] + + self.input_img = img + + if min(self.input_img.shape[:2])<512: + f = 512.0/min(self.input_img.shape[:2]) + self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR) + + def get_face_landmarks_5(self, + only_keep_largest=False, + only_center_face=False, + resize=None, + blur_ratio=0.01, + eye_dist_threshold=None): + if resize is None: + scale = 1 + input_img = self.input_img + else: + h, w = self.input_img.shape[0:2] + scale = resize / min(h, w) + scale = max(1, scale) # always scale up + h, w = int(h * scale), int(w * scale) + interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR + input_img = cv2.resize(self.input_img, (w, h), interpolation=interp) + + with torch.no_grad(): + bboxes = self.face_det.detect_faces(input_img) + + if bboxes is None or bboxes.shape[0] == 0: + return 0 + else: + bboxes = bboxes / scale + + for bbox in bboxes: + # remove faces with too small eye distance: side faces or too small faces + eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]]) + if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold): + continue + + if self.template_3points: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)]) + else: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)]) + self.all_landmarks_5.append(landmark) + self.det_faces.append(bbox[0:5]) + + if len(self.det_faces) == 0: + return 0 + if only_keep_largest: + h, w, _ = self.input_img.shape + self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]] + elif only_center_face: + h, w, _ = self.input_img.shape + self.det_faces, center_idx = get_center_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[center_idx]] + + # pad blurry images + if self.pad_blur: + self.pad_input_imgs = [] + for landmarks in self.all_landmarks_5: + # get landmarks + eye_left = landmarks[0, :] + eye_right = landmarks[1, :] + eye_avg = (eye_left + eye_right) * 0.5 + mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1.5 + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + border = max(int(np.rint(qsize * 0.1)), 3) + + # get pad + # pad: (width_left, height_top, width_right, height_bottom) + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = [ + max(-pad[0] + border, 1), + max(-pad[1] + border, 1), + max(pad[2] - self.input_img.shape[0] + border, 1), + max(pad[3] - self.input_img.shape[1] + border, 1) + ] + + if max(pad) > 1: + # pad image + pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + # modify landmark coords + landmarks[:, 0] += pad[0] + landmarks[:, 1] += pad[1] + # blur pad images + h, w, _ = pad_img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * blur_ratio) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur)) + # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0) + + pad_img = pad_img.astype('float32') + pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0) + pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255] + self.pad_input_imgs.append(pad_img) + else: + self.pad_input_imgs.append(np.copy(self.input_img)) + + return len(self.all_landmarks_5) + + def align_warp_face(self, save_cropped_path=None, border_mode='constant'): + """Align and warp faces with face template. + """ + if self.pad_blur: + assert len(self.pad_input_imgs) == len( + self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}' + for idx, landmark in enumerate(self.all_landmarks_5): + # use 5 landmarks to get affine matrix + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + if border_mode == 'constant': + border_mode = cv2.BORDER_CONSTANT + elif border_mode == 'reflect101': + border_mode = cv2.BORDER_REFLECT101 + elif border_mode == 'reflect': + border_mode = cv2.BORDER_REFLECT + if self.pad_blur: + input_img = self.pad_input_imgs[idx] + else: + input_img = self.input_img + cropped_face = cv2.warpAffine( + input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path = os.path.splitext(save_cropped_path)[0] + save_path = f'{path}_{idx:02d}.{self.save_ext}' + imwrite(cropped_face, save_path) + + def get_inverse_affine(self, save_inverse_affine_path=None): + """Get inverse affine matrix.""" + for idx, affine_matrix in enumerate(self.affine_matrices): + inverse_affine = cv2.invertAffineTransform(affine_matrix) + inverse_affine *= self.upscale_factor + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) + + + def add_restored_face(self, face): + self.restored_faces.append(face) + + + def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None): + h, w, _ = self.input_img.shape + h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor) + + if upsample_img is None: + # simply resize the background + # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR) + else: + upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + + assert len(self.restored_faces) == len( + self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.') + + inv_mask_borders = [] + for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices): + if face_upsampler is not None: + restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0] + inverse_affine /= self.upscale_factor + inverse_affine[:, 2] *= self.upscale_factor + face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor) + else: + # Add an offset to inverse affine matrix, for more precise back alignment + if self.upscale_factor > 1: + extra_offset = 0.5 * self.upscale_factor + else: + extra_offset = 0 + inverse_affine[:, 2] += extra_offset + face_size = self.face_size + inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) + + # if draw_box or not self.use_parse: # use square parse maps + # mask = np.ones(face_size, dtype=np.float32) + # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # # remove the black borders + # inv_mask_erosion = cv2.erode( + # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + # pasted_face = inv_mask_erosion[:, :, None] * inv_restored + # total_face_area = np.sum(inv_mask_erosion) # // 3 + # # add border + # if draw_box: + # h, w = face_size + # mask_border = np.ones((h, w, 3), dtype=np.float32) + # border = int(1400/np.sqrt(total_face_area)) + # mask_border[border:h-border, border:w-border,:] = 0 + # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + # inv_mask_borders.append(inv_mask_border) + # if not self.use_parse: + # # compute the fusion edge based on the area of face + # w_edge = int(total_face_area**0.5) // 20 + # erosion_radius = w_edge * 2 + # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + # blur_size = w_edge * 2 + # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + # if len(upsample_img.shape) == 2: # upsample_img is gray image + # upsample_img = upsample_img[:, :, None] + # inv_soft_mask = inv_soft_mask[:, :, None] + + # always use square mask + mask = np.ones(face_size, dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + pasted_face = inv_mask_erosion[:, :, None] * inv_restored + total_face_area = np.sum(inv_mask_erosion) # // 3 + # add border + if draw_box: + h, w = face_size + mask_border = np.ones((h, w, 3), dtype=np.float32) + border = int(1400/np.sqrt(total_face_area)) + mask_border[border:h-border, border:w-border,:] = 0 + inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + inv_mask_borders.append(inv_mask_border) + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + if len(upsample_img.shape) == 2: # upsample_img is gray image + upsample_img = upsample_img[:, :, None] + inv_soft_mask = inv_soft_mask[:, :, None] + + # parse mask + if self.use_parse: + # inference + face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR) + face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True) + normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + face_input = torch.unsqueeze(face_input, 0).to(self.device) + with torch.no_grad(): + out = self.face_parse(face_input)[0] + out = out.argmax(dim=1).squeeze().cpu().numpy() + + parse_mask = np.zeros(out.shape) + MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0] + for idx, color in enumerate(MASK_COLORMAP): + parse_mask[out == idx] = color + # blur the mask + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + # remove the black borders + thres = 10 + parse_mask[:thres, :] = 0 + parse_mask[-thres:, :] = 0 + parse_mask[:, :thres] = 0 + parse_mask[:, -thres:] = 0 + parse_mask = parse_mask / 255. + + parse_mask = cv2.resize(parse_mask, face_size) + parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3) + inv_soft_parse_mask = parse_mask[:, :, None] + # pasted_face = inv_restored + fuse_mask = (inv_soft_parse_mask 256: # 16-bit image + upsample_img = upsample_img.astype(np.uint16) + else: + upsample_img = upsample_img.astype(np.uint8) + + # draw bounding box + if draw_box: + # upsample_input_img = cv2.resize(input_img, (w_up, h_up)) + img_color = np.ones([*upsample_img.shape], dtype=np.float32) + img_color[:,:,0] = 0 + img_color[:,:,1] = 255 + img_color[:,:,2] = 0 + for inv_mask_border in inv_mask_borders: + upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img + # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img + + if save_path is not None: + path = os.path.splitext(save_path)[0] + save_path = f'{path}.{self.save_ext}' + imwrite(upsample_img, save_path) + return upsample_img + + def clean_all(self): + self.all_landmarks_5 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] + self.det_faces = [] + self.pad_input_imgs = [] diff --git a/r_facelib/utils/face_utils.py b/r_facelib/utils/face_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..657ad25eb89b02db7ba66b46256b431897526c90 --- /dev/null +++ b/r_facelib/utils/face_utils.py @@ -0,0 +1,248 @@ +import cv2 +import numpy as np +import torch + + +def compute_increased_bbox(bbox, increase_area, preserve_aspect=True): + left, top, right, bot = bbox + width = right - left + height = bot - top + + if preserve_aspect: + width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) + height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) + else: + width_increase = height_increase = increase_area + left = int(left - width_increase * width) + top = int(top - height_increase * height) + right = int(right + width_increase * width) + bot = int(bot + height_increase * height) + return (left, top, right, bot) + + +def get_valid_bboxes(bboxes, h, w): + left = max(bboxes[0], 0) + top = max(bboxes[1], 0) + right = min(bboxes[2], w) + bottom = min(bboxes[3], h) + return (left, top, right, bottom) + + +def align_crop_face_landmarks(img, + landmarks, + output_size, + transform_size=None, + enable_padding=True, + return_inverse_affine=False, + shrink_ratio=(1, 1)): + """Align and crop face with landmarks. + + The output_size and transform_size are based on width. The height is + adjusted based on shrink_ratio_h/shring_ration_w. + + Modified from: + https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py + + Args: + img (Numpy array): Input image. + landmarks (Numpy array): 5 or 68 or 98 landmarks. + output_size (int): Output face size. + transform_size (ing): Transform size. Usually the four time of + output_size. + enable_padding (float): Default: True. + shrink_ratio (float | tuple[float] | list[float]): Shring the whole + face for height and width (crop larger area). Default: (1, 1). + + Returns: + (Numpy array): Cropped face. + """ + lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5 + + if isinstance(shrink_ratio, (float, int)): + shrink_ratio = (shrink_ratio, shrink_ratio) + if transform_size is None: + transform_size = output_size * 4 + + # Parse landmarks + lm = np.array(landmarks) + if lm.shape[0] == 5 and lm_type == 'retinaface_5': + eye_left = lm[0] + eye_right = lm[1] + mouth_avg = (lm[3] + lm[4]) * 0.5 + elif lm.shape[0] == 5 and lm_type == 'dlib_5': + lm_eye_left = lm[2:4] + lm_eye_right = lm[0:2] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = lm[4] + elif lm.shape[0] == 68: + lm_eye_left = lm[36:42] + lm_eye_right = lm[42:48] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[48] + lm[54]) * 0.5 + elif lm.shape[0] == 98: + lm_eye_left = lm[60:68] + lm_eye_right = lm[68:76] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[76] + lm[82]) * 0.5 + + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1 # TODO: you can edit it to get larger rect + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + x *= shrink_ratio[1] # width + y *= shrink_ratio[0] # height + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + + quad_ori = np.copy(quad) + # Shrink, for large face + # TODO: do we really need shrink + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + h, w = img.shape[0:2] + rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink))) + img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA) + quad /= shrink + qsize /= shrink + + # Crop + h, w = img.shape[0:2] + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h)) + if crop[2] - crop[0] < w or crop[3] - crop[1] < h: + img = img[crop[1]:crop[3], crop[0]:crop[2], :] + quad -= crop[0:2] + + # Pad + # pad: (width_left, height_top, width_right, height_bottom) + h, w = img.shape[0:2] + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w = img.shape[0:2] + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * 0.02) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur)) + + img = img.astype('float32') + img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = np.clip(img, 0, 255) # float32, [0, 255] + quad += pad[:2] + + # Transform use cv2 + h_ratio = shrink_ratio[0] / shrink_ratio[1] + dst_h, dst_w = int(transform_size * h_ratio), transform_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0] + cropped_face = cv2.warpAffine( + img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray + + if output_size < transform_size: + cropped_face = cv2.resize( + cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR) + + if return_inverse_affine: + dst_h, dst_w = int(output_size * h_ratio), output_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D( + quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0] + inverse_affine = cv2.invertAffineTransform(affine_matrix) + else: + inverse_affine = None + return cropped_face, inverse_affine + + +def paste_face_back(img, face, inverse_affine): + h, w = img.shape[0:2] + face_h, face_w = face.shape[0:2] + inv_restored = cv2.warpAffine(face, inverse_affine, (w, h)) + mask = np.ones((face_h, face_w, 3), dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h)) + # remove the black borders + inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8)) + inv_restored_remove_border = inv_mask_erosion * inv_restored + total_face_area = np.sum(inv_mask_erosion) // 3 + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img + # float32, [0, 255] + return img + + +if __name__ == '__main__': + import os + + from custom_nodes.facerestore.facelib.detection import init_detection_model + from custom_nodes.facerestore.facelib.utils.face_restoration_helper import get_largest_face + + img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png' + img_name = os.splitext(os.path.basename(img_path))[0] + + # initialize model + det_net = init_detection_model('retinaface_resnet50', half=False) + img_ori = cv2.imread(img_path) + h, w = img_ori.shape[0:2] + # if larger than 800, scale it + scale = max(h / 800, w / 800) + if scale > 1: + img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR) + + with torch.no_grad(): + bboxes = det_net.detect_faces(img, 0.97) + if scale > 1: + bboxes *= scale # the score is incorrect + bboxes = get_largest_face(bboxes, h, w)[0] + + landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)]) + + cropped_face, inverse_affine = align_crop_face_landmarks( + img_ori, + landmarks, + output_size=512, + transform_size=None, + enable_padding=True, + return_inverse_affine=True, + shrink_ratio=(1, 1)) + + cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face) + img = paste_face_back(img_ori, cropped_face, inverse_affine) + cv2.imwrite(f'tmp/{img_name}_back.png', img) diff --git a/r_facelib/utils/misc.py b/r_facelib/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea7c653b509615681305713a96f62514d3be7b6 --- /dev/null +++ b/r_facelib/utils/misc.py @@ -0,0 +1,143 @@ +import cv2 +import os +import os.path as osp +import torch +from torch.hub import download_url_to_file, get_dir +from urllib.parse import urlparse +# from basicsr.utils.download_util import download_file_from_google_drive +#import gdown + + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def download_pretrained_models(file_ids, save_path_root): + os.makedirs(save_path_root, exist_ok=True) + + for file_name, file_id in file_ids.items(): + file_url = 'https://drive.google.com/uc?id='+file_id + save_path = osp.abspath(osp.join(save_path_root, file_name)) + if osp.exists(save_path): + user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') + if user_response.lower() == 'y': + print(f'Covering {file_name} to {save_path}') + print("skipping gdown in facelib/utils/misc.py "+file_url) + #gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + elif user_response.lower() == 'n': + print(f'Skipping {file_name}') + else: + raise ValueError('Wrong input. Only accepts Y/N.') + else: + print(f'Downloading {file_name} to {save_path}') + print("skipping gdown in facelib/utils/misc.py "+file_url) + #gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + """ + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/__pycache__/__init__.cpython-311.pyc b/scripts/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d619af31170ef34aca7bfcd40961932bf4f02c2 Binary files /dev/null and b/scripts/__pycache__/__init__.cpython-311.pyc differ diff --git a/scripts/__pycache__/reactor_faceswap.cpython-311.pyc b/scripts/__pycache__/reactor_faceswap.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a89f908208337e256626b43ffc774223ecd77b2 Binary files /dev/null and b/scripts/__pycache__/reactor_faceswap.cpython-311.pyc differ diff --git a/scripts/__pycache__/reactor_logger.cpython-311.pyc b/scripts/__pycache__/reactor_logger.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..299cb6525dbed4fb4d029650f4d2c0832146b195 Binary files /dev/null and b/scripts/__pycache__/reactor_logger.cpython-311.pyc differ diff --git a/scripts/__pycache__/reactor_swapper.cpython-311.pyc b/scripts/__pycache__/reactor_swapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f61341b70dee66d41408a43b3284a6bfacac726 Binary files /dev/null and b/scripts/__pycache__/reactor_swapper.cpython-311.pyc differ diff --git a/scripts/__pycache__/reactor_version.cpython-311.pyc b/scripts/__pycache__/reactor_version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab6193c8cf78aee03f35d42c972bb0b3b803cf6a Binary files /dev/null and b/scripts/__pycache__/reactor_version.cpython-311.pyc differ diff --git a/scripts/r_archs/__init__.py b/scripts/r_archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/r_archs/__pycache__/__init__.cpython-311.pyc b/scripts/r_archs/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c101b5e0fbf4e8538bd22f551f60b5575d6b1f55 Binary files /dev/null and b/scripts/r_archs/__pycache__/__init__.cpython-311.pyc differ diff --git a/scripts/r_archs/__pycache__/codeformer_arch.cpython-311.pyc b/scripts/r_archs/__pycache__/codeformer_arch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e367f50f7850f99eb3453e47e41fba7e459016dc Binary files /dev/null and b/scripts/r_archs/__pycache__/codeformer_arch.cpython-311.pyc differ diff --git a/scripts/r_archs/__pycache__/vqgan_arch.cpython-311.pyc b/scripts/r_archs/__pycache__/vqgan_arch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54ff0bb7f0b29d5fea05ac63dbc2903d4a960694 Binary files /dev/null and b/scripts/r_archs/__pycache__/vqgan_arch.cpython-311.pyc differ diff --git a/scripts/r_archs/codeformer_arch.py b/scripts/r_archs/codeformer_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..588ef696b2451ca076bb66ac081bde5819ed4247 --- /dev/null +++ b/scripts/r_archs/codeformer_arch.py @@ -0,0 +1,278 @@ +import math +import numpy as np +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from typing import Optional, List + +from scripts.r_archs.vqgan_arch import * +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import ARCH_REGISTRY + + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class TransformerSALayer(nn.Module): + def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + # Implementation of Feedforward model - MLP + self.linear1 = nn.Linear(embed_dim, dim_mlp) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_mlp, embed_dim) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + + # self attention + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + + # ffn + tgt2 = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout2(tgt2) + return tgt + +class Fuse_sft_block(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.encode_enc = ResBlock(2*in_ch, out_ch) + + self.scale = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + self.shift = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + def forward(self, enc_feat, dec_feat, w=1): + enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) + scale = self.scale(enc_feat) + shift = self.shift(enc_feat) + residual = w * (dec_feat * scale + shift) + out = dec_feat + residual + return out + + +@ARCH_REGISTRY.register() +class CodeFormer(VQAutoEncoder): + def __init__(self, dim_embd=512, n_head=8, n_layers=9, + codebook_size=1024, latent_size=256, + connect_list=['32', '64', '128', '256'], + fix_modules=['quantize','generator']): + super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) + + if fix_modules is not None: + for module in fix_modules: + for param in getattr(self, module).parameters(): + param.requires_grad = False + + self.connect_list = connect_list + self.n_layers = n_layers + self.dim_embd = dim_embd + self.dim_mlp = dim_embd*2 + + self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) + self.feat_emb = nn.Linear(256, self.dim_embd) + + # transformer + self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) + for _ in range(self.n_layers)]) + + # logits_predict head + self.idx_pred_layer = nn.Sequential( + nn.LayerNorm(dim_embd), + nn.Linear(dim_embd, codebook_size, bias=False)) + + self.channels = { + '16': 512, + '32': 256, + '64': 256, + '128': 128, + '256': 128, + '512': 64, + } + + # after second residual block for > 16, before attn layer for ==16 + self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} + # after first residual block for > 16, before attn layer for ==16 + self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} + + # fuse_convs_dict + self.fuse_convs_dict = nn.ModuleDict() + for f_size in self.connect_list: + in_ch = self.channels[f_size] + self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): + # ################### Encoder ##################### + enc_feat_dict = {} + out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.encoder.blocks): + x = block(x) + if i in out_list: + enc_feat_dict[str(x.shape[-1])] = x.clone() + + lq_feat = x + # ################# Transformer ################### + # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) + pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) + # BCHW -> BC(HW) -> (HW)BC + feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) + query_emb = feat_emb + # Transformer encoder + for layer in self.ft_layers: + query_emb = layer(query_emb, query_pos=pos_emb) + + # output logits + logits = self.idx_pred_layer(query_emb) # (hw)bn + logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n + + if code_only: # for training stage II + # logits doesn't need softmax before cross_entropy loss + return logits, lq_feat + + # ################# Quantization ################### + # if self.training: + # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) + # # b(hw)c -> bc(hw) -> bchw + # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) + # ------------ + soft_one_hot = F.softmax(logits, dim=2) + _, top_idx = torch.topk(soft_one_hot, 1, dim=2) + quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) + # preserve gradients + # quant_feat = lq_feat + (quant_feat - lq_feat).detach() + + if detach_16: + quant_feat = quant_feat.detach() # for training stage III + if adain: + quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) + + # ################## Generator #################### + x = quant_feat + fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + + for i, block in enumerate(self.generator.blocks): + x = block(x) + if i in fuse_list: # fuse after i-th block + f_size = str(x.shape[-1]) + if w>0: + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) + out = x + # logits doesn't need softmax before cross_entropy loss + return out, logits, lq_feat + \ No newline at end of file diff --git a/scripts/r_archs/vqgan_arch.py b/scripts/r_archs/vqgan_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..50b371228997fd7a3d40ce45c0fd0453467a9f5a --- /dev/null +++ b/scripts/r_archs/vqgan_arch.py @@ -0,0 +1,437 @@ +''' +VQGAN code, adapted from the original created by the Unleashing Transformers authors: +https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py + +''' +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import ARCH_REGISTRY + + +def normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.jit.script +def swish(x): + return x*torch.sigmoid(x) + + +# Define VQVAE classes +class VectorQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, beta): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) + self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.emb_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ + 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + + mean_distance = torch.mean(d) + # find closest encodings + # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) + # [0-1], higher score, higher confidence + min_encoding_scores = torch.exp(-min_encoding_scores/10) + + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, { + "perplexity": perplexity, + "min_encodings": min_encodings, + "min_encoding_indices": min_encoding_indices, + "min_encoding_scores": min_encoding_scores, + "mean_distance": mean_distance + } + + def get_codebook_feat(self, indices, shape): + # input indices: batch*token_num -> (batch*token_num)*1 + # shape: batch, height, width, channel + indices = indices.view(-1,1) + min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) + min_encodings.scatter_(1, indices, 1) + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: # reshape back to match original input shape + z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): + super().__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits + self.embed = nn.Embedding(codebook_size, emb_dim) + + def forward(self, z): + hard = self.straight_through if self.training else True + + logits = self.proj(z) + + soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) + + z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() + min_encoding_indices = soft_one_hot.argmax(dim=1) + + return z_q, diff, { + "min_encoding_indices": min_encoding_indices + } + + +class Downsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + + return x + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None): + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = normalize(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x_in): + x = x_in + x = self.norm1(x) + x = swish(x) + x = self.conv1(x) + x = self.norm2(x) + x = swish(x) + x = self.conv2(x) + if self.in_channels != self.out_channels: + x_in = self.conv_out(x_in) + + return x + x_in + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h*w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c)**(-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Encoder(nn.Module): + def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): + super().__init__() + self.nf = nf + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.attn_resolutions = attn_resolutions + + curr_res = self.resolution + in_ch_mult = (1,)+tuple(ch_mult) + + blocks = [] + # initial convultion + blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) + + # residual and downsampling blocks, with attention on smaller res (16x16) + for i in range(self.num_resolutions): + block_in_ch = nf * in_ch_mult[i] + block_out_ch = nf * ch_mult[i] + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + if curr_res in attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != self.num_resolutions - 1: + blocks.append(Downsample(block_in_ch)) + curr_res = curr_res // 2 + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + # normalise and convert to latent size + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +class Generator(nn.Module): + def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): + super().__init__() + self.nf = nf + self.ch_mult = ch_mult + self.num_resolutions = len(self.ch_mult) + self.num_res_blocks = res_blocks + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.in_channels = emb_dim + self.out_channels = 3 + block_in_ch = self.nf * self.ch_mult[-1] + curr_res = self.resolution // 2 ** (self.num_resolutions-1) + + blocks = [] + # initial conv + blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + for i in reversed(range(self.num_resolutions)): + block_out_ch = self.nf * self.ch_mult[i] + + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + + if curr_res in self.attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != 0: + blocks.append(Upsample(block_in_ch)) + curr_res = curr_res * 2 + + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) + + self.blocks = nn.ModuleList(blocks) + + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +@ARCH_REGISTRY.register() +class VQAutoEncoder(nn.Module): + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): + super().__init__() + logger = get_root_logger() + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks + self.codebook_size = codebook_size + self.embed_dim = emb_dim + self.ch_mult = ch_mult + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.quantizer_type = quantizer + self.encoder = Encoder( + self.in_channels, + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + if self.quantizer_type == "nearest": + self.beta = beta #0.25 + self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) + elif self.quantizer_type == "gumbel": + self.gumbel_num_hiddens = emb_dim + self.straight_through = gumbel_straight_through + self.kl_weight = gumbel_kl_weight + self.quantize = GumbelQuantizer( + self.codebook_size, + self.embed_dim, + self.gumbel_num_hiddens, + self.straight_through, + self.kl_weight + ) + self.generator = Generator( + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_ema' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) + logger.info(f'vqgan is loaded from: {model_path} [params_ema]') + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + logger.info(f'vqgan is loaded from: {model_path} [params]') + else: + raise ValueError(f'Wrong params!') + + + def forward(self, x): + x = self.encoder(x) + quant, codebook_loss, quant_stats = self.quantize(x) + x = self.generator(quant) + return x, codebook_loss, quant_stats + + + +# patch based discriminator +@ARCH_REGISTRY.register() +class VQGANDiscriminator(nn.Module): + def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): + super().__init__() + + layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] + ndf_mult = 1 + ndf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n, 8) + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n_layers, 8) + + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + layers += [ + nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map + self.main = nn.Sequential(*layers) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_d' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + else: + raise ValueError(f'Wrong params!') + + def forward(self, x): + return self.main(x) + \ No newline at end of file diff --git a/scripts/r_faceboost/__init__.py b/scripts/r_faceboost/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/r_faceboost/__pycache__/__init__.cpython-311.pyc b/scripts/r_faceboost/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7951312bd8513efb9cc7fc25ae613054d57ea2f8 Binary files /dev/null and b/scripts/r_faceboost/__pycache__/__init__.cpython-311.pyc differ diff --git a/scripts/r_faceboost/__pycache__/restorer.cpython-311.pyc b/scripts/r_faceboost/__pycache__/restorer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..357b8466acb333250bbb87b1f1f411f88c54c5c6 Binary files /dev/null and b/scripts/r_faceboost/__pycache__/restorer.cpython-311.pyc differ diff --git a/scripts/r_faceboost/__pycache__/swapper.cpython-311.pyc b/scripts/r_faceboost/__pycache__/swapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfef63a94e5ba2a4ee97a475f43c2a6c865ffb86 Binary files /dev/null and b/scripts/r_faceboost/__pycache__/swapper.cpython-311.pyc differ diff --git a/scripts/r_faceboost/restorer.py b/scripts/r_faceboost/restorer.py new file mode 100644 index 0000000000000000000000000000000000000000..60c01652d18e8e3c4ac37f076d15b4a35deaf66f --- /dev/null +++ b/scripts/r_faceboost/restorer.py @@ -0,0 +1,130 @@ +import sys +import cv2 +import numpy as np +import torch +from torchvision.transforms.functional import normalize + +try: + import torch.cuda as cuda +except: + cuda = None + +import comfy.utils +import folder_paths +import comfy.model_management as model_management + +from scripts.reactor_logger import logger +from r_basicsr.utils.registry import ARCH_REGISTRY +from r_chainner import model_loading +from reactor_utils import ( + tensor2img, + img2tensor, + set_ort_session, + prepare_cropped_face, + normalize_cropped_face +) + + +if cuda is not None: + if cuda.is_available(): + providers = ["CUDAExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] +else: + providers = ["CPUExecutionProvider"] + + +def get_restored_face(cropped_face, + face_restore_model, + face_restore_visibility, + codeformer_weight, + interpolation: str = "Bicubic"): + + if interpolation == "Bicubic": + interpolate = cv2.INTER_CUBIC + elif interpolation == "Bilinear": + interpolate = cv2.INTER_LINEAR + elif interpolation == "Nearest": + interpolate = cv2.INTER_NEAREST + elif interpolation == "Lanczos": + interpolate = cv2.INTER_LANCZOS4 + + face_size = 512 + if "1024" in face_restore_model.lower(): + face_size = 1024 + elif "2048" in face_restore_model.lower(): + face_size = 2048 + + scale = face_size / cropped_face.shape[0] + + logger.status(f"Boosting the Face with {face_restore_model} | Face Size is set to {face_size} with Scale Factor = {scale} and '{interpolation}' interpolation") + + cropped_face = cv2.resize(cropped_face, (face_size, face_size), interpolation=interpolate) + + # For upscaling the base 128px face, I found bicubic interpolation to be the best compromise targeting antialiasing + # and detail preservation. Nearest is predictably unusable, Linear produces too much aliasing, and Lanczos produces + # too many hallucinations and artifacts/fringing. + + model_path = folder_paths.get_full_path("facerestore_models", face_restore_model) + device = model_management.get_torch_device() + + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + + try: + + with torch.no_grad(): + + if ".onnx" in face_restore_model: # ONNX models + + ort_session = set_ort_session(model_path, providers=providers) + ort_session_inputs = {} + facerestore_model = ort_session + + for ort_session_input in ort_session.get_inputs(): + if ort_session_input.name == "input": + cropped_face_prep = prepare_cropped_face(cropped_face) + ort_session_inputs[ort_session_input.name] = cropped_face_prep + if ort_session_input.name == "weight": + weight = np.array([1], dtype=np.double) + ort_session_inputs[ort_session_input.name] = weight + + output = ort_session.run(None, ort_session_inputs)[0][0] + restored_face = normalize_cropped_face(output) + + else: # PTH models + + if "codeformer" in face_restore_model.lower(): + codeformer_net = ARCH_REGISTRY.get("CodeFormer")( + dim_embd=512, + codebook_size=1024, + n_head=8, + n_layers=9, + connect_list=["32", "64", "128", "256"], + ).to(device) + checkpoint = torch.load(model_path)["params_ema"] + codeformer_net.load_state_dict(checkpoint) + facerestore_model = codeformer_net.eval() + else: + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + facerestore_model = model_loading.load_state_dict(sd).eval() + facerestore_model.to(device) + + output = facerestore_model(cropped_face_t, w=codeformer_weight)[ + 0] if "codeformer" in face_restore_model.lower() else facerestore_model(cropped_face_t)[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + + del output + torch.cuda.empty_cache() + + except Exception as error: + + print(f"\tFailed inference: {error}", file=sys.stderr) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + if face_restore_visibility < 1: + restored_face = cropped_face * (1 - face_restore_visibility) + restored_face * face_restore_visibility + + restored_face = restored_face.astype("uint8") + return restored_face, scale diff --git a/scripts/r_faceboost/swapper.py b/scripts/r_faceboost/swapper.py new file mode 100644 index 0000000000000000000000000000000000000000..e0467cf29333e504f53ad437d5e42e22ad227eb6 --- /dev/null +++ b/scripts/r_faceboost/swapper.py @@ -0,0 +1,42 @@ +import cv2 +import numpy as np + +# The following code is almost entirely copied from INSwapper; the only change here is that we want to use Lanczos +# interpolation for the warpAffine call. Now that the face has been restored, Lanczos represents a good compromise +# whether the restored face needs to be upscaled or downscaled. +def in_swap(img, bgr_fake, M): + target_img = img + IM = cv2.invertAffineTransform(M) + img_white = np.full((bgr_fake.shape[0], bgr_fake.shape[1]), 255, dtype=np.float32) + + # Note the use of bicubic here; this is functionally the only change from the source code + bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0, flags=cv2.INTER_CUBIC) + + img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) + img_white[img_white > 20] = 255 + img_mask = img_white + mask_h_inds, mask_w_inds = np.where(img_mask == 255) + mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) + mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) + mask_size = int(np.sqrt(mask_h * mask_w)) + k = max(mask_size // 10, 10) + # k = max(mask_size//20, 6) + # k = 6 + kernel = np.ones((k, k), np.uint8) + img_mask = cv2.erode(img_mask, kernel, iterations=1) + kernel = np.ones((2, 2), np.uint8) + k = max(mask_size // 20, 5) + # k = 3 + # k = 3 + kernel_size = (k, k) + blur_size = tuple(2 * i + 1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + k = 5 + kernel_size = (k, k) + blur_size = tuple(2 * i + 1 for i in kernel_size) + img_mask /= 255 + # img_mask = fake_diff + img_mask = np.reshape(img_mask, [img_mask.shape[0], img_mask.shape[1], 1]) + fake_merged = img_mask * bgr_fake + (1 - img_mask) * target_img.astype(np.float32) + fake_merged = fake_merged.astype(np.uint8) + return fake_merged diff --git a/scripts/r_masking/__init__.py b/scripts/r_masking/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/r_masking/__pycache__/__init__.cpython-311.pyc b/scripts/r_masking/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caa2b356fa293654b57f5924cf5d632359b61271 Binary files /dev/null and b/scripts/r_masking/__pycache__/__init__.cpython-311.pyc differ diff --git a/scripts/r_masking/__pycache__/core.cpython-311.pyc b/scripts/r_masking/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c1dd455bf07880153d4e0ab26686478eba5b36d Binary files /dev/null and b/scripts/r_masking/__pycache__/core.cpython-311.pyc differ diff --git a/scripts/r_masking/__pycache__/segs.cpython-311.pyc b/scripts/r_masking/__pycache__/segs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2e3cceffecb2a94b9fc963df883a8a7c93ac22b Binary files /dev/null and b/scripts/r_masking/__pycache__/segs.cpython-311.pyc differ diff --git a/scripts/r_masking/__pycache__/subcore.cpython-311.pyc b/scripts/r_masking/__pycache__/subcore.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..444482e586e035f61f9b1aa56910c1320801d9f6 Binary files /dev/null and b/scripts/r_masking/__pycache__/subcore.cpython-311.pyc differ diff --git a/scripts/r_masking/core.py b/scripts/r_masking/core.py new file mode 100644 index 0000000000000000000000000000000000000000..36862e1be7a9a7b666fb555a06b4d47ef3fbcb95 --- /dev/null +++ b/scripts/r_masking/core.py @@ -0,0 +1,647 @@ +import numpy as np +import cv2 +import torch +import torchvision.transforms.functional as TF + +import sys as _sys +from keyword import iskeyword as _iskeyword +from operator import itemgetter as _itemgetter + +from segment_anything import SamPredictor + +from comfy import model_management + + +################################################################################ +### namedtuple +################################################################################ + +try: + from _collections import _tuplegetter +except ImportError: + _tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc) + +def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None): + """Returns a new subclass of tuple with named fields. + + >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point.__doc__ # docstring for the new class + 'Point(x, y)' + >>> p = Point(11, y=22) # instantiate with positional args or keywords + >>> p[0] + p[1] # indexable like a plain tuple + 33 + >>> x, y = p # unpack like a regular tuple + >>> x, y + (11, 22) + >>> p.x + p.y # fields also accessible by name + 33 + >>> d = p._asdict() # convert to a dictionary + >>> d['x'] + 11 + >>> Point(**d) # convert from a dictionary + Point(x=11, y=22) + >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields + Point(x=100, y=22) + + """ + + # Validate the field names. At the user's option, either generate an error + # message or automatically replace the field name with a valid name. + if isinstance(field_names, str): + field_names = field_names.replace(',', ' ').split() + field_names = list(map(str, field_names)) + typename = _sys.intern(str(typename)) + + if rename: + seen = set() + for index, name in enumerate(field_names): + if (not name.isidentifier() + or _iskeyword(name) + or name.startswith('_') + or name in seen): + field_names[index] = f'_{index}' + seen.add(name) + + for name in [typename] + field_names: + if type(name) is not str: + raise TypeError('Type names and field names must be strings') + if not name.isidentifier(): + raise ValueError('Type names and field names must be valid ' + f'identifiers: {name!r}') + if _iskeyword(name): + raise ValueError('Type names and field names cannot be a ' + f'keyword: {name!r}') + + seen = set() + for name in field_names: + if name.startswith('_') and not rename: + raise ValueError('Field names cannot start with an underscore: ' + f'{name!r}') + if name in seen: + raise ValueError(f'Encountered duplicate field name: {name!r}') + seen.add(name) + + field_defaults = {} + if defaults is not None: + defaults = tuple(defaults) + if len(defaults) > len(field_names): + raise TypeError('Got more default values than field names') + field_defaults = dict(reversed(list(zip(reversed(field_names), + reversed(defaults))))) + + # Variables used in the methods and docstrings + field_names = tuple(map(_sys.intern, field_names)) + num_fields = len(field_names) + arg_list = ', '.join(field_names) + if num_fields == 1: + arg_list += ',' + repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')' + tuple_new = tuple.__new__ + _dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip + + # Create all the named tuple methods to be added to the class namespace + + namespace = { + '_tuple_new': tuple_new, + '__builtins__': {}, + '__name__': f'namedtuple_{typename}', + } + code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))' + __new__ = eval(code, namespace) + __new__.__name__ = '__new__' + __new__.__doc__ = f'Create new instance of {typename}({arg_list})' + if defaults is not None: + __new__.__defaults__ = defaults + + @classmethod + def _make(cls, iterable): + result = tuple_new(cls, iterable) + if _len(result) != num_fields: + raise TypeError(f'Expected {num_fields} arguments, got {len(result)}') + return result + + _make.__func__.__doc__ = (f'Make a new {typename} object from a sequence ' + 'or iterable') + + def _replace(self, /, **kwds): + result = self._make(_map(kwds.pop, field_names, self)) + if kwds: + raise ValueError(f'Got unexpected field names: {list(kwds)!r}') + return result + + _replace.__doc__ = (f'Return a new {typename} object replacing specified ' + 'fields with new values') + + def __repr__(self): + 'Return a nicely formatted representation string' + return self.__class__.__name__ + repr_fmt % self + + def _asdict(self): + 'Return a new dict which maps field names to their values.' + return _dict(_zip(self._fields, self)) + + def __getnewargs__(self): + 'Return self as a plain tuple. Used by copy and pickle.' + return _tuple(self) + + # Modify function metadata to help with introspection and debugging + for method in ( + __new__, + _make.__func__, + _replace, + __repr__, + _asdict, + __getnewargs__, + ): + method.__qualname__ = f'{typename}.{method.__name__}' + + # Build-up the class namespace dictionary + # and use type() to build the result class + class_namespace = { + '__doc__': f'{typename}({arg_list})', + '__slots__': (), + '_fields': field_names, + '_field_defaults': field_defaults, + '__new__': __new__, + '_make': _make, + '_replace': _replace, + '__repr__': __repr__, + '_asdict': _asdict, + '__getnewargs__': __getnewargs__, + '__match_args__': field_names, + } + for index, name in enumerate(field_names): + doc = _sys.intern(f'Alias for field number {index}') + class_namespace[name] = _tuplegetter(index, doc) + + result = type(typename, (tuple,), class_namespace) + + # For pickling to work, the __module__ variable needs to be set to the frame + # where the named tuple is created. Bypass this step in environments where + # sys._getframe is not defined (Jython for example) or sys._getframe is not + # defined for arguments greater than 0 (IronPython), or where the user has + # specified a particular module. + if module is None: + try: + module = _sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + pass + if module is not None: + result.__module__ = module + + return result + + +SEG = namedtuple("SEG", + ['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'], + defaults=[None]) + +def crop_ndarray4(npimg, crop_region): + x1 = crop_region[0] + y1 = crop_region[1] + x2 = crop_region[2] + y2 = crop_region[3] + + cropped = npimg[:, y1:y2, x1:x2, :] + + return cropped + +crop_tensor4 = crop_ndarray4 + +def crop_ndarray2(npimg, crop_region): + x1 = crop_region[0] + y1 = crop_region[1] + x2 = crop_region[2] + y2 = crop_region[3] + + cropped = npimg[y1:y2, x1:x2] + + return cropped + +def crop_image(image, crop_region): + return crop_tensor4(image, crop_region) + +def normalize_region(limit, startp, size): + if startp < 0: + new_endp = min(limit, size) + new_startp = 0 + elif startp + size > limit: + new_startp = max(0, limit - size) + new_endp = limit + else: + new_startp = startp + new_endp = min(limit, startp+size) + + return int(new_startp), int(new_endp) + +def make_crop_region(w, h, bbox, crop_factor, crop_min_size=None): + x1 = bbox[0] + y1 = bbox[1] + x2 = bbox[2] + y2 = bbox[3] + + bbox_w = x2 - x1 + bbox_h = y2 - y1 + + crop_w = bbox_w * crop_factor + crop_h = bbox_h * crop_factor + + if crop_min_size is not None: + crop_w = max(crop_min_size, crop_w) + crop_h = max(crop_min_size, crop_h) + + kernel_x = x1 + bbox_w / 2 + kernel_y = y1 + bbox_h / 2 + + new_x1 = int(kernel_x - crop_w / 2) + new_y1 = int(kernel_y - crop_h / 2) + + # make sure position in (w,h) + new_x1, new_x2 = normalize_region(w, new_x1, crop_w) + new_y1, new_y2 = normalize_region(h, new_y1, crop_h) + + return [new_x1, new_y1, new_x2, new_y2] + +def create_segmasks(results): + bboxs = results[1] + segms = results[2] + confidence = results[3] + + results = [] + for i in range(len(segms)): + item = (bboxs[i], segms[i].astype(np.float32), confidence[i]) + results.append(item) + return results + +def dilate_masks(segmasks, dilation_factor, iter=1): + if dilation_factor == 0: + return segmasks + + dilated_masks = [] + kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) + + kernel = cv2.UMat(kernel) + + for i in range(len(segmasks)): + cv2_mask = segmasks[i][1] + + cv2_mask = cv2.UMat(cv2_mask) + + if dilation_factor > 0: + dilated_mask = cv2.dilate(cv2_mask, kernel, iter) + else: + dilated_mask = cv2.erode(cv2_mask, kernel, iter) + + dilated_mask = dilated_mask.get() + + item = (segmasks[i][0], dilated_mask, segmasks[i][2]) + dilated_masks.append(item) + + return dilated_masks + +def is_same_device(a, b): + a_device = torch.device(a) if isinstance(a, str) else a + b_device = torch.device(b) if isinstance(b, str) else b + return a_device.type == b_device.type and a_device.index == b_device.index + +class SafeToGPU: + def __init__(self, size): + self.size = size + + def to_device(self, obj, device): + if is_same_device(device, 'cpu'): + obj.to(device) + else: + if is_same_device(obj.device, 'cpu'): # cpu to gpu + model_management.free_memory(self.size * 1.3, device) + if model_management.get_free_memory(device) > self.size * 1.3: + try: + obj.to(device) + except: + print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]") + else: + print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]") + +def center_of_bbox(bbox): + w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] + return bbox[0] + w/2, bbox[1] + h/2 + +def sam_predict(predictor, points, plabs, bbox, threshold): + point_coords = None if not points else np.array(points) + point_labels = None if not plabs else np.array(plabs) + + box = np.array([bbox]) if bbox is not None else None + + cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) + + total_masks = [] + + selected = False + max_score = 0 + max_mask = None + for idx in range(len(scores)): + if scores[idx] > max_score: + max_score = scores[idx] + max_mask = cur_masks[idx] + + if scores[idx] >= threshold: + selected = True + total_masks.append(cur_masks[idx]) + else: + pass + + if not selected and max_mask is not None: + total_masks.append(max_mask) + + return total_masks + +def make_2d_mask(mask): + if len(mask.shape) == 4: + return mask.squeeze(0).squeeze(0) + + elif len(mask.shape) == 3: + return mask.squeeze(0) + + return mask + +def gen_detection_hints_from_mask_area(x, y, mask, threshold, use_negative): + mask = make_2d_mask(mask) + + points = [] + plabs = [] + + # minimum sampling step >= 3 + y_step = max(3, int(mask.shape[0] / 20)) + x_step = max(3, int(mask.shape[1] / 20)) + + for i in range(0, len(mask), y_step): + for j in range(0, len(mask[i]), x_step): + if mask[i][j] > threshold: + points.append((x + j, y + i)) + plabs.append(1) + elif use_negative and mask[i][j] == 0: + points.append((x + j, y + i)) + plabs.append(0) + + return points, plabs + +def gen_negative_hints(w, h, x1, y1, x2, y2): + npoints = [] + nplabs = [] + + # minimum sampling step >= 3 + y_step = max(3, int(w / 20)) + x_step = max(3, int(h / 20)) + + for i in range(10, h - 10, y_step): + for j in range(10, w - 10, x_step): + if not (x1 - 10 <= j and j <= x2 + 10 and y1 - 10 <= i and i <= y2 + 10): + npoints.append((j, i)) + nplabs.append(0) + + return npoints, nplabs + +def generate_detection_hints(image, seg, center, detection_hint, dilated_bbox, mask_hint_threshold, use_small_negative, + mask_hint_use_negative): + [x1, y1, x2, y2] = dilated_bbox + + points = [] + plabs = [] + if detection_hint == "center-1": + points.append(center) + plabs = [1] # 1 = foreground point, 0 = background point + + elif detection_hint == "horizontal-2": + gap = (x2 - x1) / 3 + points.append((x1 + gap, center[1])) + points.append((x1 + gap * 2, center[1])) + plabs = [1, 1] + + elif detection_hint == "vertical-2": + gap = (y2 - y1) / 3 + points.append((center[0], y1 + gap)) + points.append((center[0], y1 + gap * 2)) + plabs = [1, 1] + + elif detection_hint == "rect-4": + x_gap = (x2 - x1) / 3 + y_gap = (y2 - y1) / 3 + points.append((x1 + x_gap, center[1])) + points.append((x1 + x_gap * 2, center[1])) + points.append((center[0], y1 + y_gap)) + points.append((center[0], y1 + y_gap * 2)) + plabs = [1, 1, 1, 1] + + elif detection_hint == "diamond-4": + x_gap = (x2 - x1) / 3 + y_gap = (y2 - y1) / 3 + points.append((x1 + x_gap, y1 + y_gap)) + points.append((x1 + x_gap * 2, y1 + y_gap)) + points.append((x1 + x_gap, y1 + y_gap * 2)) + points.append((x1 + x_gap * 2, y1 + y_gap * 2)) + plabs = [1, 1, 1, 1] + + elif detection_hint == "mask-point-bbox": + center = center_of_bbox(seg.bbox) + points.append(center) + plabs = [1] + + elif detection_hint == "mask-area": + points, plabs = gen_detection_hints_from_mask_area(seg.crop_region[0], seg.crop_region[1], + seg.cropped_mask, + mask_hint_threshold, use_small_negative) + + if mask_hint_use_negative == "Outter": + npoints, nplabs = gen_negative_hints(image.shape[0], image.shape[1], + seg.crop_region[0], seg.crop_region[1], + seg.crop_region[2], seg.crop_region[3]) + + points += npoints + plabs += nplabs + + return points, plabs + +def combine_masks2(masks): + if len(masks) == 0: + return None + else: + initial_cv2_mask = np.array(masks[0]).astype(np.uint8) + combined_cv2_mask = initial_cv2_mask + + for i in range(1, len(masks)): + cv2_mask = np.array(masks[i]).astype(np.uint8) + + if combined_cv2_mask.shape == cv2_mask.shape: + combined_cv2_mask = cv2.bitwise_or(combined_cv2_mask, cv2_mask) + else: + # do nothing - incompatible mask + pass + + mask = torch.from_numpy(combined_cv2_mask) + return mask + +def dilate_mask(mask, dilation_factor, iter=1): + if dilation_factor == 0: + return make_2d_mask(mask) + + mask = make_2d_mask(mask) + + kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) + + mask = cv2.UMat(mask) + kernel = cv2.UMat(kernel) + + if dilation_factor > 0: + result = cv2.dilate(mask, kernel, iter) + else: + result = cv2.erode(mask, kernel, iter) + + return result.get() + +def convert_and_stack_masks(masks): + if len(masks) == 0: + return None + + mask_tensors = [] + for mask in masks: + mask_array = np.array(mask, dtype=np.uint8) + mask_tensor = torch.from_numpy(mask_array) + mask_tensors.append(mask_tensor) + + stacked_masks = torch.stack(mask_tensors, dim=0) + stacked_masks = stacked_masks.unsqueeze(1) + + return stacked_masks + +def merge_and_stack_masks(stacked_masks, group_size): + if stacked_masks is None: + return None + + num_masks = stacked_masks.size(0) + merged_masks = [] + + for i in range(0, num_masks, group_size): + subset_masks = stacked_masks[i:i + group_size] + merged_mask = torch.any(subset_masks, dim=0) + merged_masks.append(merged_mask) + + if len(merged_masks) > 0: + merged_masks = torch.stack(merged_masks, dim=0) + + return merged_masks + +def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation, + threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): + if sam_model.is_auto_mode: + device = model_management.get_torch_device() + sam_model.safe_to.to_device(sam_model, device=device) + + try: + predictor = SamPredictor(sam_model) + image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) + predictor.set_image(image, "RGB") + + total_masks = [] + + use_small_negative = mask_hint_use_negative == "Small" + + # seg_shape = segs[0] + segs = segs[1] + if detection_hint == "mask-points": + points = [] + plabs = [] + + for i in range(len(segs)): + bbox = segs[i].bbox + center = center_of_bbox(bbox) + points.append(center) + + # small point is background, big point is foreground + if use_small_negative and bbox[2] - bbox[0] < 10: + plabs.append(0) + else: + plabs.append(1) + + detected_masks = sam_predict(predictor, points, plabs, None, threshold) + total_masks += detected_masks + + else: + for i in range(len(segs)): + bbox = segs[i].bbox + center = center_of_bbox(bbox) + x1 = max(bbox[0] - bbox_expansion, 0) + y1 = max(bbox[1] - bbox_expansion, 0) + x2 = min(bbox[2] + bbox_expansion, image.shape[1]) + y2 = min(bbox[3] + bbox_expansion, image.shape[0]) + + dilated_bbox = [x1, y1, x2, y2] + + points, plabs = generate_detection_hints(image, segs[i], center, detection_hint, dilated_bbox, + mask_hint_threshold, use_small_negative, + mask_hint_use_negative) + + detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold) + + total_masks += detected_masks + + # merge every collected masks + mask = combine_masks2(total_masks) + + finally: + if sam_model.is_auto_mode: + sam_model.cpu() + + pass + + mask_working_device = torch.device("cpu") + + if mask is not None: + mask = mask.float() + mask = dilate_mask(mask.cpu().numpy(), dilation) + mask = torch.from_numpy(mask) + mask = mask.to(device=mask_working_device) + else: + # Extracting batch, height and width + height, width, _ = image.shape + mask = torch.zeros( + (height, width), dtype=torch.float32, device=mask_working_device + ) # empty mask + + stacked_masks = convert_and_stack_masks(total_masks) + + return (mask, merge_and_stack_masks(stacked_masks, group_size=3)) + +def tensor2mask(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t + if size[3] == 1: + return t[:,:,:,0] + elif size[3] == 4: + # Not sure what the right thing to do here is. Going to try to be a little smart and use alpha unless all alpha is 1 in case we'll fallback to RGB behavior + if torch.min(t[:, :, :, 3]).item() != 1.: + return t[:,:,:,3] + return TF.rgb_to_grayscale(tensor2rgb(t).permute(0,3,1,2), num_output_channels=1)[:,0,:,:] + +def tensor2rgb(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t.unsqueeze(3).repeat(1, 1, 1, 3) + if size[3] == 1: + return t.repeat(1, 1, 1, 3) + elif size[3] == 4: + return t[:, :, :, :3] + else: + return t + +def tensor2rgba(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t.unsqueeze(3).repeat(1, 1, 1, 4) + elif size[3] == 1: + return t.repeat(1, 1, 1, 4) + elif size[3] == 3: + alpha_tensor = torch.ones((size[0], size[1], size[2], 1)) + return torch.cat((t, alpha_tensor), dim=3) + else: + return t diff --git a/scripts/r_masking/segs.py b/scripts/r_masking/segs.py new file mode 100644 index 0000000000000000000000000000000000000000..60c22d799a9faea580dd336abfd21248e4be2d3a --- /dev/null +++ b/scripts/r_masking/segs.py @@ -0,0 +1,22 @@ +def filter(segs, labels): + labels = set([label.strip() for label in labels]) + + if 'all' in labels: + return (segs, (segs[0], []), ) + else: + res_segs = [] + remained_segs = [] + + for x in segs[1]: + if x.label in labels: + res_segs.append(x) + elif 'eyes' in labels and x.label in ['left_eye', 'right_eye']: + res_segs.append(x) + elif 'eyebrows' in labels and x.label in ['left_eyebrow', 'right_eyebrow']: + res_segs.append(x) + elif 'pupils' in labels and x.label in ['left_pupil', 'right_pupil']: + res_segs.append(x) + else: + remained_segs.append(x) + + return ((segs[0], res_segs), (segs[0], remained_segs), ) diff --git a/scripts/r_masking/subcore.py b/scripts/r_masking/subcore.py new file mode 100644 index 0000000000000000000000000000000000000000..cf7bf7d73b4f84df8dce2401e7e0d0ab78b17066 --- /dev/null +++ b/scripts/r_masking/subcore.py @@ -0,0 +1,117 @@ +import numpy as np +import cv2 +from PIL import Image + +import scripts.r_masking.core as core +from reactor_utils import tensor_to_pil + +try: + from ultralytics import YOLO +except Exception as e: + print(e) + + +def load_yolo(model_path: str): + try: + return YOLO(model_path) + except ModuleNotFoundError: + # https://github.com/ultralytics/ultralytics/issues/3856 + YOLO("yolov8n.pt") + return YOLO(model_path) + +def inference_bbox( + model, + image: Image.Image, + confidence: float = 0.3, + device: str = "", +): + pred = model(image, conf=confidence, device=device) + + bboxes = pred[0].boxes.xyxy.cpu().numpy() + cv2_image = np.array(image) + if len(cv2_image.shape) == 3: + cv2_image = cv2_image[:, :, ::-1].copy() # Convert RGB to BGR for cv2 processing + else: + # Handle the grayscale image here + # For example, you might want to convert it to a 3-channel grayscale image for consistency: + cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_GRAY2BGR) + cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY) + + segms = [] + for x0, y0, x1, y1 in bboxes: + cv2_mask = np.zeros(cv2_gray.shape, np.uint8) + cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1) + cv2_mask_bool = cv2_mask.astype(bool) + segms.append(cv2_mask_bool) + + n, m = bboxes.shape + if n == 0: + return [[], [], [], []] + + results = [[], [], [], []] + for i in range(len(bboxes)): + results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())]) + results[1].append(bboxes[i]) + results[2].append(segms[i]) + results[3].append(pred[0].boxes[i].conf.cpu().numpy()) + + return results + + +class UltraBBoxDetector: + bbox_model = None + + def __init__(self, bbox_model): + self.bbox_model = bbox_model + + def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): + drop_size = max(drop_size, 1) + detected_results = inference_bbox(self.bbox_model, tensor_to_pil(image), threshold) + segmasks = core.create_segmasks(detected_results) + + if dilation > 0: + segmasks = core.dilate_masks(segmasks, dilation) + + items = [] + h = image.shape[1] + w = image.shape[2] + + for x, label in zip(segmasks, detected_results[0]): + item_bbox = x[0] + item_mask = x[1] + + y1, x1, y2, x2 = item_bbox + + if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue + crop_region = core.make_crop_region(w, h, item_bbox, crop_factor) + + if detailer_hook is not None: + crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region) + + cropped_image = core.crop_image(image, crop_region) + cropped_mask = core.crop_ndarray2(item_mask, crop_region) + confidence = x[2] + # bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h) + + item = core.SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, label, None) + + items.append(item) + + shape = image.shape[1], image.shape[2] + segs = shape, items + + if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): + segs = detailer_hook.post_detection(segs) + + return segs + + def detect_combined(self, image, threshold, dilation): + detected_results = inference_bbox(self.bbox_model, core.tensor2pil(image), threshold) + segmasks = core.create_segmasks(detected_results) + if dilation > 0: + segmasks = core.dilate_masks(segmasks, dilation) + + return core.combine_masks(segmasks) + + def setAux(self, x): + pass diff --git a/scripts/reactor_faceswap.py b/scripts/reactor_faceswap.py new file mode 100644 index 0000000000000000000000000000000000000000..4766d5f431578a39d824e3d880a3c6955f760b88 --- /dev/null +++ b/scripts/reactor_faceswap.py @@ -0,0 +1,193 @@ +import os, glob + +from PIL import Image + +import modules.scripts as scripts +# from modules.upscaler import Upscaler, UpscalerData +from modules import scripts, scripts_postprocessing +from modules.processing import ( + StableDiffusionProcessing, + StableDiffusionProcessingImg2Img, +) +from modules.shared import state +from scripts.reactor_logger import logger +from scripts.reactor_swapper import ( + swap_face, + swap_face_many, + get_current_faces_model, + analyze_faces, + half_det_size, + providers +) +import folder_paths +import comfy.model_management as model_management + + +def get_models(): + swappers = [ + "insightface", + "reswapper" + ] + models_list = [] + for folder in swappers: + models_folder = folder + "/*" + models_path = os.path.join(folder_paths.models_dir,models_folder) + models = glob.glob(models_path) + models = [x for x in models if x.endswith(".onnx") or x.endswith(".pth")] + models_list.extend(models) + return models_list + + +class FaceSwapScript(scripts.Script): + + def process( + self, + p: StableDiffusionProcessing, + img, + enable, + source_faces_index, + faces_index, + model, + swap_in_source, + swap_in_generated, + gender_source, + gender_target, + face_model, + faces_order, + face_boost_enabled, + face_restore_model, + face_restore_visibility, + codeformer_weight, + interpolation, + ): + self.enable = enable + if self.enable: + + self.source = img + self.swap_in_generated = swap_in_generated + self.gender_source = gender_source + self.gender_target = gender_target + self.model = model + self.face_model = face_model + self.faces_order = faces_order + self.face_boost_enabled = face_boost_enabled + self.face_restore_model = face_restore_model + self.face_restore_visibility = face_restore_visibility + self.codeformer_weight = codeformer_weight + self.interpolation = interpolation + self.source_faces_index = [ + int(x) for x in source_faces_index.strip(",").split(",") if x.isnumeric() + ] + self.faces_index = [ + int(x) for x in faces_index.strip(",").split(",") if x.isnumeric() + ] + if len(self.source_faces_index) == 0: + self.source_faces_index = [0] + if len(self.faces_index) == 0: + self.faces_index = [0] + + if self.gender_source is None or self.gender_source == "no": + self.gender_source = 0 + elif self.gender_source == "female": + self.gender_source = 1 + elif self.gender_source == "male": + self.gender_source = 2 + + if self.gender_target is None or self.gender_target == "no": + self.gender_target = 0 + elif self.gender_target == "female": + self.gender_target = 1 + elif self.gender_target == "male": + self.gender_target = 2 + + # if self.source is not None: + if isinstance(p, StableDiffusionProcessingImg2Img) and swap_in_source: + logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index) + + if len(p.init_images) == 1: + + result = swap_face( + self.source, + p.init_images[0], + source_faces_index=self.source_faces_index, + faces_index=self.faces_index, + model=self.model, + gender_source=self.gender_source, + gender_target=self.gender_target, + face_model=self.face_model, + faces_order=self.faces_order, + face_boost_enabled=self.face_boost_enabled, + face_restore_model=self.face_restore_model, + face_restore_visibility=self.face_restore_visibility, + codeformer_weight=self.codeformer_weight, + interpolation=self.interpolation, + ) + p.init_images[0] = result + + # for i in range(len(p.init_images)): + # if state.interrupted or model_management.processing_interrupted(): + # logger.status("Interrupted by User") + # break + # if len(p.init_images) > 1: + # logger.status(f"Swap in %s", i) + # result = swap_face( + # self.source, + # p.init_images[i], + # source_faces_index=self.source_faces_index, + # faces_index=self.faces_index, + # model=self.model, + # gender_source=self.gender_source, + # gender_target=self.gender_target, + # face_model=self.face_model, + # ) + # p.init_images[i] = result + + elif len(p.init_images) > 1: + result = swap_face_many( + self.source, + p.init_images, + source_faces_index=self.source_faces_index, + faces_index=self.faces_index, + model=self.model, + gender_source=self.gender_source, + gender_target=self.gender_target, + face_model=self.face_model, + faces_order=self.faces_order, + face_boost_enabled=self.face_boost_enabled, + face_restore_model=self.face_restore_model, + face_restore_visibility=self.face_restore_visibility, + codeformer_weight=self.codeformer_weight, + interpolation=self.interpolation, + ) + p.init_images = result + + logger.status("--Done!--") + # else: + # logger.error(f"Please provide a source face") + + def postprocess_batch(self, p, *args, **kwargs): + if self.enable: + images = kwargs["images"] + + def postprocess_image(self, p, script_pp: scripts.PostprocessImageArgs, *args): + if self.enable and self.swap_in_generated: + if self.source is not None: + logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index) + image: Image.Image = script_pp.image + result = swap_face( + self.source, + image, + source_faces_index=self.source_faces_index, + faces_index=self.faces_index, + model=self.model, + upscale_options=self.upscale_options, + gender_source=self.gender_source, + gender_target=self.gender_target, + ) + try: + pp = scripts_postprocessing.PostprocessedImage(result) + pp.info = {} + p.extra_generation_params.update(pp.info) + script_pp.image = pp.image + except: + logger.error(f"Cannot create a result image") diff --git a/scripts/reactor_logger.py b/scripts/reactor_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f64e43370c344f27c6ce2748acb17e2218e6223f --- /dev/null +++ b/scripts/reactor_logger.py @@ -0,0 +1,47 @@ +import logging +import copy +import sys + +from modules import shared +from reactor_utils import addLoggingLevel + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "STATUS": "\033[38;5;173m", # Calm ORANGE + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +# Create a new logger +logger = logging.getLogger("ReActor") +logger.propagate = False + +# Add Custom Level +# logging.addLevelName(logging.INFO, "STATUS") +addLoggingLevel("STATUS", logging.INFO + 5) + +# Add handler if we don't have one. +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter("[%(name)s] %(asctime)s - %(levelname)s - %(message)s",datefmt="%H:%M:%S") + ) + logger.addHandler(handler) + +# Configure logger +loglevel_string = getattr(shared.cmd_opts, "reactor_loglevel", "INFO") +loglevel = getattr(logging, loglevel_string.upper(), "info") +logger.setLevel(loglevel) diff --git a/scripts/reactor_swapper.py b/scripts/reactor_swapper.py new file mode 100644 index 0000000000000000000000000000000000000000..3405a8e0d0c018b59a5a876a1ddfb7449858e0d9 --- /dev/null +++ b/scripts/reactor_swapper.py @@ -0,0 +1,576 @@ +import os +import shutil +from typing import List, Union + +import cv2 +import numpy as np +from PIL import Image + +import insightface +from insightface.app.common import Face +# try: +# import torch.cuda as cuda +# except: +# cuda = None +import torch + +import folder_paths +import comfy.model_management as model_management +from modules.shared import state + +from scripts.reactor_logger import logger +from reactor_utils import ( + move_path, + get_image_md5hash, +) +from scripts.r_faceboost import swapper, restorer + +import warnings + +np.warnings = warnings +np.warnings.filterwarnings('ignore') + +# PROVIDERS +try: + if torch.cuda.is_available(): + providers = ["CUDAExecutionProvider"] + elif torch.backends.mps.is_available(): + providers = ["CoreMLExecutionProvider"] + elif hasattr(torch,'dml') or hasattr(torch,'privateuseone'): + providers = ["ROCMExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] +except Exception as e: + logger.debug(f"ExecutionProviderError: {e}.\nEP is set to CPU.") + providers = ["CPUExecutionProvider"] +# if cuda is not None: +# if cuda.is_available(): +# providers = ["CUDAExecutionProvider"] +# else: +# providers = ["CPUExecutionProvider"] +# else: +# providers = ["CPUExecutionProvider"] + +models_path_old = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") +insightface_path_old = os.path.join(models_path_old, "insightface") +insightface_models_path_old = os.path.join(insightface_path_old, "models") + +models_path = folder_paths.models_dir +insightface_path = os.path.join(models_path, "insightface") +insightface_models_path = os.path.join(insightface_path, "models") +reswapper_path = os.path.join(models_path, "reswapper") + +if os.path.exists(models_path_old): + move_path(insightface_models_path_old, insightface_models_path) + move_path(insightface_path_old, insightface_path) + move_path(models_path_old, models_path) +if os.path.exists(insightface_path) and os.path.exists(insightface_path_old): + shutil.rmtree(insightface_path_old) + shutil.rmtree(models_path_old) + + +FS_MODEL = None +CURRENT_FS_MODEL_PATH = None + +ANALYSIS_MODELS = { + "640": None, + "320": None, +} + +SOURCE_FACES = None +SOURCE_IMAGE_HASH = None +TARGET_FACES = None +TARGET_IMAGE_HASH = None +TARGET_FACES_LIST = [] +TARGET_IMAGE_LIST_HASH = [] + +def unload_model(model): + if model is not None: + # check if model has unload method + # if "unload" in model: + # model.unload() + # if "model_unload" in model: + # model.model_unload() + del model + return None + +def unload_all_models(): + global FS_MODEL, CURRENT_FS_MODEL_PATH + FS_MODEL = unload_model(FS_MODEL) + ANALYSIS_MODELS["320"] = unload_model(ANALYSIS_MODELS["320"]) + ANALYSIS_MODELS["640"] = unload_model(ANALYSIS_MODELS["640"]) + +def get_current_faces_model(): + global SOURCE_FACES + return SOURCE_FACES + +def getAnalysisModel(det_size = (640, 640)): + global ANALYSIS_MODELS + ANALYSIS_MODEL = ANALYSIS_MODELS[str(det_size[0])] + if ANALYSIS_MODEL is None: + ANALYSIS_MODEL = insightface.app.FaceAnalysis( + name="buffalo_l", providers=providers, root=insightface_path + ) + ANALYSIS_MODEL.prepare(ctx_id=0, det_size=det_size) + ANALYSIS_MODELS[str(det_size[0])] = ANALYSIS_MODEL + return ANALYSIS_MODEL + +def getFaceSwapModel(model_path: str): + global FS_MODEL, CURRENT_FS_MODEL_PATH + if FS_MODEL is None or CURRENT_FS_MODEL_PATH is None or CURRENT_FS_MODEL_PATH != model_path: + CURRENT_FS_MODEL_PATH = model_path + FS_MODEL = unload_model(FS_MODEL) + FS_MODEL = insightface.model_zoo.get_model(model_path, providers=providers) + + return FS_MODEL + + +def sort_by_order(face, order: str): + if order == "left-right": + return sorted(face, key=lambda x: x.bbox[0]) + if order == "right-left": + return sorted(face, key=lambda x: x.bbox[0], reverse = True) + if order == "top-bottom": + return sorted(face, key=lambda x: x.bbox[1]) + if order == "bottom-top": + return sorted(face, key=lambda x: x.bbox[1], reverse = True) + if order == "small-large": + return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])) + # if order == "large-small": + # return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True) + # by default "large-small": + return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True) + +def get_face_gender( + face, + face_index, + gender_condition, + operated: str, + order: str, +): + gender = [ + x.sex + for x in face + ] + gender.reverse() + # If index is outside of bounds, return None, avoid exception + if face_index >= len(gender): + logger.status("Requested face index (%s) is out of bounds (max available index is %s)", face_index, len(gender)) + return None, 0 + face_gender = gender[face_index] + logger.status("%s Face %s: Detected Gender -%s-", operated, face_index, face_gender) + if (gender_condition == 1 and face_gender == "F") or (gender_condition == 2 and face_gender == "M"): + logger.status("OK - Detected Gender matches Condition") + try: + faces_sorted = sort_by_order(face, order) + return faces_sorted[face_index], 0 + # return sorted(face, key=lambda x: x.bbox[0])[face_index], 0 + except IndexError: + return None, 0 + else: + logger.status("WRONG - Detected Gender doesn't match Condition") + faces_sorted = sort_by_order(face, order) + return faces_sorted[face_index], 1 + # return sorted(face, key=lambda x: x.bbox[0])[face_index], 1 + +def half_det_size(det_size): + logger.status("Trying to halve 'det_size' parameter") + return (det_size[0] // 2, det_size[1] // 2) + +def analyze_faces(img_data: np.ndarray, det_size=(640, 640)): + face_analyser = getAnalysisModel(det_size) + faces = face_analyser.get(img_data) + + # Try halving det_size if no faces are found + if len(faces) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return analyze_faces(img_data, det_size_half) + + return faces + +def get_face_single(img_data: np.ndarray, face, face_index=0, det_size=(640, 640), gender_source=0, gender_target=0, order="large-small"): + + buffalo_path = os.path.join(insightface_models_path, "buffalo_l.zip") + if os.path.exists(buffalo_path): + os.remove(buffalo_path) + + if gender_source != 0: + if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order) + return get_face_gender(face,face_index,gender_source,"Source", order) + + if gender_target != 0: + if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order) + return get_face_gender(face,face_index,gender_target,"Target", order) + + if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order) + + try: + faces_sorted = sort_by_order(face, order) + return faces_sorted[face_index], 0 + # return sorted(face, key=lambda x: x.bbox[0])[face_index], 0 + except IndexError: + return None, 0 + + +def swap_face( + source_img: Union[Image.Image, None], + target_img: Image.Image, + model: Union[str, None] = None, + source_faces_index: List[int] = [0], + faces_index: List[int] = [0], + gender_source: int = 0, + gender_target: int = 0, + face_model: Union[Face, None] = None, + faces_order: List = ["large-small", "large-small"], + face_boost_enabled: bool = False, + face_restore_model = None, + face_restore_visibility: int = 1, + codeformer_weight: float = 0.5, + interpolation: str = "Bicubic", +): + global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH + result_image = target_img + + if model is not None: + + if isinstance(source_img, str): # source_img is a base64 string + import base64, io + if 'base64,' in source_img: # check if the base64 string has a data URL scheme + # split the base64 string to get the actual base64 encoded image data + base64_data = source_img.split('base64,')[-1] + # decode base64 string to bytes + img_bytes = base64.b64decode(base64_data) + else: + # if no data URL scheme, just decode + img_bytes = base64.b64decode(source_img) + + source_img = Image.open(io.BytesIO(img_bytes)) + + target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) + + if source_img is not None: + + source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) + + source_image_md5hash = get_image_md5hash(source_img) + + if SOURCE_IMAGE_HASH is None: + SOURCE_IMAGE_HASH = source_image_md5hash + source_image_same = False + else: + source_image_same = True if SOURCE_IMAGE_HASH == source_image_md5hash else False + if not source_image_same: + SOURCE_IMAGE_HASH = source_image_md5hash + + logger.info("Source Image MD5 Hash = %s", SOURCE_IMAGE_HASH) + logger.info("Source Image the Same? %s", source_image_same) + + if SOURCE_FACES is None or not source_image_same: + logger.status("Analyzing Source Image...") + source_faces = analyze_faces(source_img) + SOURCE_FACES = source_faces + elif source_image_same: + logger.status("Using Hashed Source Face(s) Model...") + source_faces = SOURCE_FACES + + elif face_model is not None: + + source_faces_index = [0] + logger.status("Using Loaded Source Face Model...") + source_face_model = [face_model] + source_faces = source_face_model + + else: + logger.error("Cannot detect any Source") + + if source_faces is not None: + + target_image_md5hash = get_image_md5hash(target_img) + + if TARGET_IMAGE_HASH is None: + TARGET_IMAGE_HASH = target_image_md5hash + target_image_same = False + else: + target_image_same = True if TARGET_IMAGE_HASH == target_image_md5hash else False + if not target_image_same: + TARGET_IMAGE_HASH = target_image_md5hash + + logger.info("Target Image MD5 Hash = %s", TARGET_IMAGE_HASH) + logger.info("Target Image the Same? %s", target_image_same) + + if TARGET_FACES is None or not target_image_same: + logger.status("Analyzing Target Image...") + target_faces = analyze_faces(target_img) + TARGET_FACES = target_faces + elif target_image_same: + logger.status("Using Hashed Target Face(s) Model...") + target_faces = TARGET_FACES + + # No use in trying to swap faces if no faces are found, enhancement + if len(target_faces) == 0: + logger.status("Cannot detect any Target, skipping swapping...") + return result_image + + if source_img is not None: + # separated management of wrong_gender between source and target, enhancement + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[0], gender_source=gender_source, order=faces_order[1]) + else: + # source_face = sorted(source_faces, key=lambda x: x.bbox[0])[source_faces_index[0]] + source_face = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True)[source_faces_index[0]] + src_wrong_gender = 0 + + if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index): + logger.status(f'Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.') + elif source_face is not None: + result = target_img + if "inswapper" in model: + model_path = model_path = os.path.join(insightface_path, model) + elif "reswapper" in model: + model_path = model_path = os.path.join(reswapper_path, model) + face_swapper = getFaceSwapModel(model_path) + + source_face_idx = 0 + + for face_num in faces_index: + # No use in trying to swap faces if no further faces are found, enhancement + if face_num >= len(target_faces): + logger.status("Checked all existing target faces, skipping swapping...") + break + + if len(source_faces_index) > 1 and source_face_idx > 0: + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[source_face_idx], gender_source=gender_source, order=faces_order[1]) + source_face_idx += 1 + + if source_face is not None and src_wrong_gender == 0: + target_face, wrong_gender = get_face_single(target_img, target_faces, face_index=face_num, gender_target=gender_target, order=faces_order[0]) + if target_face is not None and wrong_gender == 0: + logger.status(f"Swapping...") + if face_boost_enabled: + logger.status(f"Face Boost is enabled") + bgr_fake, M = face_swapper.get(result, target_face, source_face, paste_back=False) + bgr_fake, scale = restorer.get_restored_face(bgr_fake, face_restore_model, face_restore_visibility, codeformer_weight, interpolation) + M *= scale + result = swapper.in_swap(target_img, bgr_fake, M) + else: + # logger.status(f"Swapping as-is") + result = face_swapper.get(result, target_face, source_face) + elif wrong_gender == 1: + wrong_gender = 0 + # Keep searching for other faces if wrong gender is detected, enhancement + #if source_face_idx == len(source_faces_index): + # result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) + # return result_image + logger.status("Wrong target gender detected") + continue + else: + logger.status(f"No target face found for {face_num}") + elif src_wrong_gender == 1: + src_wrong_gender = 0 + # Keep searching for other faces if wrong gender is detected, enhancement + #if source_face_idx == len(source_faces_index): + # result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) + # return result_image + logger.status("Wrong source gender detected") + continue + else: + logger.status(f"No source face found for face number {source_face_idx}.") + + result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) + + else: + logger.status("No source face(s) in the provided Index") + else: + logger.status("No source face(s) found") + return result_image + +def swap_face_many( + source_img: Union[Image.Image, None], + target_imgs: List[Image.Image], + model: Union[str, None] = None, + source_faces_index: List[int] = [0], + faces_index: List[int] = [0], + gender_source: int = 0, + gender_target: int = 0, + face_model: Union[Face, None] = None, + faces_order: List = ["large-small", "large-small"], + face_boost_enabled: bool = False, + face_restore_model = None, + face_restore_visibility: int = 1, + codeformer_weight: float = 0.5, + interpolation: str = "Bicubic", +): + global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH, TARGET_FACES_LIST, TARGET_IMAGE_LIST_HASH + result_images = target_imgs + + if model is not None: + + if isinstance(source_img, str): # source_img is a base64 string + import base64, io + if 'base64,' in source_img: # check if the base64 string has a data URL scheme + # split the base64 string to get the actual base64 encoded image data + base64_data = source_img.split('base64,')[-1] + # decode base64 string to bytes + img_bytes = base64.b64decode(base64_data) + else: + # if no data URL scheme, just decode + img_bytes = base64.b64decode(source_img) + + source_img = Image.open(io.BytesIO(img_bytes)) + + target_imgs = [cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) for target_img in target_imgs] + + if source_img is not None: + + source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) + + source_image_md5hash = get_image_md5hash(source_img) + + if SOURCE_IMAGE_HASH is None: + SOURCE_IMAGE_HASH = source_image_md5hash + source_image_same = False + else: + source_image_same = True if SOURCE_IMAGE_HASH == source_image_md5hash else False + if not source_image_same: + SOURCE_IMAGE_HASH = source_image_md5hash + + logger.info("Source Image MD5 Hash = %s", SOURCE_IMAGE_HASH) + logger.info("Source Image the Same? %s", source_image_same) + + if SOURCE_FACES is None or not source_image_same: + logger.status("Analyzing Source Image...") + source_faces = analyze_faces(source_img) + SOURCE_FACES = source_faces + elif source_image_same: + logger.status("Using Hashed Source Face(s) Model...") + source_faces = SOURCE_FACES + + elif face_model is not None: + + source_faces_index = [0] + logger.status("Using Loaded Source Face Model...") + source_face_model = [face_model] + source_faces = source_face_model + + else: + logger.error("Cannot detect any Source") + + if source_faces is not None: + + target_faces = [] + for i, target_img in enumerate(target_imgs): + if state.interrupted or model_management.processing_interrupted(): + logger.status("Interrupted by User") + break + + target_image_md5hash = get_image_md5hash(target_img) + if len(TARGET_IMAGE_LIST_HASH) == 0: + TARGET_IMAGE_LIST_HASH = [target_image_md5hash] + target_image_same = False + elif len(TARGET_IMAGE_LIST_HASH) == i: + TARGET_IMAGE_LIST_HASH.append(target_image_md5hash) + target_image_same = False + else: + target_image_same = True if TARGET_IMAGE_LIST_HASH[i] == target_image_md5hash else False + if not target_image_same: + TARGET_IMAGE_LIST_HASH[i] = target_image_md5hash + + logger.info("(Image %s) Target Image MD5 Hash = %s", i, TARGET_IMAGE_LIST_HASH[i]) + logger.info("(Image %s) Target Image the Same? %s", i, target_image_same) + + if len(TARGET_FACES_LIST) == 0: + logger.status(f"Analyzing Target Image {i}...") + target_face = analyze_faces(target_img) + TARGET_FACES_LIST = [target_face] + elif len(TARGET_FACES_LIST) == i and not target_image_same: + logger.status(f"Analyzing Target Image {i}...") + target_face = analyze_faces(target_img) + TARGET_FACES_LIST.append(target_face) + elif len(TARGET_FACES_LIST) != i and not target_image_same: + logger.status(f"Analyzing Target Image {i}...") + target_face = analyze_faces(target_img) + TARGET_FACES_LIST[i] = target_face + elif target_image_same: + logger.status("(Image %s) Using Hashed Target Face(s) Model...", i) + target_face = TARGET_FACES_LIST[i] + + + # logger.status(f"Analyzing Target Image {i}...") + # target_face = analyze_faces(target_img) + if target_face is not None: + target_faces.append(target_face) + + # No use in trying to swap faces if no faces are found, enhancement + if len(target_faces) == 0: + logger.status("Cannot detect any Target, skipping swapping...") + return result_images + + if source_img is not None: + # separated management of wrong_gender between source and target, enhancement + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[0], gender_source=gender_source, order=faces_order[1]) + else: + # source_face = sorted(source_faces, key=lambda x: x.bbox[0])[source_faces_index[0]] + source_face = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True)[source_faces_index[0]] + src_wrong_gender = 0 + + if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index): + logger.status(f'Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.') + elif source_face is not None: + results = target_imgs + model_path = model_path = os.path.join(insightface_path, model) + face_swapper = getFaceSwapModel(model_path) + + source_face_idx = 0 + + for face_num in faces_index: + # No use in trying to swap faces if no further faces are found, enhancement + if face_num >= len(target_faces): + logger.status("Checked all existing target faces, skipping swapping...") + break + + if len(source_faces_index) > 1 and source_face_idx > 0: + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[source_face_idx], gender_source=gender_source, order=faces_order[1]) + source_face_idx += 1 + + if source_face is not None and src_wrong_gender == 0: + # Reading results to make current face swap on a previous face result + for i, (target_img, target_face) in enumerate(zip(results, target_faces)): + target_face_single, wrong_gender = get_face_single(target_img, target_face, face_index=face_num, gender_target=gender_target, order=faces_order[0]) + if target_face_single is not None and wrong_gender == 0: + result = target_img + logger.status(f"Swapping {i}...") + if face_boost_enabled: + logger.status(f"Face Boost is enabled") + bgr_fake, M = face_swapper.get(target_img, target_face_single, source_face, paste_back=False) + bgr_fake, scale = restorer.get_restored_face(bgr_fake, face_restore_model, face_restore_visibility, codeformer_weight, interpolation) + M *= scale + result = swapper.in_swap(target_img, bgr_fake, M) + else: + # logger.status(f"Swapping as-is") + result = face_swapper.get(target_img, target_face_single, source_face) + results[i] = result + elif wrong_gender == 1: + wrong_gender = 0 + logger.status("Wrong target gender detected") + continue + else: + logger.status(f"No target face found for {face_num}") + elif src_wrong_gender == 1: + src_wrong_gender = 0 + logger.status("Wrong source gender detected") + continue + else: + logger.status(f"No source face found for face number {source_face_idx}.") + + result_images = [Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) for result in results] + + else: + logger.status("No source face(s) in the provided Index") + else: + logger.status("No source face(s) found") + return result_images diff --git a/scripts/reactor_version.py b/scripts/reactor_version.py new file mode 100644 index 0000000000000000000000000000000000000000..e7866f7ca938d4751e6c65c66eb64e8ca4e0e8d7 --- /dev/null +++ b/scripts/reactor_version.py @@ -0,0 +1,13 @@ +app_title = "ReActor Node for ComfyUI" +version_flag = "v0.5.2-b1" + +COLORS = { + "CYAN": "\033[0;36m", # CYAN + "ORANGE": "\033[38;5;173m", # Calm ORANGE + "GREEN": "\033[0;32m", # GREEN + "YELLOW": "\033[0;33m", # YELLOW + "RED": "\033[0;91m", # RED + "0": "\033[0m", # RESET COLOR +} + +print(f"{COLORS['YELLOW']}[ReActor]{COLORS['0']} - {COLORS['ORANGE']}STATUS{COLORS['0']} - {COLORS['GREEN']}Running {version_flag} in ComfyUI{COLORS['0']}")