|
from numpy.core.numeric import require |
|
from numpy.lib.function_base import quantile |
|
import torch |
|
import torch.nn.functional as F |
|
import copy |
|
import numpy as np |
|
|
|
import os |
|
import sys |
|
import cv2 |
|
import argparse |
|
import tqdm |
|
from utils.commons.multiprocess_utils import multiprocess_run_tqdm |
|
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker |
|
|
|
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel |
|
import pickle |
|
|
|
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM', |
|
camera_distance=10, focal=1015, keypoint_mode='mediapipe') |
|
face_model.to("cuda") |
|
|
|
|
|
index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305, |
|
33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178] |
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
LAMBDA_REG_ID = 0.3 |
|
LAMBDA_REG_EXP = 0.05 |
|
|
|
def save_file(name, content): |
|
with open(name, "wb") as f: |
|
pickle.dump(content, f) |
|
|
|
def load_file(name): |
|
with open(name, "rb") as f: |
|
content = pickle.load(f) |
|
return content |
|
|
|
def cal_lan_loss_mp(proj_lan, gt_lan): |
|
|
|
loss = (proj_lan - gt_lan).pow(2) |
|
|
|
unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454] |
|
eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] |
|
inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] |
|
outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] |
|
weights = torch.ones_like(loss) |
|
weights[:, eye] = 5 |
|
weights[:, inner_lip] = 2 |
|
weights[:, outer_lip] = 2 |
|
weights[:, unmatch_mask] = 0 |
|
loss = loss * weights |
|
return torch.mean(loss) |
|
|
|
def cal_lan_loss(proj_lan, gt_lan): |
|
|
|
loss = (proj_lan - gt_lan)** 2 |
|
|
|
weights = torch.zeros_like(loss) |
|
weights = torch.ones_like(loss) |
|
weights[:, 36:48, :] = 3 |
|
weights[:, -8:, :] = 3 |
|
weights[:, 28:31, :] = 3 |
|
loss = loss * weights |
|
return torch.mean(loss) |
|
|
|
def set_requires_grad(tensor_list): |
|
for tensor in tensor_list: |
|
tensor.requires_grad = True |
|
|
|
def read_video_to_frames(img_name): |
|
frames = [] |
|
cap = cv2.VideoCapture(img_name) |
|
while cap.isOpened(): |
|
ret, frame_bgr = cap.read() |
|
if frame_bgr is None: |
|
break |
|
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) |
|
frames.append(frame_rgb) |
|
return np.stack(frames) |
|
|
|
@torch.enable_grad() |
|
def fit_3dmm_for_a_image(img_name, debug=False, keypoint_mode='mediapipe', device="cuda:0", save=True): |
|
img = cv2.imread(img_name) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
img_h, img_w = img.shape[0], img.shape[0] |
|
assert img_h == img_w |
|
num_frames = 1 |
|
|
|
lm_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png", "_lms.npy") |
|
if lm_name.endswith('_lms.npy') and os.path.exists(lm_name): |
|
lms = np.load(lm_name) |
|
else: |
|
|
|
try: |
|
landmarker = MediapipeLandmarker() |
|
lms = landmarker.extract_lm478_from_img_name(img_name) |
|
|
|
except Exception as e: |
|
print(e) |
|
return |
|
if lms is None: |
|
print("get None lms_2d, please check whether each frame has one head, exiting...") |
|
return |
|
lms = lms[:468].reshape([468,2]) |
|
lms = torch.FloatTensor(lms).to(device=device) |
|
lms[..., 1] = img_h - lms[..., 1] |
|
|
|
if keypoint_mode == 'mediapipe': |
|
cal_lan_loss_fn = cal_lan_loss_mp |
|
out_name = img_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png", "_coeff_fit_mp.npy") |
|
else: |
|
cal_lan_loss_fn = cal_lan_loss |
|
out_name = img_name.replace("/images_512/", "/coeff_fit_lm68/").replace(".png", "_coeff_fit_lm68.npy") |
|
try: |
|
os.makedirs(os.path.dirname(out_name), exist_ok=True) |
|
except: |
|
pass |
|
|
|
id_dim, exp_dim = 80, 64 |
|
sel_ids = np.arange(0, num_frames, 40) |
|
sel_num = sel_ids.shape[0] |
|
arg_focal = face_model.focal |
|
|
|
h = w = face_model.center * 2 |
|
img_scale_factor = img_h / h |
|
lms /= img_scale_factor |
|
cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).to(device=device) |
|
|
|
id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True) |
|
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True) |
|
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True) |
|
trans = lms.new_zeros((num_frames, 3), requires_grad=True) |
|
|
|
focal_length = lms.new_zeros(1, requires_grad=True) |
|
focal_length.data += arg_focal |
|
|
|
set_requires_grad([id_para, exp_para, euler_angle, trans]) |
|
|
|
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1) |
|
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1) |
|
|
|
|
|
for _ in range(200): |
|
proj_geo = face_model.compute_for_landmark_fit( |
|
id_para, exp_para, euler_angle, trans) |
|
loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach()) |
|
loss = loss_lan |
|
optimizer_frame.zero_grad() |
|
loss.backward() |
|
optimizer_frame.step() |
|
|
|
|
|
|
|
for param_group in optimizer_frame.param_groups: |
|
param_group['lr'] = 0.1 |
|
|
|
|
|
for _ in range(200): |
|
proj_geo = face_model.compute_for_landmark_fit( |
|
id_para, exp_para, euler_angle, trans) |
|
loss_lan = cal_lan_loss_fn( |
|
proj_geo[:, :, :2], lms.detach()) |
|
loss_regid = torch.mean(id_para*id_para) |
|
loss_regexp = torch.mean(exp_para * exp_para) |
|
|
|
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP |
|
optimizer_idexp.zero_grad() |
|
optimizer_frame.zero_grad() |
|
loss.backward() |
|
optimizer_idexp.step() |
|
optimizer_frame.step() |
|
|
|
|
|
|
|
|
|
|
|
id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True) |
|
id_para_.data = id_para.data.clone() |
|
id_para = id_para_ |
|
exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True) |
|
exp_para_.data = exp_para.data.clone() |
|
exp_para = exp_para_ |
|
euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True) |
|
euler_angle_.data = euler_angle.data.clone() |
|
euler_angle = euler_angle_ |
|
trans_ = lms.new_zeros((num_frames, 3), requires_grad=True) |
|
trans_.data = trans.data.clone() |
|
trans = trans_ |
|
|
|
batch_size = 1 |
|
|
|
|
|
for i in range(int((num_frames-1)/batch_size+1)): |
|
if (i+1)*batch_size > num_frames: |
|
start_n = num_frames-batch_size |
|
sel_ids = np.arange(max(num_frames-batch_size,0), num_frames) |
|
else: |
|
start_n = i*batch_size |
|
sel_ids = np.arange(i*batch_size, i*batch_size+batch_size) |
|
sel_lms = lms[sel_ids] |
|
|
|
sel_id_para = id_para.new_zeros( |
|
(batch_size, id_dim), requires_grad=True) |
|
sel_id_para.data = id_para[sel_ids].clone() |
|
sel_exp_para = exp_para.new_zeros( |
|
(batch_size, exp_dim), requires_grad=True) |
|
sel_exp_para.data = exp_para[sel_ids].clone() |
|
sel_euler_angle = euler_angle.new_zeros( |
|
(batch_size, 3), requires_grad=True) |
|
sel_euler_angle.data = euler_angle[sel_ids].clone() |
|
sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True) |
|
sel_trans.data = trans[sel_ids].clone() |
|
|
|
set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans]) |
|
optimizer_cur_batch = torch.optim.Adam( |
|
[sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005) |
|
|
|
for j in range(50): |
|
proj_geo = face_model.compute_for_landmark_fit( |
|
sel_id_para, sel_exp_para, sel_euler_angle, sel_trans) |
|
loss_lan = cal_lan_loss_fn( |
|
proj_geo[:, :, :2], lms.unsqueeze(0).detach()) |
|
|
|
loss_regid = torch.mean(sel_id_para*sel_id_para) |
|
loss_regexp = torch.mean(sel_exp_para*sel_exp_para) |
|
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP |
|
optimizer_cur_batch.zero_grad() |
|
loss.backward() |
|
optimizer_cur_batch.step() |
|
print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f}") |
|
id_para[sel_ids].data = sel_id_para.data.clone() |
|
exp_para[sel_ids].data = sel_exp_para.data.clone() |
|
euler_angle[sel_ids].data = sel_euler_angle.data.clone() |
|
trans[sel_ids].data = sel_trans.data.clone() |
|
|
|
coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(), |
|
'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()} |
|
if save: |
|
np.save(out_name, coeff_dict, allow_pickle=True) |
|
|
|
if debug: |
|
import imageio |
|
debug_name = img_name.replace("/images_512/", "/coeff_fit_mp_debug/").replace(".png", "_debug.png").replace(".jpg", "_debug.jpg") |
|
try: os.makedirs(os.path.dirname(debug_name), exist_ok=True) |
|
except: pass |
|
proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans) |
|
lm68s = proj_geo[:,:,:2].detach().cpu().numpy() |
|
lm68s = lm68s * img_scale_factor |
|
lms = lms * img_scale_factor |
|
lm68s[..., 1] = img_h - lm68s[..., 1] |
|
lms[..., 1] = img_h - lms[..., 1] |
|
lm68s = lm68s.astype(int) |
|
lm68s = lm68s.reshape([-1,2]) |
|
lms = lms.cpu().numpy().astype(int).reshape([-1,2]) |
|
for lm in lm68s: |
|
img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1) |
|
for gt_lm in lms: |
|
img = cv2.circle(img, gt_lm, 2, (255, 0, 0), thickness=1) |
|
imageio.imwrite(debug_name, img) |
|
print(f"debug img saved at {debug_name}") |
|
return coeff_dict |
|
|
|
def out_exist_job(vid_name): |
|
out_name = vid_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png","_coeff_fit_mp.npy") |
|
|
|
if os.path.exists(out_name): |
|
return None |
|
else: |
|
return vid_name |
|
|
|
def get_todo_img_names(img_names): |
|
todo_img_names = [] |
|
for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=16): |
|
if res is not None: |
|
todo_img_names.append(res) |
|
return todo_img_names |
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse, glob, tqdm |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512') |
|
parser.add_argument("--ds_name", default='FFHQ') |
|
parser.add_argument("--seed", default=0, type=int) |
|
parser.add_argument("--process_id", default=0, type=int) |
|
parser.add_argument("--total_process", default=1, type=int) |
|
parser.add_argument("--keypoint_mode", default='mediapipe', type=str) |
|
parser.add_argument("--debug", action='store_true') |
|
parser.add_argument("--reset", action='store_true') |
|
parser.add_argument("--device", default="cuda:0", type=str) |
|
parser.add_argument("--output_log", action='store_true') |
|
parser.add_argument("--load_names", action="store_true") |
|
|
|
args = parser.parse_args() |
|
img_dir = args.img_dir |
|
load_names = args.load_names |
|
|
|
print(f"args {args}") |
|
|
|
if args.ds_name == 'single_img': |
|
img_names = [img_dir] |
|
else: |
|
img_names_path = os.path.join(img_dir, "img_dir.pkl") |
|
if os.path.exists(img_names_path) and load_names: |
|
print(f"loading vid names from {img_names_path}") |
|
img_names = load_file(img_names_path) |
|
else: |
|
if args.ds_name == 'FFHQ_MV': |
|
img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png") |
|
img_names1 = glob.glob(img_name_pattern1) |
|
img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png") |
|
img_names2 = glob.glob(img_name_pattern2) |
|
img_names = img_names1 + img_names2 |
|
img_names = sorted(img_names) |
|
elif args.ds_name == 'FFHQ': |
|
img_name_pattern = os.path.join(img_dir, "*.png") |
|
img_names = glob.glob(img_name_pattern) |
|
img_names = sorted(img_names) |
|
elif args.ds_name == "PanoHeadGen": |
|
img_name_patterns = ["ref/*/*.png"] |
|
img_names = [] |
|
for img_name_pattern in img_name_patterns: |
|
img_name_pattern_full = os.path.join(img_dir, img_name_pattern) |
|
img_names_part = glob.glob(img_name_pattern_full) |
|
img_names.extend(img_names_part) |
|
img_names = sorted(img_names) |
|
print(f"saving image names to {img_names_path}") |
|
save_file(img_names_path, img_names) |
|
|
|
|
|
|
|
|
|
|
|
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM', |
|
camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode) |
|
face_model.to(torch.device(args.device)) |
|
|
|
process_id = args.process_id |
|
total_process = args.total_process |
|
if total_process > 1: |
|
assert process_id <= total_process -1 and process_id >= 0 |
|
num_samples_per_process = len(img_names) // total_process |
|
if process_id == total_process: |
|
img_names = img_names[process_id * num_samples_per_process : ] |
|
else: |
|
img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process] |
|
print(f"image names number (before fileter): {len(img_names)}") |
|
|
|
|
|
if not args.reset: |
|
img_names = get_todo_img_names(img_names) |
|
|
|
print(f"image names number (after fileter): {len(img_names)}") |
|
for i in tqdm.trange(len(img_names), desc=f"process {process_id}: fitting 3dmm ..."): |
|
img_name = img_names[i] |
|
try: |
|
fit_3dmm_for_a_image(img_name, args.debug, device=args.device) |
|
except Exception as e: |
|
print(img_name, e) |
|
if args.output_log and i % max(int(len(img_names) * 0.003), 1) == 0: |
|
print(f"process {process_id}: {i + 1} / {len(img_names)} done") |
|
sys.stdout.flush() |
|
sys.stderr.flush() |
|
|
|
print(f"process {process_id}: fitting 3dmm all done") |
|
|
|
|