File size: 5,197 Bytes
4479f79 f2a06a2 4479f79 bbe8523 4479f79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import facer
from PIL import Image
from torchvision import transforms
from eval.tools.face_utils.face import tight_warp_face
from eval.tools.face_utils.face_recg import Backbone
import os
from torch.nn import functional as F
from src.utils.data_utils import pad_to_square, pad_to_target, json_dump, json_load, split_grid
def expand_bounding_box(x_min, y_min, x_max, y_max, factor=1.3):
# Calculate the center of the bounding box
x_center = (x_min + x_max) / 2
y_center = (y_min + y_max) / 2
# Calculate the width and height of the bounding box
width = x_max - x_min
height = y_max - y_min
# Calculate the new width and height
new_width = factor * width
new_height = factor * height
# Calculate the new bounding box coordinates
x_min_new = x_center - new_width / 2
x_max_new = x_center + new_width / 2
y_min_new = y_center - new_height / 2
y_max_new = y_center + new_height / 2
return x_min_new, y_min_new, x_max_new, y_max_new
class FaceID:
def __init__(self, device):
self.device = torch.device(device)
self.T = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
self.detector = facer.face_detector("retinaface/resnet50", device=device)
face_model_path = os.getenv("FACE_ID_MODEL_PATH")
print(f'this is the way {face_model_path}')
self.model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se')
self.model.load_state_dict(torch.load(face_model_path, map_location=device))
self.model.to(device)
self.model.eval()
def detect(self, image, expand_scale=1.3):
with torch.no_grad():
faces = self.detector(image.to(self.device))
bboxes = faces['rects'].detach().cpu().tolist()
bboxes = [expand_bounding_box(*x, expand_scale) for x in bboxes]
return bboxes
def __call__(self, image_x, image_y, normalize=False):
# NOTE: Only Support One Face Per Image
try:
warp_x = tight_warp_face(image_x, self.detector)['cropped_face_masked']
warp_y = tight_warp_face(image_y, self.detector)['cropped_face_masked']
except:
# print("[Warning] No face detected!!")
return 0
if warp_x is None or warp_y is None:
# print("[Warning] No face detected!!")
return 0
feature_x = self.model(self.T(warp_x).unsqueeze(0).to(self.device))[0] # [512]
feature_y = self.model(self.T(warp_y).unsqueeze(0).to(self.device))[0] # [512]
if normalize:
feature_x = feature_x / feature_x.norm(p=2, dim=-1, keepdim=True)
feature_y = feature_y / feature_y.norm(p=2, dim=-1, keepdim=True)
return F.cosine_similarity(feature_x, feature_y, dim=0).item() * 100
if __name__ == "__main__":
# online demo: https://dun.163.com/trial/face/compare
from src.train.data.data_utils import pil2tensor
import numpy as np
faceid = FaceID("cuda")
real_image_path = "assets/bengio_bengio.png"
# gen_image_path = "runs/0303-2034_flux100k_mod-t_oc-sblocks_multi-0.5_fs-lora8_cond192_res384_bs48_resume/eval/ckpt/40000/0304_cond192_tar512/1_A man is wearing green headphones standi.png"
# gen_image_path = "runs/0303-2034_flux100k_mod-t_oc-sblocks_multi-0.5_fs-lora8_cond192_res384_bs48_resume/eval/ckpt/40000/0304_cond192_tar512/7_A man is wearing green headphones standi.png"
# gen_image_path = "runs/0303-2034_flux100k_mod-t_oc-sblocks_multi-0.5_fs-lora8_cond192_res384_bs48_resume/eval/ckpt/40000/0304_cond192_tar512/198_A man wearing a black white suit and a w.png"
gen_image_path = "data/tmp/GCG/florence-sam_phrase-grounding_S3L-two-20k_v41_wds/00000/qwen10M_00000-00010_00000_000000008_vis.png"
if "eval/ckpt" in gen_image_path:
gen_images = split_grid(Image.open(gen_image_path))
else:
gen_images = [Image.open(gen_image_path)]
for i, gen_img in enumerate(gen_images):
# img_tensor = torch.from_numpy(np.array(gen_img)).unsqueeze(0).permute(0, 3, 1, 2)
img_tensor = (pil2tensor(gen_img).unsqueeze(0) * 255).to(torch.uint8)
bboxes = faceid.detect(img_tensor)
for j, bbox in enumerate(bboxes):
face_img = gen_img.crop(bbox)
face_img.save(f"tmp_{j}.png")
print(faceid(real_image_path, face_img))
# print(faceid.detect(img_tensor))
# break
# print(faceid(real_image_path, gen_img))
|