Spaces:
Runtime error
Runtime error
File size: 2,722 Bytes
12deb01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
from os.path import join as pjoin
import utils.paramUtil as paramUtil
from options.train_options import TrainCompOptions
from utils.plot_script import *
from models import MotionTransformer
from trainers import DDPMTrainer
from datasets import Text2MotionDataset
from mmcv.runner import get_dist_info, init_dist
from mmcv.parallel import MMDistributedDataParallel
import torch
import torch.distributed as dist
def build_models(opt, dim_pose):
encoder = MotionTransformer(
input_feats=dim_pose,
num_frames=opt.max_motion_length,
num_layers=opt.num_layers,
latent_dim=opt.latent_dim,
no_clip=opt.no_clip,
no_eff=opt.no_eff)
return encoder
if __name__ == '__main__':
parser = TrainCompOptions()
opt = parser.parse()
rank, world_size = get_dist_info()
opt.device = torch.device("cuda")
torch.autograd.set_detect_anomaly(True)
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
opt.model_dir = pjoin(opt.save_root, 'model')
opt.meta_dir = pjoin(opt.save_root, 'meta')
if rank == 0:
os.makedirs(opt.model_dir, exist_ok=True)
os.makedirs(opt.meta_dir, exist_ok=True)
if world_size > 1:
dist.barrier()
if opt.dataset_name == 't2m':
opt.data_root = './data/HumanML3D'
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
opt.text_dir = pjoin(opt.data_root, 'texts')
opt.joints_num = 22
radius = 4
fps = 20
opt.max_motion_length = 196
dim_pose = 263
kinematic_chain = paramUtil.t2m_kinematic_chain
elif opt.dataset_name == 'kit':
opt.data_root = './data/KIT-ML'
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
opt.text_dir = pjoin(opt.data_root, 'texts')
opt.joints_num = 21
radius = 240 * 8
fps = 12.5
dim_pose = 251
opt.max_motion_length = 196
kinematic_chain = paramUtil.kit_kinematic_chain
else:
raise KeyError('Dataset Does Not Exist')
dim_word = 300
mean = np.load(pjoin(opt.data_root, 'Mean.npy'))
std = np.load(pjoin(opt.data_root, 'Std.npy'))
train_split_file = pjoin(opt.data_root, 'train.txt')
encoder = build_models(opt, dim_pose)
if world_size > 1:
encoder = MMDistributedDataParallel(
encoder.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=True)
else:
encoder = encoder.cuda()
trainer = DDPMTrainer(opt, encoder)
train_dataset = Text2MotionDataset(opt, mean, std, train_split_file, opt.times)
trainer.train(train_dataset)
|