import numpy as np import torch from os.path import join as pjoin from .humanml.utils.word_vectorizer import WordVectorizer from .humanml.scripts.motion_process import (process_file, recover_from_ric) from . import BASEDataModule from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken, Text2MotionDatasetM2T from .utils import humanml3d_collate class HumanML3DDataModule(BASEDataModule): def __init__(self, cfg, **kwargs): super().__init__(collate_fn=humanml3d_collate) self.cfg = cfg self.save_hyperparameters(logger=False) # Basic info of the dataset cfg.DATASET.JOINT_TYPE = 'humanml3d' self.name = "humanml3d" self.njoints = 22 # Path to the dataset data_root = cfg.DATASET.HUMANML3D.ROOT self.hparams.data_root = data_root self.hparams.text_dir = pjoin(data_root, "texts") self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs') # Mean and std of the dataset self.hparams.mean = np.load(pjoin('assets/meta', "mean.npy")) self.hparams.std = np.load(pjoin('assets/meta', "std.npy")) # Mean and std for fair evaluation self.hparams.mean_eval = np.load(pjoin('assets/meta', "mean_eval.npy")) self.hparams.std_eval = np.load(pjoin('assets/meta', "std_eval.npy")) # Length of the dataset self.hparams.max_motion_length = cfg.DATASET.HUMANML3D.MAX_MOTION_LEN self.hparams.min_motion_length = cfg.DATASET.HUMANML3D.MIN_MOTION_LEN self.hparams.max_text_len = cfg.DATASET.HUMANML3D.MAX_TEXT_LEN self.hparams.unit_length = cfg.DATASET.HUMANML3D.UNIT_LEN # Additional parameters self.hparams.debug = cfg.DEBUG self.hparams.stage = cfg.TRAIN.STAGE # Dataset switch self.DatasetEval = Text2MotionDatasetEval if cfg.TRAIN.STAGE == "vae": if cfg.model.params.motion_vae.target.split('.')[-1].lower() == "vqvae": self.hparams.win_size = 64 self.Dataset = MotionDatasetVQ else: self.Dataset = MotionDataset elif 'lm' in cfg.TRAIN.STAGE: self.hparams.code_path = cfg.DATASET.CODE_PATH self.hparams.task_path = cfg.DATASET.TASK_PATH self.hparams.std_text = cfg.DATASET.HUMANML3D.STD_TEXT self.Dataset = Text2MotionDatasetCB elif cfg.TRAIN.STAGE == "token": self.Dataset = Text2MotionDatasetToken self.DatasetEval = Text2MotionDatasetToken elif cfg.TRAIN.STAGE == "m2t": self.Dataset = Text2MotionDatasetM2T self.DatasetEval = Text2MotionDatasetM2T else: self.Dataset = Text2MotionDataset # Get additional info of the dataset self.nfeats = 263 cfg.DATASET.NFEATS = self.nfeats def feats2joints(self, features): mean = torch.tensor(self.hparams.mean).to(features) std = torch.tensor(self.hparams.std).to(features) features = features * std + mean return recover_from_ric(features, self.njoints) def joints2feats(self, features): features = process_file(features, self.njoints)[0] return features def normalize(self, features): mean = torch.tensor(self.hparams.mean).to(features) std = torch.tensor(self.hparams.std).to(features) features = (features - mean) / std return features def denormalize(self, features): mean = torch.tensor(self.hparams.mean).to(features) std = torch.tensor(self.hparams.std).to(features) features = features * std + mean return features def renorm4t2m(self, features): # renorm to t2m norms for using t2m evaluators ori_mean = torch.tensor(self.hparams.mean).to(features) ori_std = torch.tensor(self.hparams.std).to(features) eval_mean = torch.tensor(self.hparams.mean_eval).to(features) eval_std = torch.tensor(self.hparams.std_eval).to(features) features = features * ori_std + ori_mean features = (features - eval_mean) / eval_std return features def mm_mode(self, mm_on=True): if mm_on: self.is_mm = True self.name_list = self.test_dataset.name_list self.mm_list = np.random.choice(self.name_list, self.cfg.METRIC.MM_NUM_SAMPLES, replace=False) self.test_dataset.name_list = self.mm_list else: self.is_mm = False self.test_dataset.name_list = self.name_list