Elle McFarlane
add gitignore etc
15d6c34
import os
# print cwd
import sys
from os.path import join as pjoin
sys.path.append(os.getcwd())
import torch
import torch.distributed as dist
import utils.paramUtil as paramUtil
import wandb
from datasets import Text2MotionDataset
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist
from models import MotionTransformer
from options.train_options import TrainCompOptions
from trainers import DDPMTrainer
from utils.plot_script import *
from utils.utils import *
def build_models(opt, dim_pose):
encoder = MotionTransformer(
input_feats=dim_pose,
num_frames=opt.max_motion_length,
num_layers=opt.num_layers,
latent_dim=opt.latent_dim,
no_clip=opt.no_clip,
no_eff=opt.no_eff)
return encoder
if __name__ == '__main__':
parser = TrainCompOptions()
opt = parser.parse()
rank, world_size = get_dist_info()
print(f"setting random seed to {opt.seed}")
set_random_seed(opt.seed)
opt.device = torch.device("cuda")
torch.autograd.set_detect_anomaly(True)
print(f"device id: {torch.cuda.current_device()}")
print(f"selected device ids: {opt.gpu_id}")
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
opt.model_dir = pjoin(opt.save_root, 'model')
opt.meta_dir = pjoin(opt.save_root, 'meta')
opt.noise_dir = pjoin(opt.save_root, 'noise')
if rank == 0:
os.makedirs(opt.model_dir, exist_ok=True)
os.makedirs(opt.meta_dir, exist_ok=True)
os.makedirs(opt.noise_dir, exist_ok=True)
if world_size > 1:
dist.barrier()
if opt.use_wandb:
wandb_id = wandb.util.generate_id()
wandb.init(
project="text2motion",
name=f"{opt.experiment_name}",
entity=opt.wandb_user,
# notes=opt.EXPERIMENT_NOTE,
config=opt,
id=wandb_id,
resume="allow",
# monitor_gym=True,
sync_tensorboard=True,
)
# opt.wandb = wandb
if opt.dataset_name == 't2m':
opt.data_root = './data/HumanML3D'
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
opt.text_dir = pjoin(opt.data_root, 'texts')
opt.joints_num = 22
radius = 4
fps = 20
opt.max_motion_length = 196
dim_pose = 263
kinematic_chain = paramUtil.t2m_kinematic_chain
elif opt.dataset_name == 'grab':
opt.data_root = './data/GRAB'
opt.motion_dir = pjoin(opt.data_root, 'joints')
opt.text_dir = pjoin(opt.data_root, 'texts')
opt.face_text_dir = pjoin(opt.data_root, 'face_texts')
opt.joints_num = 72 # TODO (elmc): verify this BUT ALSO I'M NOT USING IT FOR NOW!
# radius = 4 # TODO (elmc): verify this, think it's only for visualization purposes
# fps = 20 # TODO (elmc): verify this, also for visualization I think
dim_pose = 212 # drop betas (body shape) and face-shape from Motion data (via to_smplx_params & smplx_dict_to_array method)
opt.dim_pose = dim_pose
opt.max_motion_length = 196 # TODO (elmc): verify this; do this dynamically..??
# TODO (elmc): verify what this does and if we can use the t2m one
# NOTE: think, again, it's only for visualization
# kinematic_chain = paramUtil.t2m_kinematic_chain
# kinematic_chain = paramUtil.grab_kinematic_chain
print(f"loading data root: {opt.data_root}")
# print(f"kinematic chain: {kinematic_chain}")
elif opt.dataset_name == 'kit':
opt.data_root = './data/KIT-ML'
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
opt.text_dir = pjoin(opt.data_root, 'texts')
opt.joints_num = 21
radius = 240 * 8
fps = 12.5
dim_pose = 251
opt.max_motion_length = 196
kinematic_chain = paramUtil.kit_kinematic_chain
else:
raise KeyError('Dataset Does Not Exist')
# TODO (elmc): check dim_word and add back in???
# dim_word = 300
mean = np.load(pjoin(opt.data_root, 'Mean.npy'))
std = np.load(pjoin(opt.data_root, 'Std.npy'))
train_split_file = pjoin(opt.data_root, 'train.txt')
print(f"cwd is {os.getcwd()}")
print(f"train_split_file: {train_split_file}")
encoder = build_models(opt, dim_pose)
if world_size > 1:
encoder = MMDistributedDataParallel(
encoder.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=True)
elif opt.data_parallel:
encoder = MMDataParallel(
encoder.cuda(opt.gpu_id[0]), device_ids=opt.gpu_id)
else:
encoder = encoder.cuda()
trainer = DDPMTrainer(opt, encoder)
train_dataset = Text2MotionDataset(opt, mean, std, train_split_file, opt.times)
print(f"loaded data, now training")
trainer.train(train_dataset)