|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import facer |
|
from PIL import Image |
|
from torchvision import transforms |
|
from eval.tools.face_utils.face import tight_warp_face |
|
from eval.tools.face_utils.face_recg import Backbone |
|
import os |
|
from torch.nn import functional as F |
|
from src.utils.data_utils import pad_to_square, pad_to_target, json_dump, json_load, split_grid |
|
|
|
|
|
def expand_bounding_box(x_min, y_min, x_max, y_max, factor=1.3): |
|
|
|
x_center = (x_min + x_max) / 2 |
|
y_center = (y_min + y_max) / 2 |
|
|
|
|
|
width = x_max - x_min |
|
height = y_max - y_min |
|
|
|
|
|
new_width = factor * width |
|
new_height = factor * height |
|
|
|
|
|
x_min_new = x_center - new_width / 2 |
|
x_max_new = x_center + new_width / 2 |
|
y_min_new = y_center - new_height / 2 |
|
y_max_new = y_center + new_height / 2 |
|
|
|
return x_min_new, y_min_new, x_max_new, y_max_new |
|
|
|
class FaceID: |
|
def __init__(self, device): |
|
self.device = torch.device(device) |
|
self.T = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
]) |
|
|
|
self.detector = facer.face_detector("retinaface/resnet50", device=device) |
|
face_model_path = os.getenv("FACE_ID_MODEL_PATH") |
|
print(f'this is the way {face_model_path}') |
|
self.model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se') |
|
self.model.load_state_dict(torch.load(face_model_path, map_location=device)) |
|
self.model.to(device) |
|
self.model.eval() |
|
|
|
def detect(self, image, expand_scale=1.3): |
|
with torch.no_grad(): |
|
faces = self.detector(image.to(self.device)) |
|
bboxes = faces['rects'].detach().cpu().tolist() |
|
bboxes = [expand_bounding_box(*x, expand_scale) for x in bboxes] |
|
return bboxes |
|
|
|
def __call__(self, image_x, image_y, normalize=False): |
|
|
|
|
|
try: |
|
warp_x = tight_warp_face(image_x, self.detector)['cropped_face_masked'] |
|
warp_y = tight_warp_face(image_y, self.detector)['cropped_face_masked'] |
|
except: |
|
|
|
return 0 |
|
|
|
if warp_x is None or warp_y is None: |
|
|
|
return 0 |
|
|
|
feature_x = self.model(self.T(warp_x).unsqueeze(0).to(self.device))[0] |
|
feature_y = self.model(self.T(warp_y).unsqueeze(0).to(self.device))[0] |
|
|
|
|
|
if normalize: |
|
feature_x = feature_x / feature_x.norm(p=2, dim=-1, keepdim=True) |
|
feature_y = feature_y / feature_y.norm(p=2, dim=-1, keepdim=True) |
|
|
|
return F.cosine_similarity(feature_x, feature_y, dim=0).item() * 100 |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
from src.train.data.data_utils import pil2tensor |
|
import numpy as np |
|
|
|
|
|
faceid = FaceID("cuda") |
|
real_image_path = "assets/bengio_bengio.png" |
|
|
|
|
|
|
|
|
|
gen_image_path = "data/tmp/GCG/florence-sam_phrase-grounding_S3L-two-20k_v41_wds/00000/qwen10M_00000-00010_00000_000000008_vis.png" |
|
if "eval/ckpt" in gen_image_path: |
|
gen_images = split_grid(Image.open(gen_image_path)) |
|
else: |
|
gen_images = [Image.open(gen_image_path)] |
|
|
|
for i, gen_img in enumerate(gen_images): |
|
|
|
img_tensor = (pil2tensor(gen_img).unsqueeze(0) * 255).to(torch.uint8) |
|
bboxes = faceid.detect(img_tensor) |
|
for j, bbox in enumerate(bboxes): |
|
face_img = gen_img.crop(bbox) |
|
face_img.save(f"tmp_{j}.png") |
|
print(faceid(real_image_path, face_img)) |
|
|
|
|
|
|
|
|
|
|
|
|