File size: 5,197 Bytes
4479f79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2a06a2
4479f79
 
bbe8523
4479f79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import facer
from PIL import Image
from torchvision import transforms
from eval.tools.face_utils.face import tight_warp_face
from eval.tools.face_utils.face_recg import Backbone
import os
from torch.nn import functional as F
from src.utils.data_utils import pad_to_square, pad_to_target, json_dump, json_load, split_grid


def expand_bounding_box(x_min, y_min, x_max, y_max, factor=1.3):
    # Calculate the center of the bounding box
    x_center = (x_min + x_max) / 2
    y_center = (y_min + y_max) / 2

    # Calculate the width and height of the bounding box
    width = x_max - x_min
    height = y_max - y_min

    # Calculate the new width and height
    new_width = factor * width
    new_height = factor * height

    # Calculate the new bounding box coordinates
    x_min_new = x_center - new_width / 2
    x_max_new = x_center + new_width / 2
    y_min_new = y_center - new_height / 2
    y_max_new = y_center + new_height / 2

    return x_min_new, y_min_new, x_max_new, y_max_new

class FaceID:
    def __init__(self, device):
        self.device = torch.device(device)
        self.T = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        self.detector = facer.face_detector("retinaface/resnet50", device=device)
        face_model_path = os.getenv("FACE_ID_MODEL_PATH")
        print(f'this is the way {face_model_path}')
        self.model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se')
        self.model.load_state_dict(torch.load(face_model_path, map_location=device))
        self.model.to(device)
        self.model.eval()

    def detect(self, image, expand_scale=1.3):
        with torch.no_grad():
            faces = self.detector(image.to(self.device))
        bboxes = faces['rects'].detach().cpu().tolist()
        bboxes = [expand_bounding_box(*x, expand_scale) for x in bboxes]
        return bboxes

    def __call__(self, image_x, image_y, normalize=False):
        # NOTE: Only Support One Face Per Image
        
        try:
            warp_x = tight_warp_face(image_x, self.detector)['cropped_face_masked']
            warp_y = tight_warp_face(image_y, self.detector)['cropped_face_masked']
        except:
            # print("[Warning] No face detected!!")
            return 0

        if warp_x is None or warp_y is None:
            # print("[Warning] No face detected!!")
            return 0
        
        feature_x = self.model(self.T(warp_x).unsqueeze(0).to(self.device))[0] # [512]
        feature_y = self.model(self.T(warp_y).unsqueeze(0).to(self.device))[0] # [512]
        
        
        if normalize:
            feature_x = feature_x / feature_x.norm(p=2, dim=-1, keepdim=True)
            feature_y = feature_y / feature_y.norm(p=2, dim=-1, keepdim=True)

        return F.cosine_similarity(feature_x, feature_y, dim=0).item() * 100
    


if __name__ == "__main__":
    # online demo: https://dun.163.com/trial/face/compare
    from src.train.data.data_utils import pil2tensor
    import numpy as np


    faceid = FaceID("cuda")
    real_image_path = "assets/bengio_bengio.png"

    # gen_image_path = "runs/0303-2034_flux100k_mod-t_oc-sblocks_multi-0.5_fs-lora8_cond192_res384_bs48_resume/eval/ckpt/40000/0304_cond192_tar512/1_A man is wearing green headphones standi.png"
    # gen_image_path = "runs/0303-2034_flux100k_mod-t_oc-sblocks_multi-0.5_fs-lora8_cond192_res384_bs48_resume/eval/ckpt/40000/0304_cond192_tar512/7_A man is wearing green headphones standi.png"
    # gen_image_path = "runs/0303-2034_flux100k_mod-t_oc-sblocks_multi-0.5_fs-lora8_cond192_res384_bs48_resume/eval/ckpt/40000/0304_cond192_tar512/198_A man wearing a black white suit and a w.png"
    gen_image_path = "data/tmp/GCG/florence-sam_phrase-grounding_S3L-two-20k_v41_wds/00000/qwen10M_00000-00010_00000_000000008_vis.png"
    if "eval/ckpt" in gen_image_path:
        gen_images = split_grid(Image.open(gen_image_path))
    else:
        gen_images = [Image.open(gen_image_path)]

    for i, gen_img in enumerate(gen_images):
        # img_tensor = torch.from_numpy(np.array(gen_img)).unsqueeze(0).permute(0, 3, 1, 2)
        img_tensor = (pil2tensor(gen_img).unsqueeze(0) * 255).to(torch.uint8)
        bboxes = faceid.detect(img_tensor)
        for j, bbox in enumerate(bboxes):
            face_img = gen_img.crop(bbox)
            face_img.save(f"tmp_{j}.png")
            print(faceid(real_image_path, face_img))
        # print(faceid.detect(img_tensor))
        # break
        # print(faceid(real_image_path, gen_img))