|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import argparse |
|
|
|
import cv2 |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
import numpy as np |
|
import facer |
|
import facer.transform |
|
from copy import deepcopy |
|
import PIL |
|
|
|
|
|
def resize_image(image, max_size=1024): |
|
height,width,_ = image.shape |
|
if width > max_size or height > max_size: |
|
if width > height: |
|
new_width = max_size |
|
new_height = int((height / width) * max_size) |
|
else: |
|
new_height = max_size |
|
new_width = int((width / height) * max_size) |
|
image = cv2.resize(image, (new_width, new_height)) |
|
return image |
|
|
|
def open_and_resize_image(image_file, max_size=1024, return_type='numpy'): |
|
if isinstance(image_file, str) or isinstance(image_file, PIL.Image.Image): |
|
if isinstance(image_file, str): |
|
img = Image.open(image_file) |
|
else: |
|
img = image_file |
|
width, height = img.size |
|
if width > height: |
|
new_width = max_size |
|
new_height = int((height / width) * max_size) |
|
else: |
|
new_height = max_size |
|
new_width = int((width / height) * max_size) |
|
img = img.resize((new_width, new_height)) |
|
if return_type == 'numpy': |
|
return np.array(img.convert('RGB')) |
|
else: |
|
return img |
|
elif isinstance(image_file, np.ndarray): |
|
height,width,_ = image_file.shape |
|
if width > height: |
|
new_width = max_size |
|
new_height = int((height / width) * max_size) |
|
else: |
|
new_height = max_size |
|
new_width = int((width / height) * max_size) |
|
img = cv2.resize(image_file, (new_width, new_height)) |
|
assert return_type == 'numpy' |
|
return img |
|
else: |
|
raise TypeError("Do not support this img type") |
|
|
|
|
|
@torch.no_grad() |
|
def loose_warp_face(input_image, face_detector, face_target_shape=(512, 512), scale=1.3, face_parser=None, device=None, croped_face_scale=3, bg_value = 0, croped_face_y_offset=0.0): |
|
""" Get the tight/loose warp of the face in the image, in which only one face is of concern. |
|
|
|
Args: |
|
input_image: Image path, or PIL.Image.Image, or np.ndarray (dtype=np.uint8). |
|
face_detector: a facer.face_detector, for face detection. |
|
face_target_shape: Output resolution. |
|
scale: Scale of the output image w.r.t. the face it contains. |
|
|
|
Returns: |
|
PIL.Image.Image, single warped face. |
|
""" |
|
_normalized_face_target_pts = torch.tensor([ |
|
[38.2946, 51.6963], |
|
[73.5318, 51.5014], |
|
[56.0252, 71.7366], |
|
[41.5493, 92.3655], |
|
[70.729904, 92.2041]]) / 112.0 |
|
|
|
target_pts = ((_normalized_face_target_pts - |
|
torch.tensor([0.5, 0.5])) / scale |
|
+ torch.tensor([0.5, 0.5])) |
|
if face_detector is not None: |
|
device = next(face_detector.parameters()).device |
|
|
|
if isinstance(input_image, str): |
|
|
|
np_img = open_and_resize_image(input_image)[:,:,:3] |
|
img_height, img_width = np_img.shape[:2] |
|
image_tensor_hwc = torch.from_numpy(np_img) |
|
elif isinstance(input_image, Image.Image): |
|
image_tensor_hwc = torch.from_numpy(np.array(input_image)[:,:,:3]) |
|
img_height, img_width = image_tensor_hwc.shape[:2] |
|
assert image_tensor_hwc.dtype == torch.uint8 |
|
else: |
|
assert isinstance(input_image, np.ndarray), 'Type %s of input_image is unsupported!' % type(input_image) |
|
assert input_image.dtype == np.uint8, 'dtype %s of input np.ndarray is unsupported!' % input_image.dtype |
|
input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)[:,:,:3] |
|
input_image = resize_image(input_image) |
|
image_tensor_hwc = torch.from_numpy(input_image) |
|
img_height, img_width = image_tensor_hwc.shape[:2] |
|
|
|
image_pt_bchw_255 = facer.hwc2bchw(image_tensor_hwc).to(device) |
|
|
|
res = {'cropped_face_masked': None, 'cropped_face': None, 'cropped_img': None, 'cropped_face_mask': None, 'align_face': None} |
|
|
|
if face_detector is not None: |
|
try: |
|
face_data = face_detector(image_pt_bchw_255) |
|
except: |
|
import pdb;pdb.set_trace() |
|
|
|
if len(face_data) == 0: |
|
return res |
|
|
|
if face_parser is not None: |
|
with torch.inference_mode(): |
|
faces = face_parser(image_pt_bchw_255, face_data) |
|
seg_logits = faces['seg']['logits'] |
|
seg_probs = seg_logits.softmax(dim=1) |
|
seg_probs = seg_probs.argmax(dim=1).unsqueeze(1)[:1] |
|
|
|
face_rects = face_data['rects'][:1] |
|
face_rects = face_data['rects'][:1] |
|
x1,y1,x2,y2 = face_rects[0][:4] |
|
x1 = (int(x1.item())) |
|
y1 = (int(y1.item())) |
|
x2 = (int(x2.item())) |
|
y2 = (int(y2.item())) |
|
face_width = x2-x1 |
|
face_height = y2-y1 |
|
center_x = int(0.5*(x1+x2)) |
|
center_y = int(0.5*(y1+y2)) + croped_face_y_offset * face_height |
|
croped_face_width = face_width*croped_face_scale |
|
croped_face_height = face_height*croped_face_scale |
|
|
|
x1 = max(int(center_x-0.5*croped_face_width),0) |
|
x2 = min(int(center_x+0.5*croped_face_width), img_width-1) |
|
y1 = max(int(center_y-0.5*croped_face_height),0) |
|
y2 = min(int(center_y+0.5*croped_face_height), img_height-1) |
|
croped_face_height = y2-y1 |
|
croped_face_width = x2-x1 |
|
center_x = int(0.5*(x1+x2)) |
|
center_y = int(0.5*(y1+y2)) |
|
croped_face_len = min(croped_face_height, croped_face_width) |
|
x1 = int(center_x - 0.5*croped_face_len) |
|
y1 = int(center_y - 0.5*croped_face_len) |
|
x2 = x1+croped_face_len |
|
y2 = y1+croped_face_len |
|
croped_image_pt_bchw_255 = image_pt_bchw_255[:, :, y1:y2, x1:x2] |
|
face_points = face_data['points'][:1] |
|
batch_inds = face_data['image_ids'][:1] |
|
|
|
matrix_align = facer.transform.get_face_align_matrix( |
|
face_points, face_target_shape, |
|
target_pts=(target_pts * torch.tensor(face_target_shape))) |
|
|
|
grid = facer.transform.make_tanh_warp_grid( |
|
matrix_align, 0.0, face_target_shape, image_pt_bchw_255.shape[2:],) |
|
image = F.grid_sample( |
|
image_pt_bchw_255.float()[batch_inds], |
|
grid, 'bilinear', align_corners=False) |
|
image_align_raw = deepcopy(image) |
|
image_align_raw = facer.bchw2hwc(image_align_raw).to(torch.uint8).cpu().numpy() |
|
image_align_raw = Image.fromarray(image_align_raw) |
|
image_croped = facer.bchw2hwc(croped_image_pt_bchw_255).to(torch.uint8).cpu().numpy() |
|
image_croped = Image.fromarray(image_croped) |
|
if face_parser is not None: |
|
image_no_mask = deepcopy(image) |
|
new_size = list(seg_probs.shape) |
|
new_size[1] = image.shape[1] |
|
seg_probs = seg_probs.expand(new_size) |
|
assert seg_probs.shape[0] == 1 and image.shape[0] == 1, 'mask shape {}, != image shape {}'.format(seg_probs.shape, image.shape) |
|
mask_img = F.grid_sample(seg_probs.float(), grid, 'bilinear', align_corners=False) |
|
image[mask_img == 0] = bg_value |
|
mask_img[mask_img!=0] = 1 |
|
assert mask_img.shape[0] == 1 |
|
else: |
|
image_no_mask = image |
|
mask_img = None |
|
else: |
|
image = image_pt_bchw_255 |
|
image_no_mask = image_pt_bchw_255 |
|
image_align_raw = None |
|
image_croped = None |
|
|
|
image = facer.bchw2hwc(image).to(torch.uint8).cpu().numpy() |
|
image_no_mask = facer.bchw2hwc(image_no_mask).to(torch.uint8).cpu().numpy() |
|
|
|
res.update({'cropped_face_masked': Image.fromarray(image), 'cropped_face': Image.fromarray(image_no_mask), 'cropped_img':image_croped, 'cropped_face_mask': mask_img, 'align_face': image_align_raw}) |
|
return res |
|
|
|
def tight_warp_face(input_image, face_detector, face_parser=None, device=None): |
|
return loose_warp_face(input_image, face_detector, |
|
face_target_shape=(112, 112), scale=1, face_parser=face_parser, device=device) |
|
|