|
|
|
|
|
|
|
|
|
|
|
from numpy.core.numeric import require |
|
from numpy.lib.function_base import quantile |
|
import torch |
|
import torch.nn.functional as F |
|
import copy |
|
import numpy as np |
|
|
|
import random |
|
import pickle |
|
import os |
|
import sys |
|
import cv2 |
|
import argparse |
|
import tqdm |
|
from utils.commons.multiprocess_utils import multiprocess_run_tqdm |
|
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker, read_video_to_frames |
|
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel |
|
from deep_3drecon.secc_renderer import SECC_Renderer |
|
from utils.commons.os_utils import multiprocess_glob |
|
|
|
|
|
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM', |
|
camera_distance=10, focal=1015, keypoint_mode='mediapipe') |
|
face_model.to(torch.device("cuda:0")) |
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
|
|
def draw_axes(img, pitch, yaw, roll, tx, ty, size=50): |
|
|
|
pitch = - pitch |
|
roll = - roll |
|
rotation_matrix = cv2.Rodrigues(np.array([pitch, yaw, roll]))[0].astype(np.float64) |
|
axes_points = np.array([ |
|
[1, 0, 0, 0], |
|
[0, 1, 0, 0], |
|
[0, 0, 1, 0] |
|
], dtype=np.float64) |
|
axes_points = rotation_matrix @ axes_points |
|
axes_points = (axes_points[:2, :] * size).astype(int) |
|
axes_points[0, :] = axes_points[0, :] + tx |
|
axes_points[1, :] = axes_points[1, :] + ty |
|
|
|
new_img = img.copy() |
|
cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 0].ravel()), (255, 0, 0), 3) |
|
cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 1].ravel()), (0, 255, 0), 3) |
|
cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 2].ravel()), (0, 0, 255), 3) |
|
return new_img |
|
|
|
def save_file(name, content): |
|
with open(name, "wb") as f: |
|
pickle.dump(content, f) |
|
|
|
def load_file(name): |
|
with open(name, "rb") as f: |
|
content = pickle.load(f) |
|
return content |
|
|
|
def cal_lap_loss(in_tensor): |
|
|
|
t = in_tensor.shape[0] |
|
in_tensor = in_tensor.reshape([t, -1]).permute(1,0).unsqueeze(1) |
|
in_tensor = torch.cat([in_tensor[:, :, 0:1], in_tensor, in_tensor[:, :, -1:]], dim=-1) |
|
lap_kernel = torch.Tensor((-0.5, 1.0, -0.5)).reshape([1,1,3]).float().to(in_tensor.device) |
|
loss_lap = 0 |
|
|
|
out_tensor = F.conv1d(in_tensor, lap_kernel) |
|
loss_lap += torch.mean(out_tensor**2) |
|
return loss_lap |
|
|
|
def cal_vel_loss(ldm): |
|
|
|
vel = ldm[1:] - ldm[:-1] |
|
return torch.mean(torch.abs(vel)) |
|
|
|
def cal_lan_loss(proj_lan, gt_lan): |
|
|
|
loss = (proj_lan - gt_lan)** 2 |
|
|
|
weights = torch.zeros_like(loss) |
|
weights = torch.ones_like(loss) |
|
weights[:, 36:48, :] = 3 |
|
weights[:, -8:, :] = 3 |
|
weights[:, 28:31, :] = 3 |
|
loss = loss * weights |
|
return torch.mean(loss) |
|
|
|
def cal_lan_loss_mp(proj_lan, gt_lan, mean:bool=True): |
|
|
|
loss = (proj_lan - gt_lan).pow(2) |
|
|
|
unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454] |
|
upper_eye = [161,160,159,158,157] + [388,387,386,385,384] |
|
eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] |
|
inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] |
|
outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] |
|
weights = torch.ones_like(loss) |
|
weights[:, eye] = 3 |
|
weights[:, upper_eye] = 20 |
|
weights[:, inner_lip] = 5 |
|
weights[:, outer_lip] = 5 |
|
weights[:, unmatch_mask] = 0 |
|
loss = loss * weights |
|
if mean: |
|
loss = torch.mean(loss) |
|
return loss |
|
|
|
def cal_acceleration_loss(trans): |
|
vel = trans[1:] - trans[:-1] |
|
acc = vel[1:] - vel[:-1] |
|
return torch.mean(torch.abs(acc)) |
|
|
|
def cal_acceleration_ldm_loss(ldm): |
|
|
|
vel = ldm[1:] - ldm[:-1] |
|
acc = vel[1:] - vel[:-1] |
|
lip_weight = 0.25 |
|
acc[48:68] *= lip_weight |
|
return torch.mean(torch.abs(acc)) |
|
|
|
def set_requires_grad(tensor_list): |
|
for tensor in tensor_list: |
|
tensor.requires_grad = True |
|
|
|
@torch.enable_grad() |
|
def fit_3dmm_for_a_video( |
|
video_name, |
|
nerf=False, |
|
id_mode='global', |
|
debug=False, |
|
keypoint_mode='mediapipe', |
|
large_yaw_threshold=9999999.9, |
|
save=True |
|
) -> bool: |
|
assert video_name.endswith(".mp4"), "this function only support video as input" |
|
if id_mode == 'global': |
|
LAMBDA_REG_ID = 0.2 |
|
LAMBDA_REG_EXP = 0.6 |
|
LAMBDA_REG_LAP = 1.0 |
|
LAMBDA_REG_VEL_ID = 0.0 |
|
LAMBDA_REG_VEL_EXP = 0.0 |
|
else: |
|
LAMBDA_REG_ID = 0.3 |
|
LAMBDA_REG_EXP = 0.05 |
|
LAMBDA_REG_LAP = 1.0 |
|
LAMBDA_REG_VEL_ID = 0.0 |
|
LAMBDA_REG_VEL_EXP = 0.0 |
|
|
|
frames = read_video_to_frames(video_name) |
|
img_h, img_w = frames.shape[1], frames.shape[2] |
|
assert img_h == img_w |
|
num_frames = len(frames) |
|
|
|
if nerf: |
|
lm_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy") |
|
else: |
|
lm_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4", "_lms.npy") |
|
|
|
if os.path.exists(lm_name): |
|
lms = np.load(lm_name) |
|
else: |
|
print(f"lms_2d file not found, try to extract it from video... {lm_name}") |
|
try: |
|
landmarker = MediapipeLandmarker() |
|
img_lm478, vid_lm478 = landmarker.extract_lm478_from_frames(frames, anti_smooth_factor=20) |
|
lms = landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478) |
|
except Exception as e: |
|
print(e) |
|
return False |
|
if lms is None: |
|
print(f"get None lms_2d, please check whether each frame has one head, exiting... {lm_name}") |
|
return False |
|
lms = lms[:, :468, :] |
|
lms = torch.FloatTensor(lms).cuda() |
|
lms[..., 1] = img_h - lms[..., 1] |
|
|
|
if keypoint_mode == 'mediapipe': |
|
|
|
cal_lan_loss_fn = cal_lan_loss_mp |
|
if nerf: |
|
out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/coeff_fit_mp.npy") |
|
else: |
|
out_name = video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4", "_coeff_fit_mp.npy") |
|
else: |
|
|
|
cal_lan_loss_fn = cal_lan_loss |
|
if nerf: |
|
out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "_coeff_fit_lm68.npy") |
|
else: |
|
out_name = video_name.replace("/video/", "/coeff_fit_lm68/").replace(".mp4", "_coeff_fit_lm68.npy") |
|
try: |
|
os.makedirs(os.path.dirname(out_name), exist_ok=True) |
|
except: |
|
pass |
|
|
|
id_dim, exp_dim = 80, 64 |
|
sel_ids = np.arange(0, num_frames, 40) |
|
|
|
h = w = face_model.center * 2 |
|
img_scale_factor = img_h / h |
|
lms /= img_scale_factor |
|
|
|
if id_mode == 'global': |
|
|
|
id_para = lms.new_zeros((1, id_dim), requires_grad=True) |
|
elif id_mode == 'finegrained': |
|
|
|
id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True) |
|
else: raise NotImplementedError(f"id mode {id_mode} not supported! we only support global or finegrained.") |
|
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True) |
|
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True) |
|
trans = lms.new_zeros((num_frames, 3), requires_grad=True) |
|
|
|
set_requires_grad([id_para, exp_para, euler_angle, trans]) |
|
|
|
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1) |
|
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1) |
|
|
|
|
|
for _ in range(200): |
|
if id_mode == 'global': |
|
proj_geo = face_model.compute_for_landmark_fit( |
|
id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans) |
|
else: |
|
proj_geo = face_model.compute_for_landmark_fit( |
|
id_para, exp_para, euler_angle, trans) |
|
loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach()) |
|
loss = loss_lan |
|
optimizer_frame.zero_grad() |
|
loss.backward() |
|
optimizer_frame.step() |
|
|
|
|
|
|
|
|
|
for param_group in optimizer_frame.param_groups: |
|
param_group['lr'] = 0.1 |
|
|
|
|
|
for _ in range(200): |
|
ret = {} |
|
if id_mode == 'global': |
|
proj_geo = face_model.compute_for_landmark_fit( |
|
id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans, ret) |
|
else: |
|
proj_geo = face_model.compute_for_landmark_fit( |
|
id_para, exp_para, euler_angle, trans, ret) |
|
loss_lan = cal_lan_loss_fn( |
|
proj_geo[:, :, :2], lms.detach()) |
|
|
|
|
|
loss_lap = cal_lap_loss(id_para) + cal_lap_loss(exp_para) + cal_lap_loss(euler_angle) * 0.3 + cal_lap_loss(trans) * 0.3 |
|
|
|
loss_regid = torch.mean(id_para*id_para) |
|
loss_regexp = torch.mean(exp_para * exp_para) |
|
|
|
loss_vel_id = cal_vel_loss(id_para) |
|
loss_vel_exp = cal_vel_loss(exp_para) |
|
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP + loss_lap * LAMBDA_REG_LAP |
|
optimizer_idexp.zero_grad() |
|
optimizer_frame.zero_grad() |
|
loss.backward() |
|
optimizer_idexp.step() |
|
optimizer_frame.step() |
|
|
|
|
|
|
|
|
|
|
|
|
|
if id_mode == 'global': |
|
id_para_ = lms.new_zeros((1, id_dim), requires_grad=False) |
|
else: |
|
id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True) |
|
id_para_.data = id_para.data.clone() |
|
id_para = id_para_ |
|
exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True) |
|
exp_para_.data = exp_para.data.clone() |
|
exp_para = exp_para_ |
|
euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True) |
|
euler_angle_.data = euler_angle.data.clone() |
|
euler_angle = euler_angle_ |
|
trans_ = lms.new_zeros((num_frames, 3), requires_grad=True) |
|
trans_.data = trans.data.clone() |
|
trans = trans_ |
|
|
|
batch_size = 50 |
|
|
|
for i in range(int((num_frames-1)/batch_size+1)): |
|
if (i+1)*batch_size > num_frames: |
|
start_n = num_frames-batch_size |
|
sel_ids = np.arange(max(num_frames-batch_size,0), num_frames) |
|
else: |
|
start_n = i*batch_size |
|
sel_ids = np.arange(i*batch_size, i*batch_size+batch_size) |
|
sel_lms = lms[sel_ids] |
|
|
|
if id_mode == 'global': |
|
sel_id_para = id_para.expand((sel_ids.shape[0], id_dim)) |
|
else: |
|
sel_id_para = id_para.new_zeros((batch_size, id_dim), requires_grad=True) |
|
sel_id_para.data = id_para[sel_ids].clone() |
|
sel_exp_para = exp_para.new_zeros( |
|
(batch_size, exp_dim), requires_grad=True) |
|
sel_exp_para.data = exp_para[sel_ids].clone() |
|
sel_euler_angle = euler_angle.new_zeros( |
|
(batch_size, 3), requires_grad=True) |
|
sel_euler_angle.data = euler_angle[sel_ids].clone() |
|
sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True) |
|
sel_trans.data = trans[sel_ids].clone() |
|
|
|
if id_mode == 'global': |
|
set_requires_grad([sel_exp_para, sel_euler_angle, sel_trans]) |
|
optimizer_cur_batch = torch.optim.Adam( |
|
[sel_exp_para, sel_euler_angle, sel_trans], lr=0.005) |
|
else: |
|
set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans]) |
|
optimizer_cur_batch = torch.optim.Adam( |
|
[sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005) |
|
|
|
for j in range(50): |
|
ret = {} |
|
proj_geo = face_model.compute_for_landmark_fit( |
|
sel_id_para, sel_exp_para, sel_euler_angle, sel_trans, ret) |
|
loss_lan = cal_lan_loss_fn( |
|
proj_geo[:, :, :2], lms[sel_ids].detach()) |
|
|
|
|
|
loss_lap = cal_lap_loss(sel_id_para) + cal_lap_loss(sel_exp_para) + cal_lap_loss(sel_euler_angle) * 0.3 + cal_lap_loss(sel_trans) * 0.3 |
|
loss_vel_id = cal_vel_loss(sel_id_para) |
|
loss_vel_exp = cal_vel_loss(sel_exp_para) |
|
log_dict = { |
|
'loss_vel_id': loss_vel_id, |
|
'loss_vel_exp': loss_vel_exp, |
|
'loss_vel_euler': cal_vel_loss(sel_euler_angle), |
|
'loss_vel_trans': cal_vel_loss(sel_trans), |
|
} |
|
loss_regid = torch.mean(sel_id_para*sel_id_para) |
|
loss_regexp = torch.mean(sel_exp_para*sel_exp_para) |
|
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_lap * LAMBDA_REG_LAP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP |
|
|
|
optimizer_cur_batch.zero_grad() |
|
loss.backward() |
|
optimizer_cur_batch.step() |
|
|
|
if debug: |
|
print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},loss_lap_ldm:{loss_lap.item():.4f}") |
|
print("|--------" + ', '.join([f"{k}: {v:.4f}" for k,v in log_dict.items()])) |
|
if id_mode != 'global': |
|
id_para[sel_ids].data = sel_id_para.data.clone() |
|
exp_para[sel_ids].data = sel_exp_para.data.clone() |
|
euler_angle[sel_ids].data = sel_euler_angle.data.clone() |
|
trans[sel_ids].data = sel_trans.data.clone() |
|
|
|
coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(), |
|
'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if debug: |
|
import imageio |
|
from utils.visualization.vis_cam3d.camera_pose_visualizer import CameraPoseVisualizer |
|
from data_util.face3d_helper import Face3DHelper |
|
from data_gen.utils.process_video.extract_blink import get_eye_area_percent |
|
face3d_helper = Face3DHelper('deep_3drecon/BFM', keypoint_mode='mediapipe') |
|
|
|
t = coeff_dict['exp'].shape[0] |
|
if len(coeff_dict['id']) == 1: |
|
coeff_dict['id'] = np.repeat(coeff_dict['id'], t, axis=0) |
|
idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d_np(coeff_dict['id'], coeff_dict['exp']).reshape([t, -1]) |
|
cano_lm3d = idexp_lm3d / 10 + face3d_helper.key_mean_shape.squeeze().reshape([1, -1]).cpu().numpy() |
|
cano_lm3d = cano_lm3d.reshape([t, -1, 3]) |
|
WH = 512 |
|
cano_lm3d = (cano_lm3d * WH/2 + WH/2).astype(int) |
|
|
|
with torch.no_grad(): |
|
rot = ParametricFaceModel.compute_rotation(euler_angle) |
|
extrinsic = torch.zeros([rot.shape[0], 4, 4]).to(rot.device) |
|
extrinsic[:, :3,:3] = rot |
|
extrinsic[:, :3, 3] = trans |
|
extrinsic[:, 3, 3] = 1 |
|
extrinsic = extrinsic.cpu().numpy() |
|
|
|
xy_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xy') |
|
xz_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xz') |
|
|
|
if nerf: |
|
debug_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/debug_fit_3dmm.mp4") |
|
else: |
|
debug_name = video_name.replace("/video/", "/coeff_fit_debug/").replace(".mp4", "_debug.mp4") |
|
try: |
|
os.makedirs(os.path.dirname(debug_name), exist_ok=True) |
|
except: pass |
|
writer = imageio.get_writer(debug_name, fps=25) |
|
if id_mode == 'global': |
|
id_para = id_para.repeat([exp_para.shape[0], 1]) |
|
proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans) |
|
lm68s = proj_geo[:,:,:2].detach().cpu().numpy() |
|
lm68s = lm68s * img_scale_factor |
|
lms = lms * img_scale_factor |
|
lm68s[..., 1] = img_h - lm68s[..., 1] |
|
lms[..., 1] = img_h - lms[..., 1] |
|
lm68s = lm68s.astype(int) |
|
for i in tqdm.trange(min(250, len(frames)), desc=f'rendering debug video to {debug_name}..'): |
|
xy_cam3d_img = xy_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25) |
|
xy_cam3d_img = cv2.resize(xy_cam3d_img, (512,512)) |
|
xz_cam3d_img = xz_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25) |
|
xz_cam3d_img = cv2.resize(xz_cam3d_img, (512,512)) |
|
|
|
img = copy.deepcopy(frames[i]) |
|
img2 = copy.deepcopy(frames[i]) |
|
|
|
img = draw_axes(img, euler_angle[i,0].item(), euler_angle[i,1].item(), euler_angle[i,2].item(), lm68s[i][4][0].item(), lm68s[i, 4][1].item(), size=50) |
|
|
|
gt_lm_color = (255, 0, 0) |
|
|
|
for lm in lm68s[i]: |
|
img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1) |
|
for gt_lm in lms[i]: |
|
img2 = cv2.circle(img2, gt_lm.cpu().numpy().astype(int), 2, gt_lm_color, thickness=1) |
|
|
|
cano_lm3d_img = np.ones([WH, WH, 3], dtype=np.uint8) * 255 |
|
for j in range(len(cano_lm3d[i])): |
|
x, y, _ = cano_lm3d[i, j] |
|
color = (255,0,0) |
|
cano_lm3d_img = cv2.circle(cano_lm3d_img, center=(x,y), radius=3, color=color, thickness=-1) |
|
cano_lm3d_img = cv2.flip(cano_lm3d_img, 0) |
|
|
|
_, secc_img = secc_renderer(id_para[0:1], exp_para[i:i+1], euler_angle[i:i+1]*0, trans[i:i+1]*0) |
|
secc_img = (secc_img +1)*127.5 |
|
secc_img = F.interpolate(secc_img, size=(img_h, img_w)) |
|
secc_img = secc_img.permute(0, 2,3,1).int().cpu().numpy()[0] |
|
out_img1 = np.concatenate([img, img2, secc_img], axis=1).astype(np.uint8) |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
out_img2 = np.concatenate([xy_cam3d_img, xz_cam3d_img, cano_lm3d_img], axis=1).astype(np.uint8) |
|
out_img = np.concatenate([out_img1, out_img2], axis=0) |
|
writer.append_data(out_img) |
|
writer.close() |
|
|
|
|
|
|
|
|
|
|
|
if save: |
|
np.save(out_name, coeff_dict, allow_pickle=True) |
|
return coeff_dict |
|
|
|
def out_exist_job(vid_name): |
|
out_name = vid_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy") |
|
lms_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") |
|
if os.path.exists(out_name) or not os.path.exists(lms_name): |
|
return None |
|
else: |
|
return vid_name |
|
|
|
def get_todo_vid_names(vid_names): |
|
if len(vid_names) == 1: |
|
return vid_names |
|
todo_vid_names = [] |
|
for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16): |
|
if res is not None: |
|
todo_vid_names.append(res) |
|
return todo_vid_names |
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse, glob, tqdm |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--vid_dir", default='data/raw/videos/May_10s.mp4') |
|
parser.add_argument("--ds_name", default='nerf') |
|
parser.add_argument("--seed", default=0, type=int) |
|
parser.add_argument("--process_id", default=0, type=int) |
|
parser.add_argument("--total_process", default=1, type=int) |
|
parser.add_argument("--id_mode", default='global', type=str) |
|
parser.add_argument("--keypoint_mode", default='mediapipe', type=str) |
|
parser.add_argument("--large_yaw_threshold", default=9999999.9, type=float) |
|
parser.add_argument("--debug", action='store_true') |
|
parser.add_argument("--reset", action='store_true') |
|
parser.add_argument("--load_names", action="store_true") |
|
|
|
args = parser.parse_args() |
|
vid_dir = args.vid_dir |
|
ds_name = args.ds_name |
|
load_names = args.load_names |
|
|
|
print(f"args {args}") |
|
|
|
if ds_name.lower() == 'nerf': |
|
vid_names = [vid_dir] |
|
out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names] |
|
else: |
|
if ds_name in ['lrs3_trainval']: |
|
vid_name_pattern = os.path.join(vid_dir, "*/*.mp4") |
|
elif ds_name in ['TH1KH_512', 'CelebV-HQ']: |
|
vid_name_pattern = os.path.join(vid_dir, "*.mp4") |
|
elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']: |
|
vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4") |
|
elif ds_name in ["RAVDESS", 'VFHQ']: |
|
vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4") |
|
else: |
|
raise NotImplementedError() |
|
|
|
vid_names_path = os.path.join(vid_dir, "vid_names.pkl") |
|
if os.path.exists(vid_names_path) and load_names: |
|
print(f"loading vid names from {vid_names_path}") |
|
vid_names = load_file(vid_names_path) |
|
else: |
|
vid_names = multiprocess_glob(vid_name_pattern) |
|
vid_names = sorted(vid_names) |
|
print(f"saving vid names to {vid_names_path}") |
|
save_file(vid_names_path, vid_names) |
|
out_names = [video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names] |
|
|
|
print(vid_names[:10]) |
|
random.seed(args.seed) |
|
random.shuffle(vid_names) |
|
|
|
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM', |
|
camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode) |
|
face_model.to(torch.device("cuda:0")) |
|
secc_renderer = SECC_Renderer(512) |
|
secc_renderer.to("cuda:0") |
|
|
|
process_id = args.process_id |
|
total_process = args.total_process |
|
if total_process > 1: |
|
assert process_id <= total_process -1 |
|
num_samples_per_process = len(vid_names) // total_process |
|
if process_id == total_process: |
|
vid_names = vid_names[process_id * num_samples_per_process : ] |
|
else: |
|
vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process] |
|
|
|
if not args.reset: |
|
vid_names = get_todo_vid_names(vid_names) |
|
|
|
failed_img_names = [] |
|
for i in tqdm.trange(len(vid_names), desc=f"process {process_id}: fitting 3dmm ..."): |
|
img_name = vid_names[i] |
|
try: |
|
is_person_specific_data = ds_name=='nerf' |
|
success = fit_3dmm_for_a_video(img_name, is_person_specific_data, args.id_mode, args.debug, large_yaw_threshold=args.large_yaw_threshold) |
|
if not success: |
|
failed_img_names.append(img_name) |
|
except Exception as e: |
|
print(img_name, e) |
|
failed_img_names.append(img_name) |
|
print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {len(failed_img_names)} / {i + 1} = {len(failed_img_names) / (i + 1):.4f}") |
|
sys.stdout.flush() |
|
print(f"all failed image names: {failed_img_names}") |
|
print(f"All finished!") |