EMAGE / camn_trainer.py
H-Liu1997's picture
Upload folder using huggingface_hub
2d47d90 verified
import train
import os
import time
import csv
import sys
import warnings
import random
import numpy as np
import time
import pprint
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from loguru import logger
import smplx
import librosa
from utils import config, logger_tools, other_tools, metric
from utils import rotation_conversions as rc
from dataloaders import data_tools
from optimizers.optim_factory import create_optimizer
from optimizers.scheduler_factory import create_scheduler
from optimizers.loss_factory import get_loss_func
from scipy.spatial.transform import Rotation
class CustomTrainer(train.BaseTrainer):
def __init__(self, args):
super().__init__(args)
self.joints = self.train_data.joints
self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'div_reg', "kl"], [False,True,True, False, False, False, False, False, False, False, False, False, False])
if not self.args.rot6d: #"rot6d" not in args.pose_rep:
logger.error(f"this script is for rot6d, your pose rep. is {args.pose_rep}")
self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank)
self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank)
def _load_data(self, dict_data):
tar_pose = dict_data["pose"].to(self.rank)
tar_trans = dict_data["trans"].to(self.rank)
tar_exps = dict_data["facial"].to(self.rank)
tar_beta = dict_data["beta"].to(self.rank)
tar_id = dict_data["id"].to(self.rank).long()
tar_word = dict_data["word"].to(self.rank)
in_audio = dict_data["audio"].to(self.rank)
in_emo = dict_data["emo"].to(self.rank)
#in_sem = dict_data["sem"].to(self.rank)
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
in_pre_pose_cat = torch.cat([tar_pose[:, 0:self.args.pre_frames], tar_trans[:, :self.args.pre_frames]], dim=2).to(self.rank)
in_pre_pose = tar_pose.new_zeros((bs, n, j*6+1+3)).to(self.rank)
in_pre_pose[:, 0:self.args.pre_frames, :-1] = in_pre_pose_cat[:, 0:self.args.pre_frames]
in_pre_pose[:, 0:self.args.pre_frames, -1] = 1
return {
"tar_pose": tar_pose,
"in_audio": in_audio,
"in_motion": in_pre_pose,
"tar_trans": tar_trans,
"tar_exps": tar_exps,
"tar_beta": tar_beta,
"tar_word": tar_word,
'tar_id': tar_id,
'in_emo': in_emo,
#'in_sem': in_sem,
}
def _d_training(self, loaded_data):
bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints
net_out = self.model(in_audio = loaded_data['in_audio'], pre_seq = loaded_data["in_motion"], in_text=loaded_data["tar_word"], in_id=loaded_data["tar_id"], in_emo=loaded_data["in_emo"], in_facial = loaded_data["tar_exps"])
rec_pose = net_out["rec_pose"][:, :, :j*6]
# rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3]
rec_pose = rec_pose.reshape(bs, n, j, 6)
rec_pose = rc.rotation_6d_to_matrix(rec_pose)
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
tar_pose = rc.rotation_6d_to_matrix(loaded_data["tar_pose"].reshape(bs, n, j, 6))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
out_d_fake = self.d_model(rec_pose)
out_d_real = self.d_model(tar_pose)
d_loss_adv = torch.sum(-torch.mean(torch.log(out_d_real + 1e-8) + torch.log(1 - out_d_fake + 1e-8)))
self.tracker.update_meter("dis", "train", d_loss_adv.item())
return d_loss_adv
def _g_training(self, loaded_data, use_adv, mode="train"):
bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints
net_out = self.model(in_audio = loaded_data['in_audio'], pre_seq = loaded_data["in_motion"], in_text=loaded_data["tar_word"], in_id=loaded_data["tar_id"], in_emo=loaded_data["in_emo"], in_facial = loaded_data["tar_exps"])
rec_pose = net_out["rec_pose"][:, :, :j*6]
rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3]
# print(rec_pose.shape, bs, n, j, loaded_data['in_audio'].shape, loaded_data["in_motion"].shape)
rec_pose = rec_pose.reshape(bs, n, j, 6)
rec_pose = rc.rotation_6d_to_matrix(rec_pose)
tar_pose = rc.rotation_6d_to_matrix(loaded_data["tar_pose"].reshape(bs, n, j, 6))
rec_loss = self.rec_loss(tar_pose, rec_pose)
rec_loss *= self.args.rec_weight
self.tracker.update_meter("rec", mode, rec_loss.item())
# rec_loss_vel = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1])
# self.tracker.update_meter("vel", mode, rec_loss_vel.item())
# rec_loss_acc = self.vel_loss(rec_pose[:, 2:] - 2*rec_pose[:, 1:-1] + rec_pose[:, :-2], tar_pose[:, 2:] - 2*tar_pose[:, 1:-1] + tar_pose[:, :-2])
# self.tracker.update_meter("acc", mode, rec_loss_acc.item())
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
if self.args.pose_dims < 330 and mode != "train":
rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs, n, j, 6))
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs, n, j*3)
rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0])
rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, 55, 3))
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, 55*6)
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6))
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs, n, j*3)
tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0])
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, 55*6)
if use_adv and mode == 'train':
out_d_fake = self.d_model(rec_pose)
d_loss_adv = -torch.mean(torch.log(out_d_fake + 1e-8))
self.tracker.update_meter("gen", mode, d_loss_adv.item())
else:
d_loss_adv = 0
if self.args.train_trans:
trans_loss = self.vel_loss(rec_trans, loaded_data["tar_trans"])
trans_loss *= self.args.rec_weight
self.tracker.update_meter("trans", mode, trans_loss.item())
else:
trans_loss = 0
# trans_loss_vel = self.vel_loss(rec_trans[:, 1:] - rec_trans[:, :-1], loaded_data["tar_trans"][:, 1:] - loaded_data["tar_trans"][:, :-1])
# self.tracker.update_meter("transv", mode, trans_loss_vel.item())
# trans_loss_acc = self.vel_loss(rec_trans[:, 2:] - 2*rec_trans[:, 1:-1] + rec_trans[:, :-2], loaded_data["tar_trans"][:, 2:] - 2*loaded_data["tar_trans"][:, 1:-1] + loaded_data["tar_trans"][:, :-2])
# self.tracker.update_meter("transa", mode, trans_loss_acc.item())
if mode == 'train':
return d_loss_adv + rec_loss + trans_loss # + rec_loss_vel + rec_loss_acc + trans_loss_vel + trans_loss_acc
elif mode == 'val':
return {
'rec_pose': rec_pose,
'rec_trans': rec_trans,
'tar_pose': tar_pose,
}
else:
return {
'rec_pose': rec_pose,
'rec_trans': rec_trans,
'tar_pose': tar_pose,
'tar_exps': loaded_data["tar_exps"],
'tar_beta': loaded_data["tar_beta"],
'tar_trans': loaded_data["tar_trans"],
}
def train(self, epoch):
use_adv = bool(epoch>=self.args.no_adv_epoch)
self.model.train()
self.d_model.train()
self.tracker.reset()
t_start = time.time()
for its, batch_data in enumerate(self.train_loader):
loaded_data = self._load_data(batch_data)
t_data = time.time() - t_start
if use_adv:
d_loss_final = 0
self.opt_d.zero_grad()
d_loss_adv = self._d_training(loaded_data)
d_loss_final += d_loss_adv
d_loss_final.backward()
self.opt_d.step()
self.opt.zero_grad()
g_loss_final = 0
g_loss_final += self._g_training(loaded_data, use_adv, 'train')
g_loss_final.backward()
self.opt.step()
mem_cost = torch.cuda.memory_cached() / 1E9
lr_g = self.opt.param_groups[0]['lr']
lr_d = self.opt_d.param_groups[0]['lr']
t_train = time.time() - t_start - t_data
t_start = time.time()
if its % self.args.log_period == 0:
self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g, lr_d=lr_d)
if self.args.debug:
if its == 1: break
self.opt_s.step(epoch)
self.opt_d_s.step(epoch)
def val(self, epoch):
self.model.eval()
self.d_model.eval()
with torch.no_grad():
for its, batch_data in enumerate(self.train_loader):
loaded_data = self._load_data(batch_data)
net_out = self._g_training(loaded_data, False, 'val')
tar_pose = net_out['tar_pose']
rec_pose = net_out['rec_pose']
n = tar_pose.shape[1]
if (30/self.args.pose_fps) != 1:
assert 30%self.args.pose_fps == 0
n *= int(30/self.args.pose_fps)
tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
n = tar_pose.shape[1]
remain = n%self.args.vae_test_len
tar_pose = tar_pose[:, :n-remain, :]
rec_pose = rec_pose[:, :n-remain, :]
latent_out = self.eval_copy.map2latent(rec_pose).reshape(-1, self.args.vae_length).cpu().numpy()
latent_ori = self.eval_copy.map2latent(tar_pose).reshape(-1, self.args.vae_length).cpu().numpy()
if its == 0:
latent_out_motion_all = latent_out
latent_ori_all = latent_ori
else:
latent_out_motion_all = np.concatenate([latent_out_motion_all, latent_out], axis=0)
latent_ori_all = np.concatenate([latent_ori_all, latent_ori], axis=0)
if self.args.debug:
if its == 1: break
fid_motion = data_tools.FIDCalculator.frechet_distance(latent_out_motion_all, latent_ori_all)
self.tracker.update_meter("fid", "val", fid_motion)
self.val_recording(epoch)
def test(self, epoch):
results_save_path = self.checkpoint_path + f"/{epoch}/"
if os.path.exists(results_save_path):
return 0
os.makedirs(results_save_path)
start_time = time.time()
total_length = 0
test_seq_list = self.test_data.selected_file
align = 0
latent_out = []
latent_ori = []
self.model.eval()
self.smplx.eval()
self.eval_copy.eval()
with torch.no_grad():
for its, batch_data in enumerate(self.test_loader):
loaded_data = self._load_data(batch_data)
net_out = self._g_training(loaded_data, False, 'test')
tar_pose = net_out['tar_pose']
rec_pose = net_out['rec_pose']
tar_exps = net_out['tar_exps']
tar_beta = net_out['tar_beta']
rec_trans = net_out['rec_trans']
tar_trans = net_out['tar_trans']
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], 55
if (30/self.args.pose_fps) != 1:
assert 30%self.args.pose_fps == 0
n *= int(30/self.args.pose_fps)
tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
tar_beta = torch.nn.functional.interpolate(tar_beta.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
tar_exps = torch.nn.functional.interpolate(tar_exps.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
tar_trans = torch.nn.functional.interpolate(tar_trans.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
rec_trans = torch.nn.functional.interpolate(rec_trans.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
# print(rec_pose.shape, tar_pose.shape)
# rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6))
# rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
# tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6))
# tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
remain = n%self.args.vae_test_len
latent_out.append(self.eval_copy.map2latent(rec_pose[:, :n-remain]).reshape(-1, self.args.vae_length).detach().cpu().numpy()) # bs * n/8 * 240
latent_ori.append(self.eval_copy.map2latent(tar_pose[:, :n-remain]).reshape(-1, self.args.vae_length).detach().cpu().numpy())
rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6))
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6))
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
vertices_rec = self.smplx(
betas=tar_beta.reshape(bs*n, 300),
transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3),
expression=tar_exps.reshape(bs*n, 100)-tar_exps.reshape(bs*n, 100),
jaw_pose=rec_pose[:, 66:69],
global_orient=rec_pose[:,:3],
body_pose=rec_pose[:,3:21*3+3],
left_hand_pose=rec_pose[:,25*3:40*3],
right_hand_pose=rec_pose[:,40*3:55*3],
return_joints=True,
leye_pose=rec_pose[:, 69:72],
reye_pose=rec_pose[:, 72:75],
)
# vertices_tar = self.smplx(
# betas=tar_beta.reshape(bs*n, 300),
# transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3),
# expression=tar_exps.reshape(bs*n, 100)-tar_exps.reshape(bs*n, 100),
# jaw_pose=tar_pose[:, 66:69],
# global_orient=tar_pose[:,:3],
# body_pose=tar_pose[:,3:21*3+3],
# left_hand_pose=tar_pose[:,25*3:40*3],
# right_hand_pose=tar_pose[:,40*3:55*3],
# return_joints=True,
# leye_pose=tar_pose[:, 69:72],
# reye_pose=tar_pose[:, 72:75],
# )
joints_rec = vertices_rec["joints"].detach().cpu().numpy().reshape(1, n, 127*3)[0, :n, :55*3]
# joints_tar = vertices_tar["joints"].detach().cpu().numpy().reshape(1, n, 127*3)[0, :n, :55*3]
_ = self.l1_calculator.run(joints_rec)
if self.alignmenter is not None:
in_audio_eval, sr = librosa.load(self.args.data_path+"wave16k/"+test_seq_list.iloc[its]['id']+".wav")
in_audio_eval = librosa.resample(in_audio_eval, orig_sr=sr, target_sr=self.args.audio_sr)
a_offset = int(self.align_mask * (self.args.audio_sr / self.args.pose_fps))
onset_bt = self.alignmenter.load_audio(in_audio_eval[:int(self.args.audio_sr / self.args.pose_fps*n)], a_offset, len(in_audio_eval)-a_offset, True)
beat_vel = self.alignmenter.load_pose(joints_rec, self.align_mask, n-self.align_mask, 30, True)
# print(beat_vel)
align += (self.alignmenter.calculate_align(onset_bt, beat_vel, 30) * (n-2*self.align_mask))
tar_pose_axis_np = tar_pose.detach().cpu().numpy()
rec_pose_axis_np = rec_pose.detach().cpu().numpy()
rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3)
rec_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) - tar_exps.detach().cpu().numpy().reshape(bs*n, 100)
tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) - tar_exps.detach().cpu().numpy().reshape(bs*n, 100)
tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3)
gt_npz = np.load(self.args.data_path+self.args.pose_rep +"/"+test_seq_list.iloc[its]['id']+".npz", allow_pickle=True)
if not self.args.train_trans:
tar_trans_np = tar_trans_np - tar_trans_np
rec_trans_np = rec_trans_np - rec_trans_np
np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz',
betas=gt_npz["betas"],
poses=tar_pose_axis_np,
expressions=tar_exp_np,
trans=tar_trans_np,
model='smplx2020',
gender='neutral',
mocap_frame_rate = 30 ,
)
np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz',
betas=gt_npz["betas"],
poses=rec_pose_axis_np,
expressions=rec_exp_np,
trans=rec_trans_np,
model='smplx2020',
gender='neutral',
mocap_frame_rate = 30,
)
total_length += n
latent_out_all = np.concatenate(latent_out, axis=0)
latent_ori_all = np.concatenate(latent_ori, axis=0)
fid = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all)
logger.info(f"fid score: {fid}")
self.test_recording("fid", fid, epoch)
align_avg = align/(total_length-2*len(self.test_loader)*self.align_mask)
logger.info(f"align score: {align_avg}")
self.test_recording("bc", align_avg, epoch)
l1div = self.l1_calculator.avg()
logger.info(f"l1div score: {l1div}")
self.test_recording("l1div", l1div, epoch)
# data_tools.result2target_vis(self.args.pose_version, results_save_path, results_save_path, self.test_demo, False)
end_time = time.time() - start_time
logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")