Spaces:
Runtime error
Runtime error
import os | |
from os.path import join as pjoin | |
import utils.paramUtil as paramUtil | |
from options.train_options import TrainCompOptions | |
from utils.plot_script import * | |
from models import MotionTransformer | |
from trainers import DDPMTrainer | |
from datasets import Text2MotionDataset | |
from mmcv.runner import get_dist_info, init_dist | |
from mmcv.parallel import MMDistributedDataParallel | |
import torch | |
import torch.distributed as dist | |
def build_models(opt, dim_pose): | |
encoder = MotionTransformer( | |
input_feats=dim_pose, | |
num_frames=opt.max_motion_length, | |
num_layers=opt.num_layers, | |
latent_dim=opt.latent_dim, | |
no_clip=opt.no_clip, | |
no_eff=opt.no_eff) | |
return encoder | |
if __name__ == '__main__': | |
parser = TrainCompOptions() | |
opt = parser.parse() | |
rank, world_size = get_dist_info() | |
opt.device = torch.device("cuda") | |
torch.autograd.set_detect_anomaly(True) | |
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) | |
opt.model_dir = pjoin(opt.save_root, 'model') | |
opt.meta_dir = pjoin(opt.save_root, 'meta') | |
if rank == 0: | |
os.makedirs(opt.model_dir, exist_ok=True) | |
os.makedirs(opt.meta_dir, exist_ok=True) | |
if world_size > 1: | |
dist.barrier() | |
if opt.dataset_name == 't2m': | |
opt.data_root = './data/HumanML3D' | |
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') | |
opt.text_dir = pjoin(opt.data_root, 'texts') | |
opt.joints_num = 22 | |
radius = 4 | |
fps = 20 | |
opt.max_motion_length = 196 | |
dim_pose = 263 | |
kinematic_chain = paramUtil.t2m_kinematic_chain | |
elif opt.dataset_name == 'kit': | |
opt.data_root = './data/KIT-ML' | |
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') | |
opt.text_dir = pjoin(opt.data_root, 'texts') | |
opt.joints_num = 21 | |
radius = 240 * 8 | |
fps = 12.5 | |
dim_pose = 251 | |
opt.max_motion_length = 196 | |
kinematic_chain = paramUtil.kit_kinematic_chain | |
else: | |
raise KeyError('Dataset Does Not Exist') | |
dim_word = 300 | |
mean = np.load(pjoin(opt.data_root, 'Mean.npy')) | |
std = np.load(pjoin(opt.data_root, 'Std.npy')) | |
train_split_file = pjoin(opt.data_root, 'train.txt') | |
encoder = build_models(opt, dim_pose) | |
if world_size > 1: | |
encoder = MMDistributedDataParallel( | |
encoder.cuda(), | |
device_ids=[torch.cuda.current_device()], | |
broadcast_buffers=False, | |
find_unused_parameters=True) | |
else: | |
encoder = encoder.cuda() | |
trainer = DDPMTrainer(opt, encoder) | |
train_dataset = Text2MotionDataset(opt, mean, std, train_split_file, opt.times) | |
trainer.train(train_dataset) | |