Spaces:
Running
Running
import os | |
# print cwd | |
import sys | |
from os.path import join as pjoin | |
sys.path.append(os.getcwd()) | |
import torch | |
import torch.distributed as dist | |
import utils.paramUtil as paramUtil | |
import wandb | |
from datasets import Text2MotionDataset | |
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel | |
from mmcv.runner import get_dist_info, init_dist | |
from models import MotionTransformer | |
from options.train_options import TrainCompOptions | |
from trainers import DDPMTrainer | |
from utils.plot_script import * | |
from utils.utils import * | |
def build_models(opt, dim_pose): | |
encoder = MotionTransformer( | |
input_feats=dim_pose, | |
num_frames=opt.max_motion_length, | |
num_layers=opt.num_layers, | |
latent_dim=opt.latent_dim, | |
no_clip=opt.no_clip, | |
no_eff=opt.no_eff) | |
return encoder | |
if __name__ == '__main__': | |
parser = TrainCompOptions() | |
opt = parser.parse() | |
rank, world_size = get_dist_info() | |
print(f"setting random seed to {opt.seed}") | |
set_random_seed(opt.seed) | |
opt.device = torch.device("cuda") | |
torch.autograd.set_detect_anomaly(True) | |
print(f"device id: {torch.cuda.current_device()}") | |
print(f"selected device ids: {opt.gpu_id}") | |
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) | |
opt.model_dir = pjoin(opt.save_root, 'model') | |
opt.meta_dir = pjoin(opt.save_root, 'meta') | |
opt.noise_dir = pjoin(opt.save_root, 'noise') | |
if rank == 0: | |
os.makedirs(opt.model_dir, exist_ok=True) | |
os.makedirs(opt.meta_dir, exist_ok=True) | |
os.makedirs(opt.noise_dir, exist_ok=True) | |
if world_size > 1: | |
dist.barrier() | |
if opt.use_wandb: | |
wandb_id = wandb.util.generate_id() | |
wandb.init( | |
project="text2motion", | |
name=f"{opt.experiment_name}", | |
entity=opt.wandb_user, | |
# notes=opt.EXPERIMENT_NOTE, | |
config=opt, | |
id=wandb_id, | |
resume="allow", | |
# monitor_gym=True, | |
sync_tensorboard=True, | |
) | |
# opt.wandb = wandb | |
if opt.dataset_name == 't2m': | |
opt.data_root = './data/HumanML3D' | |
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') | |
opt.text_dir = pjoin(opt.data_root, 'texts') | |
opt.joints_num = 22 | |
radius = 4 | |
fps = 20 | |
opt.max_motion_length = 196 | |
dim_pose = 263 | |
kinematic_chain = paramUtil.t2m_kinematic_chain | |
elif opt.dataset_name == 'grab': | |
opt.data_root = './data/GRAB' | |
opt.motion_dir = pjoin(opt.data_root, 'joints') | |
opt.text_dir = pjoin(opt.data_root, 'texts') | |
opt.face_text_dir = pjoin(opt.data_root, 'face_texts') | |
opt.joints_num = 72 # TODO (elmc): verify this BUT ALSO I'M NOT USING IT FOR NOW! | |
# radius = 4 # TODO (elmc): verify this, think it's only for visualization purposes | |
# fps = 20 # TODO (elmc): verify this, also for visualization I think | |
dim_pose = 212 # drop betas (body shape) and face-shape from Motion data (via to_smplx_params & smplx_dict_to_array method) | |
opt.dim_pose = dim_pose | |
opt.max_motion_length = 196 # TODO (elmc): verify this; do this dynamically..?? | |
# TODO (elmc): verify what this does and if we can use the t2m one | |
# NOTE: think, again, it's only for visualization | |
# kinematic_chain = paramUtil.t2m_kinematic_chain | |
# kinematic_chain = paramUtil.grab_kinematic_chain | |
print(f"loading data root: {opt.data_root}") | |
# print(f"kinematic chain: {kinematic_chain}") | |
elif opt.dataset_name == 'kit': | |
opt.data_root = './data/KIT-ML' | |
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') | |
opt.text_dir = pjoin(opt.data_root, 'texts') | |
opt.joints_num = 21 | |
radius = 240 * 8 | |
fps = 12.5 | |
dim_pose = 251 | |
opt.max_motion_length = 196 | |
kinematic_chain = paramUtil.kit_kinematic_chain | |
else: | |
raise KeyError('Dataset Does Not Exist') | |
# TODO (elmc): check dim_word and add back in??? | |
# dim_word = 300 | |
mean = np.load(pjoin(opt.data_root, 'Mean.npy')) | |
std = np.load(pjoin(opt.data_root, 'Std.npy')) | |
train_split_file = pjoin(opt.data_root, 'train.txt') | |
print(f"cwd is {os.getcwd()}") | |
print(f"train_split_file: {train_split_file}") | |
encoder = build_models(opt, dim_pose) | |
if world_size > 1: | |
encoder = MMDistributedDataParallel( | |
encoder.cuda(), | |
device_ids=[torch.cuda.current_device()], | |
broadcast_buffers=False, | |
find_unused_parameters=True) | |
elif opt.data_parallel: | |
encoder = MMDataParallel( | |
encoder.cuda(opt.gpu_id[0]), device_ids=opt.gpu_id) | |
else: | |
encoder = encoder.cuda() | |
trainer = DDPMTrainer(opt, encoder) | |
train_dataset = Text2MotionDataset(opt, mean, std, train_split_file, opt.times) | |
print(f"loaded data, now training") | |
trainer.train(train_dataset) | |