diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b6d4892a191283bb46f391c868117c0155a94cdb --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +**/*.pyc +.idea/ +__pycache__/ + +deps/ +datasets/ +experiments_t2m/ +experiments_t2m_test/ +experiments_control/ +experiments_control_test/ +experiments_recons/ +experiments_recons_test/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f03fbb7cd30918f38ed7bca1b67b752db82d9359 --- /dev/null +++ b/LICENSE @@ -0,0 +1,25 @@ +Copyright Tsinghua University and Shanghai AI Laboratory. All Rights Reserved. + +License for Non-commercial Scientific Research Purposes. + +For more information see . +If you use this software, please cite the corresponding publications +listed on the above website. + +Permission to use, copy, modify, and distribute this software and its +documentation for educational, research, and non-profit purposes only. +Any modification based on this work must be open-source and prohibited +for commercial, pornographic, military, or surveillance use. + +The authors grant you a non-exclusive, worldwide, non-transferable, +non-sublicensable, revocable, royalty-free, and limited license under +our copyright interests to reproduce, distribute, and create derivative +works of the text, videos, and codes solely for your non-commercial +research purposes. + +You must retain, in the source form of any derivative works that you +distribute, all copyright, patent, trademark, and attribution notices +from the source form of this work. + +For commercial uses of this software, please send email to all people +in the author list. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bd2d5a14490d3fd4cfd5dcb9d784a52706174133 --- /dev/null +++ b/README.md @@ -0,0 +1,11 @@ +--- +title: MotionLCM +emoji: 🏎️💨 +colorFrom: yellow +colorTo: pink +sdk: gradio +sdk_version: 4.44.1 +app_file: app.py +pinned: false +python_version: 3.10.12 +--- \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..be98321d516bd1c0b3bc11ae7d692ba442ef3857 --- /dev/null +++ b/app.py @@ -0,0 +1,234 @@ +import os +import time +import random +import datetime +import os.path as osp +from functools import partial + +import tqdm +from omegaconf import OmegaConf + +import torch +import gradio as gr + +from mld.config import get_module_config +from mld.data.get_data import get_dataset +from mld.models.modeltype.mld import MLD +from mld.utils.utils import set_seed +from mld.data.humanml.utils.plot_script import plot_3d_motion + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +WEBSITE = """ +
+

MotionLCM: Real-time Controllable Motion Generation via Latent Consistency Model

+

+Wenxun Dai1   +Ling-Hao Chen1   +Jingbo Wang2   +Jinpeng Liu1   +Bo Dai2   +Yansong Tang1 +

+

+1Tsinghua University   +2Shanghai AI Laboratory +

+
+""" + +WEBSITE_bottom = """ +
+

+Space adapted from TMR +and MoMask. +

+
+""" + +EXAMPLES = [ + "a person does a jump", + "a person waves both arms in the air.", + "The person takes 4 steps backwards.", + "this person bends forward as if to bow.", + "The person was pushed but did not fall.", + "a man walks forward in a snake like pattern.", + "a man paces back and forth along the same line.", + "with arms out to the sides a person walks forward", + "A man bends down and picks something up with his right hand.", + "The man walked forward, spun right on one foot and walked back to his original position.", + "a person slightly bent over with right hand pressing against the air walks forward slowly" +] + +if not os.path.exists("./experiments_t2m/"): + os.system("bash prepare/download_pretrained_models.sh") +if not os.path.exists('./deps/glove/'): + os.system("bash prepare/download_glove.sh") +if not os.path.exists('./deps/sentence-t5-large/'): + os.system("bash prepare/prepare_t5.sh") +if not os.path.exists('./deps/t2m/'): + os.system("bash prepare/download_t2m_evaluators.sh") +if not os.path.exists('./datasets/humanml3d/'): + os.system("bash prepare/prepare_tiny_humanml3d.sh") + +DEFAULT_TEXT = "A person is " +MAX_VIDEOS = 8 +NUM_ROWS = 2 +NUM_COLS = MAX_VIDEOS // NUM_ROWS +EXAMPLES_PER_PAGE = 12 +T2M_CFG = "./configs_v1/motionlcm_t2m.yaml" + +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +print("device: ", device) + +cfg = OmegaConf.load(T2M_CFG) +cfg_root = os.path.dirname(T2M_CFG) +cfg_model = get_module_config(cfg.model, cfg.model.target, cfg_root) +cfg = OmegaConf.merge(cfg, cfg_model) +set_seed(cfg.SEED_VALUE) + +name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) +cfg.output_dir = osp.join(cfg.TEST_FOLDER, name_time_str) +vis_dir = osp.join(cfg.output_dir, 'samples') +os.makedirs(cfg.output_dir, exist_ok=False) +os.makedirs(vis_dir, exist_ok=False) + +state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"] +print("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS)) + +is_lcm = False +lcm_key = 'denoiser.time_embedding.cond_proj.weight' # unique key for CFG +if lcm_key in state_dict: + is_lcm = True + time_cond_proj_dim = state_dict[lcm_key].shape[1] + cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim +print(f'Is LCM: {is_lcm}') + +dataset = get_dataset(cfg) +model = MLD(cfg, dataset) +model.to(device) +model.eval() +model.requires_grad_(False) +model.load_state_dict(state_dict) + +FPS = eval(f"cfg.DATASET.{cfg.DATASET.NAME.upper()}.FRAME_RATE") + + +@torch.no_grad() +def generate(text_, motion_len_): + batch = {"text": [text_] * MAX_VIDEOS, "length": [motion_len_] * MAX_VIDEOS} + + s = time.time() + joints = model(batch)[0] + runtime_infer = round(time.time() - s, 3) + + s = time.time() + path = [] + for i in tqdm.tqdm(range(len(joints))): + uid = random.randrange(999999999) + video_path = osp.join(vis_dir, f"sample_{uid}.mp4") + plot_3d_motion(video_path, joints[i].detach().cpu().numpy(), '', fps=FPS) + path.append(video_path) + runtime_draw = round(time.time() - s, 3) + + runtime_info = f'Inference {len(joints)} motions, Runtime (Inference): {runtime_infer}s, ' \ + f'Runtime (Draw Skeleton): {runtime_draw}s, device: {device} ' + + return path, runtime_info + + +def generate_component(generate_function, text_, motion_len_, num_inference_steps_, guidance_scale_): + if text_ == DEFAULT_TEXT or text_ == "" or text_ is None: + return [None] * MAX_VIDEOS + ["Please modify the default text prompt."] + + model.cfg.model.scheduler.num_inference_steps = num_inference_steps_ + model.guidance_scale = guidance_scale_ + motion_len_ = max(36, min(int(float(motion_len_) * FPS), 196)) + paths, info = generate_function(text_, motion_len_) + paths = paths + [None] * (MAX_VIDEOS - len(paths)) + return paths + [info] + +theme = gr.themes.Default(primary_hue="purple", secondary_hue="gray") +generate_and_show = partial(generate_component, generate) + +with gr.Blocks(theme=theme) as demo: + gr.HTML(WEBSITE) + videos = [] + + with gr.Row(): + with gr.Column(scale=3): + text = gr.Textbox( + show_label=True, + label="Text prompt", + value=DEFAULT_TEXT, + ) + + with gr.Row(): + with gr.Column(scale=1): + motion_len = gr.Slider( + minimum=1.8, + maximum=9.8, + step=0.2, + value=5.0, + label="Motion length", + info="Motion duration in seconds: [1.8s, 9.8s] (FPS = 20)." + ) + + with gr.Column(scale=1): + num_inference_steps = gr.Slider( + minimum=1, + maximum=4, + step=1, + value=1, + label="Inference steps", + info="Number of inference steps.", + ) + + cfg = gr.Slider( + minimum=1, + maximum=15, + step=0.5, + value=7.5, + label="CFG", + info="Classifier-free diffusion guidance.", + ) + + gen_btn = gr.Button("Generate", variant="primary") + clear = gr.Button("Clear", variant="secondary") + + results = gr.Textbox(show_label=True, + label='Inference info (runtime and device)', + info='Real-time inference cannot be achieved using the free CPU. Local GPU deployment is recommended.', + interactive=False) + + with gr.Column(scale=2): + examples = gr.Examples( + examples=EXAMPLES, + inputs=[text], + examples_per_page=EXAMPLES_PER_PAGE) + + for i in range(NUM_ROWS): + with gr.Row(): + for j in range(NUM_COLS): + video = gr.Video(autoplay=True, loop=True) + videos.append(video) + + # gr.HTML(WEBSITE_bottom) + + gen_btn.click( + fn=generate_and_show, + inputs=[text, motion_len, num_inference_steps, cfg], + outputs=videos+[results], + ) + text.submit( + fn=generate_and_show, + inputs=[text, motion_len, num_inference_steps, cfg], + outputs=videos+[results], + ) + + def clear_videos(): + return [None] * MAX_VIDEOS + [DEFAULT_TEXT] + [None] + + clear.click(fn=clear_videos, outputs=videos + [text] + [results]) + +demo.launch() diff --git a/configs/mld_t2m.yaml b/configs/mld_t2m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c78700d8cd29234f682cb27fb1b49244063f59f --- /dev/null +++ b/configs/mld_t2m.yaml @@ -0,0 +1,104 @@ +FOLDER: './experiments_t2m' +TEST_FOLDER: './experiments_t2m_test' + +NAME: 'mld_humanml' + +SEED_VALUE: 1234 + +TRAIN: + BATCH_SIZE: 64 + SPLIT: 'train' + NUM_WORKERS: 8 + PERSISTENT_WORKERS: true + + PRETRAINED: 'experiments_recons/vae_humanml/vae_humanml.ckpt' + + validation_steps: -1 + validation_epochs: 50 + checkpointing_steps: -1 + checkpointing_epochs: 50 + max_train_steps: -1 + max_train_epochs: 3000 + learning_rate: 1e-4 + lr_scheduler: "cosine" + lr_warmup_steps: 1000 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 0.0 + adam_epsilon: 1e-08 + max_grad_norm: 1.0 + model_ema: false + model_ema_steps: 32 + model_ema_decay: 0.999 + +VAL: + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + +TEST: + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + + CHECKPOINTS: 'experiments_t2m/mld_humanml/mld_humanml.ckpt' + + # Testing Args + REPLICATION_TIMES: 20 + MM_NUM_SAMPLES: 100 + MM_NUM_REPEATS: 30 + MM_NUM_TIMES: 10 + DIVERSITY_TIMES: 300 + DO_MM_TEST: true + +DATASET: + NAME: 'humanml3d' + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + CONTROL_ARGS: + CONTROL: false + TEMPORAL: false + TRAIN_JOINTS: [0] + TEST_JOINTS: [0] + TRAIN_DENSITY: 'random' + TEST_DENSITY: 100 + MEAN_STD_PATH: './datasets/humanml_spatial_norm' + SAMPLER: + MAX_LEN: 200 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + PADDING_TO_MAX: false + WINDOW_SIZE: null + +METRIC: + DIST_SYNC_ON_STEP: true + TYPE: ['TM2TMetrics'] + +model: + target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_ddim', 'noise_optimizer'] + latent_dim: [16, 32] + guidance_scale: 7.5 + guidance_uncondp: 0.1 + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + bert_path: './deps/distilbert-base-uncased' + clip_path: './deps/clip-vit-large-patch14' + t5_path: './deps/sentence-t5-large' + t2m_path: './deps/t2m/' diff --git a/configs/modules/denoiser.yaml b/configs/modules/denoiser.yaml new file mode 100644 index 0000000000000000000000000000000000000000..35c3afcae3d24f69a13651e67264c6a731f80c42 --- /dev/null +++ b/configs/modules/denoiser.yaml @@ -0,0 +1,28 @@ +denoiser: + target: mld.models.architectures.mld_denoiser.MldDenoiser + params: + latent_dim: ${model.latent_dim} + hidden_dim: 256 + text_dim: 768 + time_dim: 768 + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: false + norm_eps: 1e-5 + activation: 'gelu' + norm_post: true + activation_post: null + flip_sin_to_cos: true + freq_shift: 0 + time_act_fn: 'silu' + time_post_act_fn: null + position_embedding: 'learned' + arch: 'trans_enc' + add_mem_pos: true + force_pre_post_proj: true + text_act_fn: null + zero_init_cond: true + controlnet_embed_dim: 256 + controlnet_act_fn: 'silu' diff --git a/configs/modules/motion_vae.yaml b/configs/modules/motion_vae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b682c1bb1b76b22f31e1ca03ee061ed6b92c8d40 --- /dev/null +++ b/configs/modules/motion_vae.yaml @@ -0,0 +1,18 @@ +motion_vae: + target: mld.models.architectures.mld_vae.MldVae + params: + nfeats: ${DATASET.NFEATS} + latent_dim: ${model.latent_dim} + hidden_dim: 256 + force_pre_post_proj: true + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + arch: 'encoder_decoder' + normalize_before: false + norm_eps: 1e-5 + activation: 'gelu' + norm_post: true + activation_post: null + position_embedding: 'learned' diff --git a/configs/modules/noise_optimizer.yaml b/configs/modules/noise_optimizer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45e1a29ba13abc196ca272bbd771c1a7783843de --- /dev/null +++ b/configs/modules/noise_optimizer.yaml @@ -0,0 +1,15 @@ +noise_optimizer: + target: mld.models.architectures.dno.DNO + params: + optimize: false + max_train_steps: 400 + learning_rate: 0.1 + lr_scheduler: 'cosine' + lr_warmup_steps: 50 + clip_grad: true + loss_hint_type: 'l2' + loss_diff_penalty: 0.000 + loss_correlate_penalty: 100 + visualize_samples: 0 + visualize_ske_steps: [] + output_dir: ${output_dir} diff --git a/configs/modules/scheduler_ddim.yaml b/configs/modules/scheduler_ddim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce1755c23a96d0bdd61250c7a19cfb3f8e15fcda --- /dev/null +++ b/configs/modules/scheduler_ddim.yaml @@ -0,0 +1,14 @@ +scheduler: + target: diffusers.DDIMScheduler + num_inference_steps: 50 + eta: 0.0 + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: 'scaled_linear' + prediction_type: 'epsilon' + clip_sample: false + # below are for ddim + set_alpha_to_one: false + steps_offset: 1 diff --git a/configs/modules/scheduler_lcm.yaml b/configs/modules/scheduler_lcm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f08995b88082eb610d073b1e3e4913f8435ff5b4 --- /dev/null +++ b/configs/modules/scheduler_lcm.yaml @@ -0,0 +1,19 @@ +scheduler: + target: mld.models.schedulers.scheduling_lcm.LCMScheduler + num_inference_steps: 1 + cfg_step_map: + 1: 8.0 + 2: 12.5 + 4: 13.5 + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: 'scaled_linear' + clip_sample: false + set_alpha_to_one: false + original_inference_steps: 10 + timesteps_step_map: + 1: [799] + 2: [699, 299] + 4: [699, 399, 299, 299] diff --git a/configs/modules/text_encoder.yaml b/configs/modules/text_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2cbaaaaae3384471a9d2a2da36d60af5e031489c --- /dev/null +++ b/configs/modules/text_encoder.yaml @@ -0,0 +1,5 @@ +text_encoder: + target: mld.models.architectures.mld_clip.MldTextEncoder + params: + last_hidden_state: false + modelpath: ${model.t5_path} diff --git a/configs/modules/traj_encoder.yaml b/configs/modules/traj_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f3bea115c80786c52fb6969987a9d28ec2e205fa --- /dev/null +++ b/configs/modules/traj_encoder.yaml @@ -0,0 +1,17 @@ +traj_encoder: + target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder + params: + nfeats: ${DATASET.NJOINTS} + latent_dim: ${model.latent_dim} + hidden_dim: 256 + force_post_proj: true + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: false + norm_eps: 1e-5 + activation: 'gelu' + norm_post: true + activation_post: null + position_embedding: 'learned' diff --git a/configs/motionlcm_control_s.yaml b/configs/motionlcm_control_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a8a58a29aac64e4b88e240a9a8996ed89f3557b --- /dev/null +++ b/configs/motionlcm_control_s.yaml @@ -0,0 +1,113 @@ +FOLDER: './experiments_control/spatial' +TEST_FOLDER: './experiments_control_test/spatial' + +NAME: 'motionlcm_humanml' + +SEED_VALUE: 1234 + +TRAIN: + DATASET: 'humanml3d' + BATCH_SIZE: 128 + SPLIT: 'train' + NUM_WORKERS: 8 + PERSISTENT_WORKERS: true + + PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt' + + validation_steps: -1 + validation_epochs: 50 + checkpointing_steps: -1 + checkpointing_epochs: 50 + max_train_steps: -1 + max_train_epochs: 1000 + learning_rate: 1e-4 + learning_rate_spatial: 1e-4 + lr_scheduler: "cosine" + lr_warmup_steps: 1000 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 0.0 + adam_epsilon: 1e-08 + max_grad_norm: 1.0 + +VAL: + DATASET: 'humanml3d' + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + +TEST: + DATASET: 'humanml3d' + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + + CHECKPOINTS: 'experiments_control/spatial/motionlcm_humanml/motionlcm_humanml_s_pelvis.ckpt' + # CHECKPOINTS: 'experiments_control/spatial/motionlcm_humanml/motionlcm_humanml_s_all.ckpt' + + # Testing Args + REPLICATION_TIMES: 1 + DIVERSITY_TIMES: 300 + DO_MM_TEST: false + MAX_NUM_SAMPLES: 1024 + +DATASET: + NAME: 'humanml3d' + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + CONTROL_ARGS: + CONTROL: true + TEMPORAL: false + TRAIN_JOINTS: [0] + TEST_JOINTS: [0] + TRAIN_DENSITY: 'random' + TEST_DENSITY: 100 + MEAN_STD_PATH: './datasets/humanml_spatial_norm' + SAMPLER: + MAX_LEN: 200 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + PADDING_TO_MAX: false + WINDOW_SIZE: null + +METRIC: + DIST_SYNC_ON_STEP: true + TYPE: ['TM2TMetrics', 'ControlMetrics'] + +model: + target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder', 'noise_optimizer'] + latent_dim: [16, 32] + guidance_scale: 'dynamic' + + # ControlNet Args + is_controlnet: true + vaeloss: true + vaeloss_type: 'mask' + cond_ratio: 1.0 + control_loss_func: 'l1_smooth' + use_3d: true + lcm_w_min_nax: [5, 15] + lcm_num_ddim_timesteps: 10 + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + bert_path: './deps/distilbert-base-uncased' + clip_path: './deps/clip-vit-large-patch14' + t5_path: './deps/sentence-t5-large' + t2m_path: './deps/t2m/' diff --git a/configs/motionlcm_control_t.yaml b/configs/motionlcm_control_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34c541e95020b87b470f1c0bc09bc6851cc1f2e2 --- /dev/null +++ b/configs/motionlcm_control_t.yaml @@ -0,0 +1,111 @@ +FOLDER: './experiments_control/temporal' +TEST_FOLDER: './experiments_control_test/temporal' + +NAME: 'motionlcm_humanml' + +SEED_VALUE: 1234 + +TRAIN: + DATASET: 'humanml3d' + BATCH_SIZE: 128 + SPLIT: 'train' + NUM_WORKERS: 8 + PERSISTENT_WORKERS: true + + PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt' + + validation_steps: -1 + validation_epochs: 50 + checkpointing_steps: -1 + checkpointing_epochs: 50 + max_train_steps: -1 + max_train_epochs: 1000 + learning_rate: 1e-4 + learning_rate_spatial: 1e-4 + lr_scheduler: "cosine" + lr_warmup_steps: 1000 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 0.0 + adam_epsilon: 1e-08 + max_grad_norm: 1.0 + +VAL: + DATASET: 'humanml3d' + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + +TEST: + DATASET: 'humanml3d' + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + + CHECKPOINTS: 'experiments_control/temporal/motionlcm_humanml/motionlcm_humanml_t.ckpt' + + # Testing Args + REPLICATION_TIMES: 20 + DIVERSITY_TIMES: 300 + DO_MM_TEST: false + +DATASET: + NAME: 'humanml3d' + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + CONTROL_ARGS: + CONTROL: true + TEMPORAL: true + TRAIN_JOINTS: [0, 10, 11, 15, 20, 21] + TEST_JOINTS: [0, 10, 11, 15, 20, 21] + TRAIN_DENSITY: [25, 25] + TEST_DENSITY: 25 + MEAN_STD_PATH: './datasets/humanml_spatial_norm' + SAMPLER: + MAX_LEN: 200 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + PADDING_TO_MAX: false + WINDOW_SIZE: null + +METRIC: + DIST_SYNC_ON_STEP: true + TYPE: ['TM2TMetrics', 'ControlMetrics'] + +model: + target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder', 'noise_optimizer'] + latent_dim: [16, 32] + guidance_scale: 'dynamic' + + # ControlNet Args + is_controlnet: true + vaeloss: true + vaeloss_type: 'sum' + cond_ratio: 1.0 + control_loss_func: 'l2' + use_3d: false + lcm_w_min_nax: [5, 15] + lcm_num_ddim_timesteps: 10 + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + bert_path: './deps/distilbert-base-uncased' + clip_path: './deps/clip-vit-large-patch14' + t5_path: './deps/sentence-t5-large' + t2m_path: './deps/t2m/' diff --git a/configs/motionlcm_t2m.yaml b/configs/motionlcm_t2m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..23da7d341e03ba04f4ca2dc65cc63fe8d1ed28e0 --- /dev/null +++ b/configs/motionlcm_t2m.yaml @@ -0,0 +1,109 @@ +FOLDER: './experiments_t2m' +TEST_FOLDER: './experiments_t2m_test' + +NAME: 'motionlcm_humanml' + +SEED_VALUE: 1234 + +TRAIN: + BATCH_SIZE: 128 + SPLIT: 'train' + NUM_WORKERS: 8 + PERSISTENT_WORKERS: true + + PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml.ckpt' + + validation_steps: -1 + validation_epochs: 50 + checkpointing_steps: -1 + checkpointing_epochs: 50 + max_train_steps: -1 + max_train_epochs: 1000 + learning_rate: 2e-4 + lr_scheduler: "cosine" + lr_warmup_steps: 1000 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 0.0 + adam_epsilon: 1e-08 + max_grad_norm: 1.0 + + # Latent Consistency Distillation Specific Arguments + w_min: 5.0 + w_max: 15.0 + num_ddim_timesteps: 10 + loss_type: 'huber' + huber_c: 0.5 + unet_time_cond_proj_dim: 256 + ema_decay: 0.95 + +VAL: + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + +TEST: + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + + CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt' + + # Testing Args + REPLICATION_TIMES: 20 + MM_NUM_SAMPLES: 100 + MM_NUM_REPEATS: 30 + MM_NUM_TIMES: 10 + DIVERSITY_TIMES: 300 + DO_MM_TEST: true + +DATASET: + NAME: 'humanml3d' + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + CONTROL_ARGS: + CONTROL: false + TEMPORAL: false + TRAIN_JOINTS: [0] + TEST_JOINTS: [0] + TRAIN_DENSITY: 'random' + TEST_DENSITY: 100 + MEAN_STD_PATH: './datasets/humanml_spatial_norm' + SAMPLER: + MAX_LEN: 200 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + PADDING_TO_MAX: false + WINDOW_SIZE: null + +METRIC: + DIST_SYNC_ON_STEP: true + TYPE: ['TM2TMetrics'] + +model: + target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'noise_optimizer'] + latent_dim: [16, 32] + guidance_scale: 'dynamic' + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + bert_path: './deps/distilbert-base-uncased' + clip_path: './deps/clip-vit-large-patch14' + t5_path: './deps/sentence-t5-large' + t2m_path: './deps/t2m/' diff --git a/configs/motionlcm_t2m_clt.yaml b/configs/motionlcm_t2m_clt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d168b5062854eb1d40f0422651a6dde39e033732 --- /dev/null +++ b/configs/motionlcm_t2m_clt.yaml @@ -0,0 +1,69 @@ +FOLDER: './experiments_t2m' +TEST_FOLDER: './experiments_t2m_test' + +NAME: 'motionlcm_humanml' + +SEED_VALUE: 1234 + +TEST: + BATCH_SIZE: 1 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + + CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt' + + # Testing Args + REPLICATION_TIMES: 1 + DIVERSITY_TIMES: 300 + DO_MM_TEST: false + MAX_NUM_SAMPLES: 1024 + +DATASET: + NAME: 'humanml3d' + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + CONTROL_ARGS: + CONTROL: true + TEMPORAL: false + TRAIN_JOINTS: [0] + TEST_JOINTS: [0] + TRAIN_DENSITY: 'random' + TEST_DENSITY: 100 + MEAN_STD_PATH: './datasets/humanml_spatial_norm' + SAMPLER: + MAX_LEN: 200 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + PADDING_TO_MAX: false + WINDOW_SIZE: null + +METRIC: + DIST_SYNC_ON_STEP: true + TYPE: ['TM2TMetrics', 'ControlMetrics'] + +model: + target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'noise_optimizer'] + latent_dim: [16, 32] + guidance_scale: 'dynamic' + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + bert_path: './deps/distilbert-base-uncased' + clip_path: './deps/clip-vit-large-patch14' + t5_path: './deps/sentence-t5-large' + t2m_path: './deps/t2m/' diff --git a/configs/vae.yaml b/configs/vae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fb80dcb23bee921eee82d089ee698d1bdedaceef --- /dev/null +++ b/configs/vae.yaml @@ -0,0 +1,103 @@ +FOLDER: './experiments_recons' +TEST_FOLDER: './experiments_recons_test' + +NAME: 'vae_humanml' + +SEED_VALUE: 1234 + +TRAIN: + BATCH_SIZE: 128 + SPLIT: 'train' + NUM_WORKERS: 8 + PERSISTENT_WORKERS: true + PRETRAINED: '' + + validation_steps: -1 + validation_epochs: 100 + checkpointing_steps: -1 + checkpointing_epochs: 100 + max_train_steps: -1 + max_train_epochs: 6000 + learning_rate: 2e-4 + lr_scheduler: "cosine" + lr_warmup_steps: 1000 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 0.0 + adam_epsilon: 1e-08 + max_grad_norm: 1.0 + +VAL: + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + +TEST: + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + + CHECKPOINTS: 'experiments_recons/vae_humanml/vae_humanml.ckpt' + + # Testing Args + REPLICATION_TIMES: 20 + DIVERSITY_TIMES: 300 + DO_MM_TEST: false + +DATASET: + NAME: 'humanml3d' + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + CONTROL_ARGS: + CONTROL: false + TEMPORAL: false + TRAIN_JOINTS: [0] + TEST_JOINTS: [0] + TRAIN_DENSITY: 'random' + TEST_DESITY: 100 + MEAN_STD_PATH: './datasets/humanml_spatial_norm' + SAMPLER: + MAX_LEN: 200 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + PADDING_TO_MAX: true + WINDOW_SIZE: 64 + +METRIC: + DIST_SYNC_ON_STEP: true + TYPE: ['TM2TMetrics', "PosMetrics"] + +model: + target: ['motion_vae'] + latent_dim: [16, 32] + + # VAE Args + rec_feats_ratio: 1.0 + rec_joints_ratio: 1.0 + rec_velocity_ratio: 0.0 + kl_ratio: 1e-4 + + rec_feats_loss: 'l1_smooth' + rec_joints_loss: 'l1_smooth' + rec_velocity_loss: 'l1_smooth' + mask_loss: true + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + t2m_path: './deps/t2m/' diff --git a/configs_v1/modules/denoiser.yaml b/configs_v1/modules/denoiser.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25825fef0fc1f756ab9f6a79e2b01f4f01ee8e2f --- /dev/null +++ b/configs_v1/modules/denoiser.yaml @@ -0,0 +1,28 @@ +denoiser: + target: mld.models.architectures.mld_denoiser.MldDenoiser + params: + latent_dim: ${model.latent_dim} + hidden_dim: null + text_dim: 768 + time_dim: 768 + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: false + norm_eps: 1e-5 + activation: 'gelu' + norm_post: true + activation_post: null + flip_sin_to_cos: true + freq_shift: 0 + time_act_fn: 'silu' + time_post_act_fn: null + position_embedding: 'learned' + arch: 'trans_enc' + add_mem_pos: true + force_pre_post_proj: false + text_act_fn: 'relu' + zero_init_cond: true + controlnet_embed_dim: 256 + controlnet_act_fn: null diff --git a/configs_v1/modules/motion_vae.yaml b/configs_v1/modules/motion_vae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4682b573f33267d0c7e73ec08c3328ed4eb296a --- /dev/null +++ b/configs_v1/modules/motion_vae.yaml @@ -0,0 +1,18 @@ +motion_vae: + target: mld.models.architectures.mld_vae.MldVae + params: + nfeats: ${DATASET.NFEATS} + latent_dim: ${model.latent_dim} + hidden_dim: null + force_pre_post_proj: false + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + arch: 'encoder_decoder' + normalize_before: false + norm_eps: 1e-5 + activation: 'gelu' + norm_post: true + activation_post: null + position_embedding: 'learned' diff --git a/configs_v1/modules/scheduler_lcm.yaml b/configs_v1/modules/scheduler_lcm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1102aa75cad465efcc7458d57486d26a93be0ad4 --- /dev/null +++ b/configs_v1/modules/scheduler_lcm.yaml @@ -0,0 +1,11 @@ +scheduler: + target: diffusers.LCMScheduler + num_inference_steps: 1 + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: 'scaled_linear' + clip_sample: false + set_alpha_to_one: false + original_inference_steps: 50 diff --git a/configs_v1/modules/text_encoder.yaml b/configs_v1/modules/text_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2cbaaaaae3384471a9d2a2da36d60af5e031489c --- /dev/null +++ b/configs_v1/modules/text_encoder.yaml @@ -0,0 +1,5 @@ +text_encoder: + target: mld.models.architectures.mld_clip.MldTextEncoder + params: + last_hidden_state: false + modelpath: ${model.t5_path} diff --git a/configs_v1/modules/traj_encoder.yaml b/configs_v1/modules/traj_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a56dff8f54037ea22a85dc8bc7a08141636dc1b --- /dev/null +++ b/configs_v1/modules/traj_encoder.yaml @@ -0,0 +1,17 @@ +traj_encoder: + target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder + params: + nfeats: ${DATASET.NJOINTS} + latent_dim: ${model.latent_dim} + hidden_dim: null + force_post_proj: false + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: false + norm_eps: 1e-5 + activation: 'gelu' + norm_post: true + activation_post: null + position_embedding: 'learned' diff --git a/configs_v1/motionlcm_control_t.yaml b/configs_v1/motionlcm_control_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..23dcccecaebecef62c713a04dad79f73ce55eca9 --- /dev/null +++ b/configs_v1/motionlcm_control_t.yaml @@ -0,0 +1,114 @@ +FOLDER: './experiments_control/temporal' +TEST_FOLDER: './experiments_control_test/temporal' + +NAME: 'motionlcm_humanml' + +SEED_VALUE: 1234 + +TRAIN: + DATASET: 'humanml3d' + BATCH_SIZE: 128 + SPLIT: 'train' + NUM_WORKERS: 8 + PERSISTENT_WORKERS: true + + PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml_v1.ckpt' + + validation_steps: -1 + validation_epochs: 50 + checkpointing_steps: -1 + checkpointing_epochs: 50 + max_train_steps: -1 + max_train_epochs: 1000 + learning_rate: 1e-4 + learning_rate_spatial: 1e-4 + lr_scheduler: "cosine" + lr_warmup_steps: 1000 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 0.0 + adam_epsilon: 1e-08 + max_grad_norm: 1.0 + +VAL: + DATASET: 'humanml3d' + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + +TEST: + DATASET: 'humanml3d' + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + + CHECKPOINTS: 'experiments_control/temporal/motionlcm_humanml/motionlcm_humanml_t_v1.ckpt' + + # Testing Args + REPLICATION_TIMES: 20 + MM_NUM_SAMPLES: 100 + MM_NUM_REPEATS: 30 + MM_NUM_TIMES: 10 + DIVERSITY_TIMES: 300 + DO_MM_TEST: false + +DATASET: + NAME: 'humanml3d' + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + CONTROL_ARGS: + CONTROL: true + TEMPORAL: true + TRAIN_JOINTS: [0, 10, 11, 15, 20, 21] + TEST_JOINTS: [0, 10, 11, 15, 20, 21] + TRAIN_DENSITY: [25, 25] + TEST_DENSITY: 25 + MEAN_STD_PATH: './datasets/humanml_spatial_norm' + SAMPLER: + MAX_LEN: 200 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + PADDING_TO_MAX: false + WINDOW_SIZE: null + +METRIC: + DIST_SYNC_ON_STEP: true + TYPE: ['TM2TMetrics', 'ControlMetrics'] + +model: + target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder'] + latent_dim: [1, 256] + guidance_scale: 7.5 + + # ControlNet Args + is_controlnet: true + vaeloss: true + vaeloss_type: 'sum' + cond_ratio: 1.0 + control_loss_func: 'l2' + use_3d: false + lcm_w_min_nax: null + lcm_num_ddim_timesteps: null + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + bert_path: './deps/distilbert-base-uncased' + clip_path: './deps/clip-vit-large-patch14' + t5_path: './deps/sentence-t5-large' + t2m_path: './deps/t2m/' diff --git a/configs_v1/motionlcm_t2m.yaml b/configs_v1/motionlcm_t2m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d20a8993dec6d1cdd93ec70965a513dd98a72303 --- /dev/null +++ b/configs_v1/motionlcm_t2m.yaml @@ -0,0 +1,109 @@ +FOLDER: './experiments_t2m' +TEST_FOLDER: './experiments_t2m_test' + +NAME: 'motionlcm_humanml' + +SEED_VALUE: 1234 + +TRAIN: + BATCH_SIZE: 256 + SPLIT: 'train' + NUM_WORKERS: 8 + PERSISTENT_WORKERS: true + + PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml_v1.ckpt' + + validation_steps: -1 + validation_epochs: 50 + checkpointing_steps: -1 + checkpointing_epochs: 50 + max_train_steps: -1 + max_train_epochs: 1000 + learning_rate: 2e-4 + lr_scheduler: "cosine" + lr_warmup_steps: 1000 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 0.0 + adam_epsilon: 1e-08 + max_grad_norm: 1.0 + + # Latent Consistency Distillation Specific Arguments + w_min: 5.0 + w_max: 15.0 + num_ddim_timesteps: 50 + loss_type: 'huber' + huber_c: 0.001 + unet_time_cond_proj_dim: 256 + ema_decay: 0.95 + +VAL: + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + +TEST: + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + PERSISTENT_WORKERS: true + + CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml_v1.ckpt' + + # Testing Args + REPLICATION_TIMES: 20 + MM_NUM_SAMPLES: 100 + MM_NUM_REPEATS: 30 + MM_NUM_TIMES: 10 + DIVERSITY_TIMES: 300 + DO_MM_TEST: true + +DATASET: + NAME: 'humanml3d' + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + CONTROL_ARGS: + CONTROL: false + TEMPORAL: false + TRAIN_JOINTS: [0] + TEST_JOINTS: [0] + TRAIN_DENSITY: 'random' + TEST_DENSITY: 100 + MEAN_STD_PATH: './datasets/humanml_spatial_norm' + SAMPLER: + MAX_LEN: 200 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + PADDING_TO_MAX: false + WINDOW_SIZE: null + +METRIC: + DIST_SYNC_ON_STEP: true + TYPE: ['TM2TMetrics'] + +model: + target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm'] + latent_dim: [1, 256] + guidance_scale: 7.5 + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + bert_path: './deps/distilbert-base-uncased' + clip_path: './deps/clip-vit-large-patch14' + t5_path: './deps/sentence-t5-large' + t2m_path: './deps/t2m/' diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..2a0979ad6de8458babcb16c90d64b513332d3c71 --- /dev/null +++ b/demo.py @@ -0,0 +1,196 @@ +import os +import pickle +import sys +import datetime +import logging +import os.path as osp + +from omegaconf import OmegaConf + +import torch + +from mld.config import parse_args +from mld.data.get_data import get_dataset +from mld.models.modeltype.mld import MLD +from mld.models.modeltype.vae import VAE +from mld.utils.utils import set_seed, move_batch_to_device +from mld.data.humanml.utils.plot_script import plot_3d_motion +from mld.utils.temos_utils import remove_padding + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def load_example_hint_input(text_path: str) -> tuple: + with open(text_path, "r") as f: + lines = f.readlines() + + n_frames, control_type_ids, control_hint_ids = [], [], [] + for line in lines: + s = line.strip() + n_frame, control_type_id, control_hint_id = s.split(' ') + n_frames.append(int(n_frame)) + control_type_ids.append(int(control_type_id)) + control_hint_ids.append(int(control_hint_id)) + + return n_frames, control_type_ids, control_hint_ids + + +def load_example_input(text_path: str) -> tuple: + with open(text_path, "r") as f: + lines = f.readlines() + + texts, lens = [], [] + for line in lines: + s = line.strip() + s_l = s.split(" ")[0] + s_t = s[(len(s_l) + 1):] + lens.append(int(s_l)) + texts.append(s_t) + return texts, lens + + +def main(): + cfg = parse_args() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + set_seed(cfg.SEED_VALUE) + + name_time_str = osp.join(cfg.NAME, "demo_" + datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + cfg.output_dir = osp.join(cfg.TEST_FOLDER, name_time_str) + vis_dir = osp.join(cfg.output_dir, 'samples') + os.makedirs(cfg.output_dir, exist_ok=False) + os.makedirs(vis_dir, exist_ok=False) + + steam_handler = logging.StreamHandler(sys.stdout) + file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log')) + logging.basicConfig(level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[steam_handler, file_handler]) + logger = logging.getLogger(__name__) + + OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml')) + + state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"] + logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS)) + + # Step 1: Check if the checkpoint is VAE-based. + is_vae = False + vae_key = 'vae.skel_embedding.weight' + if vae_key in state_dict: + is_vae = True + logger.info(f'Is VAE: {is_vae}') + + # Step 2: Check if the checkpoint is MLD-based. + is_mld = False + mld_key = 'denoiser.time_embedding.linear_1.weight' + if mld_key in state_dict: + is_mld = True + logger.info(f'Is MLD: {is_mld}') + + # Step 3: Check if the checkpoint is LCM-based. + is_lcm = False + lcm_key = 'denoiser.time_embedding.cond_proj.weight' # unique key for CFG + if lcm_key in state_dict: + is_lcm = True + time_cond_proj_dim = state_dict[lcm_key].shape[1] + cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim + logger.info(f'Is LCM: {is_lcm}') + + # Step 4: Check if the checkpoint is Controlnet-based. + cn_key = "controlnet.controlnet_cond_embedding.0.weight" + is_controlnet = True if cn_key in state_dict else False + cfg.model.is_controlnet = is_controlnet + logger.info(f'Is Controlnet: {is_controlnet}') + + if is_mld or is_lcm or is_controlnet: + target_model_class = MLD + else: + target_model_class = VAE + + if cfg.optimize: + assert cfg.model.get('noise_optimizer') is not None + cfg.model.noise_optimizer.params.optimize = True + logger.info('Optimization enabled. Set the batch size to 1.') + logger.info(f'Original batch size: {cfg.TEST.BATCH_SIZE}') + cfg.TEST.BATCH_SIZE = 1 + + dataset = get_dataset(cfg) + model = target_model_class(cfg, dataset) + model.to(device) + model.eval() + model.requires_grad_(False) + logger.info(model.load_state_dict(state_dict)) + + FPS = eval(f"cfg.DATASET.{cfg.DATASET.NAME.upper()}.FRAME_RATE") + + if cfg.example is not None and not is_controlnet: + text, length = load_example_input(cfg.example) + for t, l in zip(text, length): + logger.info(f"{l}: {t}") + + batch = {"length": length, "text": text} + + for rep_i in range(cfg.replication): + with torch.no_grad(): + joints = model(batch)[0] + + num_samples = len(joints) + for i in range(num_samples): + res = dict() + pkl_path = osp.join(vis_dir, f"sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl") + res['joints'] = joints[i].detach().cpu().numpy() + res['text'] = text[i] + res['length'] = length[i] + res['hint'] = None + with open(pkl_path, 'wb') as f: + pickle.dump(res, f) + logger.info(f"Motions are generated here:\n{pkl_path}") + + if not cfg.no_plot: + plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(), text[i], fps=FPS) + + else: + test_dataloader = dataset.test_dataloader() + for rep_i in range(cfg.replication): + for batch_id, batch in enumerate(test_dataloader): + batch = move_batch_to_device(batch, device) + with torch.no_grad(): + joints, joints_ref = model(batch) + + num_samples = len(joints) + text = batch['text'] + length = batch['length'] + if 'hint' in batch: + hint, hint_mask = batch['hint'], batch['hint_mask'] + hint = dataset.denorm_spatial(hint) * hint_mask + hint = remove_padding(hint, lengths=length) + else: + hint = None + + for i in range(num_samples): + res = dict() + pkl_path = osp.join(vis_dir, f"batch_id_{batch_id}_sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl") + res['joints'] = joints[i].detach().cpu().numpy() + res['text'] = text[i] + res['length'] = length[i] + res['hint'] = hint[i].detach().cpu().numpy() if hint is not None else None + with open(pkl_path, 'wb') as f: + pickle.dump(res, f) + logger.info(f"Motions are generated here:\n{pkl_path}") + + if not cfg.no_plot: + plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(), + text[i], fps=FPS, hint=hint[i].detach().cpu().numpy() if hint is not None else None) + + if rep_i == 0: + res['joints'] = joints_ref[i].detach().cpu().numpy() + with open(pkl_path.replace('.pkl', '_ref.pkl'), 'wb') as f: + pickle.dump(res, f) + logger.info(f"Motions are generated here:\n{pkl_path.replace('.pkl', '_ref.pkl')}") + if not cfg.no_plot: + plot_3d_motion(pkl_path.replace('.pkl', '_ref.mp4'), joints_ref[i].detach().cpu().numpy(), + text[i], fps=FPS, hint=hint[i].detach().cpu().numpy() if hint is not None else None) + + +if __name__ == "__main__": + main() diff --git a/fit.py b/fit.py new file mode 100644 index 0000000000000000000000000000000000000000..0145bccb33ae9539aa703e13207cc6051a6104d2 --- /dev/null +++ b/fit.py @@ -0,0 +1,136 @@ +# borrow from optimization https://github.com/wangsen1312/joints2smpl +import os +import argparse +import pickle + +import h5py +import natsort +import smplx + +import torch + +from mld.transforms.joints2rots import config +from mld.transforms.joints2rots.smplify import SMPLify3D + +parser = argparse.ArgumentParser() +parser.add_argument("--pkl", type=str, default=None, help="pkl motion file") +parser.add_argument("--dir", type=str, default=None, help="pkl motion folder") +parser.add_argument("--num_smplify_iters", type=int, default=150, help="num of smplify iters") +parser.add_argument("--cuda", type=bool, default=True, help="enables cuda") +parser.add_argument("--gpu_ids", type=int, default=0, help="choose gpu ids") +parser.add_argument("--num_joints", type=int, default=22, help="joint number") +parser.add_argument("--joint_category", type=str, default="AMASS", help="use correspondence") +parser.add_argument("--fix_foot", type=str, default="False", help="fix foot or not") +opt = parser.parse_args() +print(opt) + +if opt.pkl: + paths = [opt.pkl] +elif opt.dir: + paths = [] + file_list = natsort.natsorted(os.listdir(opt.dir)) + for item in file_list: + if item.endswith('.pkl') and not item.endswith("_mesh.pkl"): + paths.append(os.path.join(opt.dir, item)) +else: + raise ValueError(f'{opt.pkl} and {opt.dir} are both None!') + +for path in paths: + # load joints + if os.path.exists(path.replace('.pkl', '_mesh.pkl')): + print(f"{path} is rendered! skip!") + continue + + with open(path, 'rb') as f: + data = pickle.load(f) + + joints = data['joints'] + # load predefined something + device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu") + print(config.SMPL_MODEL_DIR) + smplxmodel = smplx.create( + config.SMPL_MODEL_DIR, + model_type="smpl", + gender="neutral", + ext="pkl", + batch_size=joints.shape[0], + ).to(device) + + # load the mean pose as original + smpl_mean_file = config.SMPL_MEAN_FILE + + file = h5py.File(smpl_mean_file, "r") + init_mean_pose = ( + torch.from_numpy(file["pose"][:]) + .unsqueeze(0).repeat(joints.shape[0], 1) + .float() + .to(device) + ) + init_mean_shape = ( + torch.from_numpy(file["shape"][:]) + .unsqueeze(0).repeat(joints.shape[0], 1) + .float() + .to(device) + ) + cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device) + + # initialize SMPLify + smplify = SMPLify3D( + smplxmodel=smplxmodel, + batch_size=joints.shape[0], + joints_category=opt.joint_category, + num_iters=opt.num_smplify_iters, + device=device, + ) + print("initialize SMPLify3D done!") + + print("Start SMPLify!") + keypoints_3d = torch.Tensor(joints).to(device).float() + + if opt.joint_category == "AMASS": + confidence_input = torch.ones(opt.num_joints) + # make sure the foot and ankle + if opt.fix_foot: + confidence_input[7] = 1.5 + confidence_input[8] = 1.5 + confidence_input[10] = 1.5 + confidence_input[11] = 1.5 + else: + print("Such category not settle down!") + + # ----- from initial to fitting ------- + ( + new_opt_vertices, + new_opt_joints, + new_opt_pose, + new_opt_betas, + new_opt_cam_t, + new_opt_joint_loss, + ) = smplify( + init_mean_pose.detach(), + init_mean_shape.detach(), + cam_trans_zero.detach(), + keypoints_3d, + conf_3d=confidence_input.to(device) + ) + + # fix shape + betas = torch.zeros_like(new_opt_betas) + root = keypoints_3d[:, 0, :] + + output = smplxmodel( + betas=betas, + global_orient=new_opt_pose[:, :3], + body_pose=new_opt_pose[:, 3:], + transl=root, + return_verts=True + ) + vertices = output.vertices.detach().cpu().numpy() + floor_height = vertices[..., 1].min() + vertices[..., 1] -= floor_height + data['vertices'] = vertices + + save_file = path.replace('.pkl', '_mesh.pkl') + with open(save_file, 'wb') as f: + pickle.dump(data, f) + print(f'vertices saved in {save_file}') diff --git a/mld/__init__.py b/mld/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/config.py b/mld/config.py new file mode 100644 index 0000000000000000000000000000000000000000..bb08a65c7b0b197b26e8896c9500935dad0f54c0 --- /dev/null +++ b/mld/config.py @@ -0,0 +1,52 @@ +import os +import importlib +from typing import Type, TypeVar +from argparse import ArgumentParser + +from omegaconf import OmegaConf, DictConfig + + +def get_module_config(cfg_model: DictConfig, paths: list[str], cfg_root: str) -> DictConfig: + files = [os.path.join(cfg_root, 'modules', p+'.yaml') for p in paths] + for file in files: + assert os.path.exists(file), f'{file} is not exists.' + with open(file, 'r') as f: + cfg_model.merge_with(OmegaConf.load(f)) + return cfg_model + + +def get_obj_from_str(string: str, reload: bool = False) -> Type: + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config: DictConfig) -> TypeVar: + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def parse_args() -> DictConfig: + parser = ArgumentParser() + parser.add_argument("--cfg", type=str, required=True, help="The main config file") + parser.add_argument('--example', type=str, required=False, help="The input texts and lengths with txt format") + parser.add_argument('--example_hint', type=str, required=False, help="The input hint ids and lengths with txt format") + parser.add_argument('--no-plot', action="store_true", required=False, help="Whether to plot the skeleton-based motion") + parser.add_argument('--replication', type=int, default=1, help="The number of replications of sampling") + parser.add_argument('--vis', type=str, default="tb", choices=['tb', 'swanlab'], help="The visualization backends: tensorboard or swanlab") + parser.add_argument('--optimize', action='store_true', help="Enable optimization for motion control") + args = parser.parse_args() + + cfg = OmegaConf.load(args.cfg) + cfg_root = os.path.dirname(args.cfg) + cfg_model = get_module_config(cfg.model, cfg.model.target, cfg_root) + cfg = OmegaConf.merge(cfg, cfg_model) + + cfg.example = args.example + cfg.example_hint = args.example_hint + cfg.no_plot = args.no_plot + cfg.replication = args.replication + cfg.vis = args.vis + cfg.optimize = args.optimize + return cfg diff --git a/mld/data/__init__.py b/mld/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/data/base.py b/mld/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ce07535387b4faebdcb1c1407684207d6b1c9d0a --- /dev/null +++ b/mld/data/base.py @@ -0,0 +1,58 @@ +import copy +from os.path import join as pjoin +from typing import Any, Callable + +from torch.utils.data import DataLoader + + +class BaseDataModule: + def __init__(self, collate_fn: Callable) -> None: + super(BaseDataModule, self).__init__() + self.collate_fn = collate_fn + self.is_mm = False + + def get_sample_set(self, overrides: dict) -> Any: + sample_params = copy.deepcopy(self.hparams) + sample_params.update(overrides) + split_file = pjoin( + eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"), + self.cfg.TEST.SPLIT + ".txt" + ) + return self.Dataset(split_file=split_file, **sample_params) + + def __getattr__(self, item: str) -> Any: + if item.endswith("_dataset") and not item.startswith("_"): + subset = item[:-len("_dataset")].upper() + item_c = "_" + item + if item_c not in self.__dict__: + split_file = pjoin( + eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"), + eval(f"self.cfg.{subset}.SPLIT") + ".txt" + ) + self.__dict__[item_c] = self.Dataset(split_file=split_file, **self.hparams) + return getattr(self, item_c) + classname = self.__class__.__name__ + raise AttributeError(f"'{classname}' object has no attribute '{item}'") + + def get_dataloader_options(self, stage: str) -> dict: + stage_args = eval(f"self.cfg.{stage.upper()}") + dataloader_options = { + "batch_size": stage_args.BATCH_SIZE, + "num_workers": stage_args.NUM_WORKERS, + "collate_fn": self.collate_fn, + "persistent_workers": stage_args.PERSISTENT_WORKERS, + } + return dataloader_options + + def train_dataloader(self) -> DataLoader: + dataloader_options = self.get_dataloader_options('TRAIN') + return DataLoader(self.train_dataset, shuffle=True, **dataloader_options) + + def val_dataloader(self) -> DataLoader: + dataloader_options = self.get_dataloader_options('VAL') + return DataLoader(self.val_dataset, shuffle=False, **dataloader_options) + + def test_dataloader(self) -> DataLoader: + dataloader_options = self.get_dataloader_options('TEST') + dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE + return DataLoader(self.test_dataset, shuffle=False, **dataloader_options) diff --git a/mld/data/data.py b/mld/data/data.py new file mode 100644 index 0000000000000000000000000000000000000000..604eab8b71c2df68cf31dfff2add0d45794471e4 --- /dev/null +++ b/mld/data/data.py @@ -0,0 +1,73 @@ +import copy +from typing import Callable, Optional + +import numpy as np +from omegaconf import DictConfig + +import torch + +from .base import BaseDataModule +from .humanml.dataset import Text2MotionDataset, MotionDataset +from .humanml.scripts.motion_process import recover_from_ric + + +# (nfeats, njoints) +dataset_map = {'humanml3d': (263, 22), 'kit': (251, 21)} + + +class DataModule(BaseDataModule): + + def __init__(self, + name: str, + cfg: DictConfig, + motion_only: bool, + collate_fn: Optional[Callable] = None, + **kwargs) -> None: + super().__init__(collate_fn=collate_fn) + self.cfg = cfg + self.name = name + self.nfeats, self.njoints = dataset_map[name] + self.hparams = copy.deepcopy({**kwargs, 'njoints': self.njoints}) + self.Dataset = MotionDataset if motion_only else Text2MotionDataset + sample_overrides = {"tiny": True, "progress_bar": False} + self._sample_set = self.get_sample_set(overrides=sample_overrides) + + def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor: + raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint) + raw_std = torch.tensor(self._sample_set.raw_std).to(hint) + hint = hint * raw_std + raw_mean + return hint + + def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor: + raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint) + raw_std = torch.tensor(self._sample_set.raw_std).to(hint) + hint = (hint - raw_mean) / raw_std + return hint + + def feats2joints(self, features: torch.Tensor) -> torch.Tensor: + mean = torch.tensor(self.hparams['mean']).to(features) + std = torch.tensor(self.hparams['std']).to(features) + features = features * std + mean + return recover_from_ric(features, self.njoints) + + def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor: + # renorm to t2m norms for using t2m evaluators + ori_mean = torch.tensor(self.hparams['mean']).to(features) + ori_std = torch.tensor(self.hparams['std']).to(features) + eval_mean = torch.tensor(self.hparams['mean_eval']).to(features) + eval_std = torch.tensor(self.hparams['std_eval']).to(features) + features = features * ori_std + ori_mean + features = (features - eval_mean) / eval_std + return features + + def mm_mode(self, mm_on: bool = True) -> None: + if mm_on: + self.is_mm = True + self.name_list = self.test_dataset.name_list + self.mm_list = np.random.choice(self.name_list, + self.cfg.TEST.MM_NUM_SAMPLES, + replace=False) + self.test_dataset.name_list = self.mm_list + else: + self.is_mm = False + self.test_dataset.name_list = self.name_list diff --git a/mld/data/get_data.py b/mld/data/get_data.py new file mode 100644 index 0000000000000000000000000000000000000000..3af566f7b4c2afb74afa6783f1752d2d1d0c5bc9 --- /dev/null +++ b/mld/data/get_data.py @@ -0,0 +1,79 @@ +from typing import Optional +from os.path import join as pjoin + +import numpy as np + +from omegaconf import DictConfig + +from .data import DataModule +from .base import BaseDataModule +from .utils import mld_collate, mld_collate_motion_only +from .humanml.utils.word_vectorizer import WordVectorizer + + +def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]: + name = "t2m" if dataset_name == "humanml3d" else dataset_name + assert name in ["t2m", "kit"] + if phase in ["val"]: + if name == 't2m': + data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta") + elif name == 'kit': + data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta") + else: + raise ValueError("Only support t2m and kit") + mean = np.load(pjoin(data_root, "mean.npy")) + std = np.load(pjoin(data_root, "std.npy")) + else: + data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") + mean = np.load(pjoin(data_root, "Mean.npy")) + std = np.load(pjoin(data_root, "Std.npy")) + + return mean, std + + +def get_WordVectorizer(cfg: DictConfig, dataset_name: str) -> Optional[WordVectorizer]: + if dataset_name.lower() in ["humanml3d", "kit"]: + return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab") + else: + raise ValueError("Only support WordVectorizer for HumanML3D and KIT") + + +dataset_module_map = {"humanml3d": DataModule, "kit": DataModule} +motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"} + + +def get_dataset(cfg: DictConfig, motion_only: bool = False) -> BaseDataModule: + dataset_name = cfg.DATASET.NAME + if dataset_name.lower() in ["humanml3d", "kit"]: + data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") + mean, std = get_mean_std('train', cfg, dataset_name) + mean_eval, std_eval = get_mean_std("val", cfg, dataset_name) + wordVectorizer = None if motion_only else get_WordVectorizer(cfg, dataset_name) + collate_fn = mld_collate_motion_only if motion_only else mld_collate + dataset = dataset_module_map[dataset_name.lower()]( + name=dataset_name.lower(), + cfg=cfg, + motion_only=motion_only, + collate_fn=collate_fn, + mean=mean, + std=std, + mean_eval=mean_eval, + std_eval=std_eval, + w_vectorizer=wordVectorizer, + text_dir=pjoin(data_root, "texts"), + motion_dir=pjoin(data_root, motion_subdir[dataset_name]), + max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN, + min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN, + max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN, + unit_length=eval(f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"), + fps=eval(f"cfg.DATASET.{dataset_name.upper()}.FRAME_RATE"), + padding_to_max=cfg.DATASET.PADDING_TO_MAX, + window_size=cfg.DATASET.WINDOW_SIZE, + control_args=eval(f"cfg.DATASET.{dataset_name.upper()}.CONTROL_ARGS")) + + cfg.DATASET.NFEATS = dataset.nfeats + cfg.DATASET.NJOINTS = dataset.njoints + return dataset + + elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]: + raise NotImplementedError diff --git a/mld/data/humanml/__init__.py b/mld/data/humanml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/data/humanml/common/quaternion.py b/mld/data/humanml/common/quaternion.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0135e523300aa970e3d797b666557172fdbe06 --- /dev/null +++ b/mld/data/humanml/common/quaternion.py @@ -0,0 +1,29 @@ +import torch + + +def qinv(q: torch.Tensor) -> torch.Tensor: + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + + +def qrot(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) diff --git a/mld/data/humanml/dataset.py b/mld/data/humanml/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..aed5c9ed80e54ae24fe413c148e541a56eacff55 --- /dev/null +++ b/mld/data/humanml/dataset.py @@ -0,0 +1,348 @@ +import os +import random +import logging +import codecs as cs +from os.path import join as pjoin + +import numpy as np +from rich.progress import track + +import torch +from torch.utils.data import Dataset + +from .scripts.motion_process import recover_from_ric +from .utils.word_vectorizer import WordVectorizer + +logger = logging.getLogger(__name__) + + +class MotionDataset(Dataset): + def __init__(self, mean: np.ndarray, std: np.ndarray, + split_file: str, motion_dir: str, window_size: int, + tiny: bool = False, progress_bar: bool = True, **kwargs) -> None: + self.data = [] + self.lengths = [] + id_list = [] + with cs.open(split_file, "r") as f: + for line in f.readlines(): + id_list.append(line.strip()) + + maxdata = 10 if tiny else 1e10 + if progress_bar: + enumerator = enumerate( + track( + id_list, + f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}", + )) + else: + enumerator = enumerate(id_list) + + count = 0 + for i, name in enumerator: + if count > maxdata: + break + try: + motion = np.load(pjoin(motion_dir, name + '.npy')) + if motion.shape[0] < window_size: + continue + self.lengths.append(motion.shape[0] - window_size) + self.data.append(motion) + except Exception as e: + print(e) + pass + + self.cumsum = np.cumsum([0] + self.lengths) + if not tiny: + logger.info("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1])) + + self.mean = mean + self.std = std + self.window_size = window_size + + def __len__(self) -> int: + return self.cumsum[-1] + + def __getitem__(self, item: int) -> tuple: + if item != 0: + motion_id = np.searchsorted(self.cumsum, item) - 1 + idx = item - self.cumsum[motion_id] - 1 + else: + motion_id = 0 + idx = 0 + motion = self.data[motion_id][idx:idx + self.window_size] + "Z Normalization" + motion = (motion - self.mean) / self.std + return motion, self.window_size + + +class Text2MotionDataset(Dataset): + + def __init__( + self, + mean: np.ndarray, + std: np.ndarray, + split_file: str, + w_vectorizer: WordVectorizer, + max_motion_length: int, + min_motion_length: int, + max_text_len: int, + unit_length: int, + motion_dir: str, + text_dir: str, + fps: int, + padding_to_max: bool, + njoints: int, + tiny: bool = False, + progress_bar: bool = True, + **kwargs, + ) -> None: + self.w_vectorizer = w_vectorizer + self.max_motion_length = max_motion_length + self.min_motion_length = min_motion_length + self.max_text_len = max_text_len + self.unit_length = unit_length + self.padding_to_max = padding_to_max + self.njoints = njoints + + data_dict = {} + id_list = [] + with cs.open(split_file, "r") as f: + for line in f.readlines(): + id_list.append(line.strip()) + self.id_list = id_list + + maxdata = 10 if tiny else 1e10 + if progress_bar: + enumerator = enumerate( + track( + id_list, + f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}", + )) + else: + enumerator = enumerate(id_list) + count = 0 + bad_count = 0 + new_name_list = [] + length_list = [] + for i, name in enumerator: + if count > maxdata: + break + try: + motion = np.load(pjoin(motion_dir, name + ".npy")) + if len(motion) < self.min_motion_length or len(motion) >= self.max_motion_length: + bad_count += 1 + continue + text_data = [] + flag = False + with cs.open(pjoin(text_dir, name + ".txt")) as f: + for line in f.readlines(): + text_dict = {} + line_split = line.strip().split("#") + caption = line_split[0] + tokens = line_split[1].split(" ") + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict["caption"] = caption + text_dict["tokens"] = tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + try: + n_motion = motion[int(f_tag * fps): int(to_tag * fps)] + if (len(n_motion)) < self.min_motion_length or \ + len(n_motion) >= self.max_motion_length: + continue + new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name + while new_name in data_dict: + new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name + data_dict[new_name] = { + "motion": n_motion, + "length": len(n_motion), + "text": [text_dict], + } + new_name_list.append(new_name) + length_list.append(len(n_motion)) + except ValueError: + print(line_split) + print(line_split[2], line_split[3], f_tag, to_tag, name) + + if flag: + data_dict[name] = { + "motion": motion, + "length": len(motion), + "text": text_data, + } + new_name_list.append(name) + length_list.append(len(motion)) + count += 1 + except Exception as e: + print(e) + pass + + name_list, length_list = zip( + *sorted(zip(new_name_list, length_list), key=lambda x: x[1])) + + if not tiny: + logger.info(f"Reading {len(self.id_list)} motions from {split_file}.") + logger.info(f"Total {len(name_list)} motions are used.") + logger.info(f"{bad_count} motion sequences not within the length range of " + f"[{self.min_motion_length}, {self.max_motion_length}) are filtered out.") + + self.mean = mean + self.std = std + + control_args = kwargs['control_args'] + self.control_mode = None + if os.path.exists(control_args.MEAN_STD_PATH): + self.raw_mean = np.load(pjoin(control_args.MEAN_STD_PATH, 'Mean_raw.npy')) + self.raw_std = np.load(pjoin(control_args.MEAN_STD_PATH, 'Std_raw.npy')) + else: + self.raw_mean = self.raw_std = None + if not tiny and control_args.CONTROL: + self.t_ctrl = control_args.TEMPORAL + self.training_control_joints = np.array(control_args.TRAIN_JOINTS) + self.testing_control_joints = np.array(control_args.TEST_JOINTS) + self.training_density = control_args.TRAIN_DENSITY + self.testing_density = control_args.TEST_DENSITY + + self.control_mode = 'val' if ('test' in split_file or 'val' in split_file) else 'train' + if self.control_mode == 'train': + logger.info(f'Training Control Joints: {self.training_control_joints}') + logger.info(f'Training Control Density: {self.training_density}') + else: + logger.info(f'Testing Control Joints: {self.testing_control_joints}') + logger.info(f'Testing Control Density: {self.testing_density}') + logger.info(f"Temporal Control: {self.t_ctrl}") + + self.data_dict = data_dict + self.name_list = name_list + + def __len__(self) -> int: + return len(self.name_list) + + def random_mask(self, joints: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + choose_joint = self.testing_control_joints + + length = joints.shape[0] + density = self.testing_density + if density in [1, 2, 5]: + choose_seq_num = density + else: + choose_seq_num = int(length * density / 100) + + if self.t_ctrl: + choose_seq = np.arange(0, choose_seq_num) + else: + choose_seq = np.random.choice(length, choose_seq_num, replace=False) + choose_seq.sort() + + mask_seq = np.zeros((length, self.njoints, 3)) + for cj in choose_joint: + mask_seq[choose_seq, cj] = 1.0 + + joints = (joints - self.raw_mean) / self.raw_std + joints = joints * mask_seq + return joints, mask_seq + + def random_mask_train(self, joints: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + if self.t_ctrl: + choose_joint = self.training_control_joints + else: + num_joints = len(self.training_control_joints) + num_joints_control = 1 + choose_joint = np.random.choice(num_joints, num_joints_control, replace=False) + choose_joint = self.training_control_joints[choose_joint] + + length = joints.shape[0] + + if self.training_density == 'random': + choose_seq_num = np.random.choice(length - 1, 1) + 1 + else: + choose_seq_num = int(length * random.uniform(self.training_density[0], self.training_density[1]) / 100) + + if self.t_ctrl: + choose_seq = np.arange(0, choose_seq_num) + else: + choose_seq = np.random.choice(length, choose_seq_num, replace=False) + choose_seq.sort() + + mask_seq = np.zeros((length, self.njoints, 3)) + for cj in choose_joint: + mask_seq[choose_seq, cj] = 1 + + joints = (joints - self.raw_mean) / self.raw_std + joints = joints * mask_seq + return joints, mask_seq + + def __getitem__(self, idx: int) -> tuple: + data = self.data_dict[self.name_list[idx]] + motion, m_length, text_list = data["motion"], data["length"], data["text"] + # Randomly select a caption + text_data = random.choice(text_list) + caption, tokens = text_data["caption"], text_data["tokens"] + + if len(tokens) < self.max_text_len: + # pad with "unk" + tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"] + sent_len = len(tokens) + tokens = tokens + ["unk/OTHER"] * (self.max_text_len + 2 - sent_len) + else: + # crop + tokens = tokens[:self.max_text_len] + tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"] + sent_len = len(tokens) + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + + # Crop the motions in to times of 4, and introduce small variations + if self.unit_length < 10: + coin2 = np.random.choice(["single", "single", "double"]) + else: + coin2 = "single" + + if coin2 == "double": + m_length = (m_length // self.unit_length - 1) * self.unit_length + elif coin2 == "single": + m_length = (m_length // self.unit_length) * self.unit_length + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx:idx + m_length] + + hint, hint_mask = None, None + if self.control_mode is not None: + joints = recover_from_ric(torch.from_numpy(motion).float(), self.njoints) + joints = joints.numpy() + if self.control_mode == 'train': + hint, hint_mask = self.random_mask_train(joints) + else: + hint, hint_mask = self.random_mask(joints) + + if self.padding_to_max: + padding = np.zeros((self.max_motion_length - m_length, *hint.shape[1:])) + hint = np.concatenate([hint, padding], axis=0) + hint_mask = np.concatenate([hint_mask, padding], axis=0) + + "Z Normalization" + motion = (motion - self.mean) / self.std + + if self.padding_to_max: + padding = np.zeros((self.max_motion_length - m_length, motion.shape[1])) + motion = np.concatenate([motion, padding], axis=0) + + return (word_embeddings, + pos_one_hots, + caption, + sent_len, + motion, + m_length, + "_".join(tokens), + (hint, hint_mask)) diff --git a/mld/data/humanml/scripts/motion_process.py b/mld/data/humanml/scripts/motion_process.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebcef746e94c586d13c8b859ae5e059890d788a --- /dev/null +++ b/mld/data/humanml/scripts/motion_process.py @@ -0,0 +1,51 @@ +import torch + +from ..common.quaternion import qinv, qrot + + +# Recover global angle and positions for rotation dataset +# root_rot_velocity (B, seq_len, 1) +# root_linear_velocity (B, seq_len, 2) +# root_y (B, seq_len, 1) +# ric_data (B, seq_len, (joint_num - 1)*3) +# rot_data (B, seq_len, (joint_num - 1)*6) +# local_velocity (B, seq_len, joint_num*3) +# foot contact (B, seq_len, 4) +def recover_root_rot_pos(data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + rot_vel = data[..., 0] + r_rot_ang = torch.zeros_like(rot_vel).to(data.device) + '''Get Y-axis rotation from rotation velocity''' + r_rot_ang[..., 1:] = rot_vel[..., :-1] + r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) + + r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device) + r_rot_quat[..., 0] = torch.cos(r_rot_ang) + r_rot_quat[..., 2] = torch.sin(r_rot_ang) + + r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device) + r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] + '''Add Y-axis rotation to root position''' + r_pos = qrot(qinv(r_rot_quat), r_pos) + + r_pos = torch.cumsum(r_pos, dim=-2) + + r_pos[..., 1] = data[..., 3] + return r_rot_quat, r_pos + + +def recover_from_ric(data: torch.Tensor, joints_num: int) -> torch.Tensor: + r_rot_quat, r_pos = recover_root_rot_pos(data) + positions = data[..., 4:(joints_num - 1) * 3 + 4] + positions = positions.view(positions.shape[:-1] + (-1, 3)) + + '''Add Y-axis rotation to local joints''' + positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) + + '''Add root XZ to joints''' + positions[..., 0] += r_pos[..., 0:1] + positions[..., 2] += r_pos[..., 2:3] + + '''Concat root and joints''' + positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) + + return positions diff --git a/mld/data/humanml/utils/__init__.py b/mld/data/humanml/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/data/humanml/utils/paramUtil.py b/mld/data/humanml/utils/paramUtil.py new file mode 100644 index 0000000000000000000000000000000000000000..019ea21f1a1ce17996c81aa559460b9733029ceb --- /dev/null +++ b/mld/data/humanml/utils/paramUtil.py @@ -0,0 +1,62 @@ +import numpy as np + +# Define a kinematic tree for the skeletal structure +kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] + +kit_raw_offsets = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1] + ] +) + +t2m_raw_offsets = np.array([[0, 0, 0], + [1, 0, 0], + [-1, 0, 0], + [0, 1, 0], + [0, -1, 0], + [0, -1, 0], + [0, 1, 0], + [0, -1, 0], + [0, -1, 0], + [0, 1, 0], + [0, 0, 1], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [-1, 0, 0], + [0, 0, 1], + [0, -1, 0], + [0, -1, 0], + [0, -1, 0], + [0, -1, 0], + [0, -1, 0], + [0, -1, 0]]) + +t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], + [9, 13, 16, 18, 20]] +t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] +t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] + +kit_tgt_skel_id = '03950' + +t2m_tgt_skel_id = '000021' diff --git a/mld/data/humanml/utils/plot_script.py b/mld/data/humanml/utils/plot_script.py new file mode 100644 index 0000000000000000000000000000000000000000..dcee7fc2497de4546ed62a9c9da5c74dc0a3372e --- /dev/null +++ b/mld/data/humanml/utils/plot_script.py @@ -0,0 +1,98 @@ +from textwrap import wrap +from typing import Optional + +import numpy as np + +import matplotlib.pyplot as plt +import mpl_toolkits.mplot3d.axes3d as p3 +from matplotlib.animation import FuncAnimation +from mpl_toolkits.mplot3d.art3d import Poly3DCollection + +import mld.data.humanml.utils.paramUtil as paramUtil + +skeleton = paramUtil.t2m_kinematic_chain + + +def plot_3d_motion(save_path: str, joints: np.ndarray, title: str, + figsize: tuple[int, int] = (3, 3), + fps: int = 120, radius: int = 3, kinematic_tree: list = skeleton, + hint: Optional[np.ndarray] = None) -> None: + + title = '\n'.join(wrap(title, 20)) + + def init(): + ax.set_xlim3d([-radius / 2, radius / 2]) + ax.set_ylim3d([0, radius]) + ax.set_zlim3d([-radius / 3., radius * 2 / 3.]) + fig.suptitle(title, fontsize=10) + ax.grid(b=False) + + def plot_xzPlane(minx, maxx, miny, minz, maxz): + # Plot a plane XZ + verts = [ + [minx, miny, minz], + [minx, miny, maxz], + [maxx, miny, maxz], + [maxx, miny, minz] + ] + xz_plane = Poly3DCollection([verts]) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) + ax.add_collection3d(xz_plane) + + # (seq_len, joints_num, 3) + data = joints.copy().reshape(len(joints), -1, 3) + + data *= 1.3 # scale for visualization + if hint is not None: + mask = hint.sum(-1) != 0 + hint = hint[mask] + hint *= 1.3 + + fig = plt.figure(figsize=figsize) + plt.tight_layout() + ax = p3.Axes3D(fig) + init() + MINS = data.min(axis=0).min(axis=0) + MAXS = data.max(axis=0).max(axis=0) + colors = ["#DD5A37", "#D69E00", "#B75A39", "#DD5A37", "#D69E00", + "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", + "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", ] + + frame_number = data.shape[0] + + height_offset = MINS[1] + data[:, :, 1] -= height_offset + if hint is not None: + hint[..., 1] -= height_offset + trajec = data[:, 0, [0, 2]] + + data[..., 0] -= data[:, 0:1, 0] + data[..., 2] -= data[:, 0:1, 2] + + def update(index): + ax.lines = [] + ax.collections = [] + ax.view_init(elev=120, azim=-90) + ax.dist = 7.5 + plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], + MAXS[2] - trajec[index, 1]) + + if hint is not None: + ax.scatter(hint[..., 0] - trajec[index, 0], hint[..., 1], hint[..., 2] - trajec[index, 1], color="#80B79A") + + for i, (chain, color) in enumerate(zip(kinematic_tree, colors)): + if i < 5: + linewidth = 4.0 + else: + linewidth = 2.0 + ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, + color=color) + + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False) + ani.save(save_path, fps=fps) + plt.close() diff --git a/mld/data/humanml/utils/word_vectorizer.py b/mld/data/humanml/utils/word_vectorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4135a41d28e8f7b4ae5ef4d5a910665e670a179a --- /dev/null +++ b/mld/data/humanml/utils/word_vectorizer.py @@ -0,0 +1,82 @@ +import pickle +from os.path import join as pjoin + +import numpy as np + + +POS_enumerator = { + 'VERB': 0, + 'NOUN': 1, + 'DET': 2, + 'ADP': 3, + 'NUM': 4, + 'AUX': 5, + 'PRON': 6, + 'ADJ': 7, + 'ADV': 8, + 'Loc_VIP': 9, + 'Body_VIP': 10, + 'Obj_VIP': 11, + 'Act_VIP': 12, + 'Desc_VIP': 13, + 'OTHER': 14 +} + +Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', + 'up', 'down', 'straight', 'curve') + +Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') + +Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') + +Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', + 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', + 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') + +Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', + 'angrily', 'sadly') + +VIP_dict = { + 'Loc_VIP': Loc_list, + 'Body_VIP': Body_list, + 'Obj_VIP': Obj_List, + 'Act_VIP': Act_list, + 'Desc_VIP': Desc_list, +} + + +class WordVectorizer(object): + def __init__(self, meta_root: str, prefix: str) -> None: + vectors = np.load(pjoin(meta_root, '%s_data.npy' % prefix)) + words = pickle.load(open(pjoin(meta_root, '%s_words.pkl' % prefix), 'rb')) + word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl' % prefix), 'rb')) + self.word2vec = {w: vectors[word2idx[w]] for w in words} + + def _get_pos_ohot(self, pos: str) -> np.ndarray: + pos_vec = np.zeros(len(POS_enumerator)) + if pos in POS_enumerator: + pos_vec[POS_enumerator[pos]] = 1 + else: + pos_vec[POS_enumerator['OTHER']] = 1 + return pos_vec + + def __len__(self) -> int: + return len(self.word2vec) + + def __getitem__(self, item: str) -> tuple: + word, pos = item.split('/') + if word in self.word2vec: + word_vec = self.word2vec[word] + vip_pos = None + for key, values in VIP_dict.items(): + if word in values: + vip_pos = key + break + if vip_pos is not None: + pos_vec = self._get_pos_ohot(vip_pos) + else: + pos_vec = self._get_pos_ohot(pos) + else: + word_vec = self.word2vec['unk'] + pos_vec = self._get_pos_ohot('OTHER') + return word_vec, pos_vec diff --git a/mld/data/utils.py b/mld/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f685848d355000ca0b8b75fe8b2105db3aabedad --- /dev/null +++ b/mld/data/utils.py @@ -0,0 +1,52 @@ +import torch + +from mld.utils.temos_utils import lengths_to_mask + + +def collate_tensors(batch: list) -> torch.Tensor: + dims = batch[0].dim() + max_size = [max([b.size(i) for b in batch]) for i in range(dims)] + size = (len(batch), ) + tuple(max_size) + canvas = batch[0].new_zeros(size=size) + for i, b in enumerate(batch): + sub_tensor = canvas[i] + for d in range(dims): + sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) + sub_tensor.add_(b) + return canvas + + +def mld_collate(batch: list) -> dict: + notnone_batches = [b for b in batch if b is not None] + notnone_batches.sort(key=lambda x: x[3], reverse=True) + adapted_batch = { + "motion": + collate_tensors([torch.tensor(b[4]).float() for b in notnone_batches]), + "text": [b[2] for b in notnone_batches], + "length": [b[5] for b in notnone_batches], + "word_embs": + collate_tensors([torch.tensor(b[0]).float() for b in notnone_batches]), + "pos_ohot": + collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]), + "text_len": + collate_tensors([torch.tensor(b[3]) for b in notnone_batches]), + "tokens": [b[6] for b in notnone_batches] + } + + mask = lengths_to_mask(adapted_batch['length'], adapted_batch['motion'].device, adapted_batch['motion'].shape[1]) + adapted_batch['mask'] = mask + + # collate trajectory + if notnone_batches[0][-1][0] is not None: + adapted_batch['hint'] = collate_tensors([torch.tensor(b[-1][0]).float() for b in notnone_batches]) + adapted_batch['hint_mask'] = collate_tensors([torch.tensor(b[-1][1]).float() for b in notnone_batches]) + + return adapted_batch + + +def mld_collate_motion_only(batch: list) -> dict: + batch = { + "motion": collate_tensors([torch.tensor(b[0]).float() for b in batch]), + "length": [b[1] for b in batch] + } + return batch diff --git a/mld/launch/__init__.py b/mld/launch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/launch/blender.py b/mld/launch/blender.py new file mode 100644 index 0000000000000000000000000000000000000000..09a7fae39721c918050bd15fb3356b58d34b78e9 --- /dev/null +++ b/mld/launch/blender.py @@ -0,0 +1,23 @@ +# Fix blender path +import os +import sys +from argparse import ArgumentParser + +sys.path.append(os.path.expanduser("~/.local/lib/python3.9/site-packages")) + + +# Monkey patch argparse such that +# blender / python parsing works +def parse_args(self, args=None, namespace=None): + if args is not None: + return self.parse_args_bak(args=args, namespace=namespace) + try: + idx = sys.argv.index("--") + args = sys.argv[idx + 1:] # the list after '--' + except ValueError as e: # '--' not in the list: + args = [] + return self.parse_args_bak(args=args, namespace=namespace) + + +setattr(ArgumentParser, 'parse_args_bak', ArgumentParser.parse_args) +setattr(ArgumentParser, 'parse_args', parse_args) diff --git a/mld/models/__init__.py b/mld/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/models/architectures/__init__.py b/mld/models/architectures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/models/architectures/dno.py b/mld/models/architectures/dno.py new file mode 100644 index 0000000000000000000000000000000000000000..c5fcf4c7c9926d8d6645ffad71c786515c442f49 --- /dev/null +++ b/mld/models/architectures/dno.py @@ -0,0 +1,79 @@ +import os + +import torch +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter + + +class DNO(object): + def __init__( + self, + optimize: bool, + max_train_steps: int, + learning_rate: float, + lr_scheduler: str, + lr_warmup_steps: int, + clip_grad: bool, + loss_hint_type: str, + loss_diff_penalty: float, + loss_correlate_penalty: float, + visualize_samples: int, + visualize_ske_steps: list[int], + output_dir: str + ) -> None: + + self.optimize = optimize + self.max_train_steps = max_train_steps + self.learning_rate = learning_rate + self.lr_scheduler = lr_scheduler + self.lr_warmup_steps = lr_warmup_steps + self.clip_grad = clip_grad + self.loss_hint_type = loss_hint_type + self.loss_diff_penalty = loss_diff_penalty + self.loss_correlate_penalty = loss_correlate_penalty + + if loss_hint_type == 'l1': + self.loss_hint_func = F.l1_loss + elif loss_hint_type == 'l1_smooth': + self.loss_hint_func = F.smooth_l1_loss + elif loss_hint_type == 'l2': + self.loss_hint_func = F.mse_loss + else: + raise ValueError(f'Invalid loss type: {loss_hint_type}') + + self.visualize_samples = float('inf') if visualize_samples == 'inf' else visualize_samples + assert self.visualize_samples >= 0 + self.visualize_samples_done = 0 + self.visualize_ske_steps = visualize_ske_steps + if len(visualize_ske_steps) > 0: + self.vis_dir = os.path.join(output_dir, 'vis_optimize') + os.makedirs(self.vis_dir) + + self.writer = None + self.output_dir = output_dir + if self.visualize_samples > 0: + self.writer = SummaryWriter(output_dir) + + @property + def do_visualize(self): + return self.visualize_samples_done < self.visualize_samples + + @staticmethod + def noise_regularize_1d(noise: torch.Tensor, stop_at: int = 2, dim: int = 1) -> torch.Tensor: + size = noise.shape[dim] + if size & (size - 1) != 0: + new_size = 2 ** (size - 1).bit_length() + pad = new_size - size + pad_shape = list(noise.shape) + pad_shape[dim] = pad + pad_noise = torch.randn(*pad_shape, device=noise.device) + noise = torch.cat([noise, pad_noise], dim=dim) + size = noise.shape[dim] + + loss = torch.zeros(noise.shape[0], device=noise.device) + while size > stop_at: + rolled_noise = torch.roll(noise, shifts=1, dims=dim) + loss += (noise * rolled_noise).mean(dim=tuple(range(1, noise.ndim))).pow(2) + noise = noise.view(*noise.shape[:dim], size // 2, 2, *noise.shape[dim + 1:]).mean(dim=dim + 1) + size //= 2 + return loss diff --git a/mld/models/architectures/mld_clip.py b/mld/models/architectures/mld_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9b3545e363a651e33270054378fefdaccb2815 --- /dev/null +++ b/mld/models/architectures/mld_clip.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn + +from transformers import AutoModel, AutoTokenizer +from sentence_transformers import SentenceTransformer + + +class MldTextEncoder(nn.Module): + + def __init__(self, modelpath: str, last_hidden_state: bool = False) -> None: + super().__init__() + + if 't5' in modelpath: + self.text_model = SentenceTransformer(modelpath) + self.tokenizer = self.text_model.tokenizer + else: + self.tokenizer = AutoTokenizer.from_pretrained(modelpath) + self.text_model = AutoModel.from_pretrained(modelpath) + + self.max_length = self.tokenizer.model_max_length + if "clip" in modelpath: + self.text_encoded_dim = self.text_model.config.text_config.hidden_size + if last_hidden_state: + self.name = "clip_hidden" + else: + self.name = "clip" + elif "bert" in modelpath: + self.name = "bert" + self.text_encoded_dim = self.text_model.config.hidden_size + elif 't5' in modelpath: + self.name = 't5' + else: + raise ValueError(f"Model {modelpath} not supported") + + def forward(self, texts: list[str]) -> torch.Tensor: + # get prompt text embeddings + if self.name in ["clip", "clip_hidden"]: + text_inputs = self.tokenizer( + texts, + padding="max_length", + truncation=True, + max_length=self.max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + # split into max length Clip can handle + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length] + elif self.name == "bert": + text_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + + if self.name == "clip": + # (batch_Size, text_encoded_dim) + text_embeddings = self.text_model.get_text_features( + text_input_ids.to(self.text_model.device)) + # (batch_Size, 1, text_encoded_dim) + text_embeddings = text_embeddings.unsqueeze(1) + elif self.name == "clip_hidden": + # (batch_Size, seq_length , text_encoded_dim) + text_embeddings = self.text_model.text_model( + text_input_ids.to(self.text_model.device)).last_hidden_state + elif self.name == "bert": + # (batch_Size, seq_length , text_encoded_dim) + text_embeddings = self.text_model( + **text_inputs.to(self.text_model.device)).last_hidden_state + elif self.name == 't5': + text_embeddings = self.text_model.encode(texts, show_progress_bar=False, convert_to_tensor=True, batch_size=len(texts)) + text_embeddings = text_embeddings.unsqueeze(1) + else: + raise NotImplementedError(f"Model {self.name} not implemented") + + return text_embeddings diff --git a/mld/models/architectures/mld_denoiser.py b/mld/models/architectures/mld_denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..01ab62e12ad19fd0161d0eb67447e061a2661844 --- /dev/null +++ b/mld/models/architectures/mld_denoiser.py @@ -0,0 +1,200 @@ +from typing import Optional, Union + +import torch +import torch.nn as nn + +from mld.models.operator.embeddings import TimestepEmbedding, Timesteps +from mld.models.operator.attention import (SkipTransformerEncoder, + SkipTransformerDecoder, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer) +from mld.models.operator.moe import MoeTransformerEncoderLayer, MoeTransformerDecoderLayer +from mld.models.operator.utils import get_clones, get_activation_fn, zero_module +from mld.models.operator.position_encoding import build_position_encoding + + +def load_balancing_loss_func(router_logits: tuple, num_experts: int = 4, topk: int = 2): + router_logits = torch.cat(router_logits, dim=0) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1) + _, selected_experts = torch.topk(routing_weights, topk, dim=-1) + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + router_prob_per_expert = torch.mean(routing_weights, dim=0) + overall_loss = num_experts * torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss + + +class MldDenoiser(nn.Module): + + def __init__(self, + latent_dim: list = [1, 256], + hidden_dim: Optional[int] = None, + text_dim: int = 768, + time_dim: int = 768, + ff_size: int = 1024, + num_layers: int = 9, + num_heads: int = 4, + dropout: float = 0.1, + normalize_before: bool = False, + norm_eps: float = 1e-5, + activation: str = "gelu", + norm_post: bool = True, + activation_post: Optional[str] = None, + flip_sin_to_cos: bool = True, + freq_shift: float = 0, + time_act_fn: str = 'silu', + time_post_act_fn: Optional[str] = None, + position_embedding: str = "learned", + arch: str = "trans_enc", + add_mem_pos: bool = True, + force_pre_post_proj: bool = False, + text_act_fn: str = 'relu', + time_cond_proj_dim: Optional[int] = None, + zero_init_cond: bool = True, + is_controlnet: bool = False, + controlnet_embed_dim: Optional[int] = None, + controlnet_act_fn: str = 'silu', + moe: bool = False, + moe_num_experts: int = 4, + moe_topk: int = 2, + moe_loss_weight: float = 1e-2, + moe_jitter_noise: Optional[float] = None + ) -> None: + super(MldDenoiser, self).__init__() + + self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim + add_pre_post_proj = force_pre_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1]) + self.latent_pre = nn.Linear(latent_dim[-1], self.latent_dim) if add_pre_post_proj else nn.Identity() + self.latent_post = nn.Linear(self.latent_dim, latent_dim[-1]) if add_pre_post_proj else nn.Identity() + + self.arch = arch + self.time_cond_proj_dim = time_cond_proj_dim + + self.moe_num_experts = moe_num_experts + self.moe_topk = moe_topk + self.moe_loss_weight = moe_loss_weight + + self.time_proj = Timesteps(time_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(time_dim, self.latent_dim, time_act_fn, post_act_fn=time_post_act_fn, + cond_proj_dim=time_cond_proj_dim, zero_init_cond=zero_init_cond) + self.emb_proj = nn.Sequential(get_activation_fn(text_act_fn), nn.Linear(text_dim, self.latent_dim)) + + self.query_pos = build_position_encoding(self.latent_dim, position_embedding=position_embedding) + if self.arch == "trans_enc": + if moe: + encoder_layer = MoeTransformerEncoderLayer( + self.latent_dim, num_heads, moe_num_experts, moe_topk, ff_size, + dropout, activation, normalize_before, norm_eps, moe_jitter_noise) + else: + encoder_layer = TransformerEncoderLayer( + self.latent_dim, num_heads, ff_size, dropout, + activation, normalize_before, norm_eps) + + encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post and not is_controlnet else None + self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post, + is_controlnet=is_controlnet, is_moe=moe) + + elif self.arch == 'trans_dec': + if add_mem_pos: + self.mem_pos = build_position_encoding(self.latent_dim, position_embedding=position_embedding) + else: + self.mem_pos = None + if moe: + decoder_layer = MoeTransformerDecoderLayer( + self.latent_dim, num_heads, moe_num_experts, moe_topk, ff_size, + dropout, activation, normalize_before, norm_eps, moe_jitter_noise) + else: + decoder_layer = TransformerDecoderLayer( + self.latent_dim, num_heads, ff_size, dropout, + activation, normalize_before, norm_eps) + + decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post and not is_controlnet else None + self.decoder = SkipTransformerDecoder(decoder_layer, num_layers, decoder_norm, activation_post, + is_controlnet=is_controlnet, is_moe=moe) + else: + raise ValueError(f"Not supported architecture: {self.arch}!") + + self.is_controlnet = is_controlnet + if self.is_controlnet: + embed_dim = controlnet_embed_dim if controlnet_embed_dim is not None else self.latent_dim + modules = [ + nn.Linear(latent_dim[-1], embed_dim), + get_activation_fn(controlnet_act_fn) if controlnet_act_fn else None, + nn.Linear(embed_dim, embed_dim), + get_activation_fn(controlnet_act_fn) if controlnet_act_fn else None, + zero_module(nn.Linear(embed_dim, latent_dim[-1])) + ] + self.controlnet_cond_embedding = nn.Sequential(*[m for m in modules if m is not None]) + + self.controlnet_down_mid_blocks = nn.ModuleList([ + zero_module(nn.Linear(self.latent_dim, self.latent_dim)) for _ in range(num_layers)]) + + def forward(self, + sample: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep_cond: Optional[torch.Tensor] = None, + controlnet_cond: Optional[torch.Tensor] = None, + controlnet_residuals: Optional[list[torch.Tensor]] = None + ) -> tuple: + + # 0. check if controlnet + if self.is_controlnet: + sample = sample + self.controlnet_cond_embedding(controlnet_cond) + + # 1. dimension matching (pre) + sample = sample.permute(1, 0, 2) + sample = self.latent_pre(sample) + + # 2. time_embedding + timesteps = timestep.expand(sample.shape[1]).clone() + time_emb = self.time_proj(timesteps) + time_emb = time_emb.to(dtype=sample.dtype) + # [1, bs, latent_dim] <= [bs, latent_dim] + time_emb = self.time_embedding(time_emb, timestep_cond).unsqueeze(0) + + # 3. condition + time embedding + # text_emb [seq_len, batch_size, text_dim] <= [batch_size, seq_len, text_dim] + encoder_hidden_states = encoder_hidden_states.permute(1, 0, 2) + # text embedding projection + text_emb_latent = self.emb_proj(encoder_hidden_states) + emb_latent = torch.cat((time_emb, text_emb_latent), 0) + + # 4. transformer + if self.arch == "trans_enc": + xseq = torch.cat((sample, emb_latent), axis=0) + xseq = self.query_pos(xseq) + tokens, intermediates, router_logits = self.encoder(xseq, controlnet_residuals=controlnet_residuals) + elif self.arch == 'trans_dec': + sample = self.query_pos(sample) + if self.mem_pos: + emb_latent = self.mem_pos(emb_latent) + tokens, intermediates, router_logits = self.decoder(sample, emb_latent, + controlnet_residuals=controlnet_residuals) + else: + raise TypeError(f"{self.arch} is not supported") + + router_loss = None + if router_logits is not None: + router_loss = load_balancing_loss_func(router_logits, self.moe_num_experts, self.moe_topk) + router_loss = self.moe_loss_weight * router_loss + + if self.is_controlnet: + control_res_samples = [] + for res, block in zip(intermediates, self.controlnet_down_mid_blocks): + r = block(res) + control_res_samples.append(r) + return control_res_samples, router_loss + elif self.arch == "trans_enc": + sample = tokens[:sample.shape[0]] + elif self.arch == 'trans_dec': + sample = tokens + else: + raise TypeError(f"{self.arch} is not supported") + + # 5. dimension matching (post) + sample = self.latent_post(sample) + sample = sample.permute(1, 0, 2) + return sample, router_loss diff --git a/mld/models/architectures/mld_traj_encoder.py b/mld/models/architectures/mld_traj_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0e5df65bf6b0030d246040c51d444f1a1f31df62 --- /dev/null +++ b/mld/models/architectures/mld_traj_encoder.py @@ -0,0 +1,64 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from mld.models.operator.attention import SkipTransformerEncoder, TransformerEncoderLayer +from mld.models.operator.position_encoding import build_position_encoding + + +class MldTrajEncoder(nn.Module): + + def __init__(self, + nfeats: int, + latent_dim: list = [1, 256], + hidden_dim: Optional[int] = None, + force_post_proj: bool = False, + ff_size: int = 1024, + num_layers: int = 9, + num_heads: int = 4, + dropout: float = 0.1, + normalize_before: bool = False, + norm_eps: float = 1e-5, + activation: str = "gelu", + norm_post: bool = True, + activation_post: Optional[str] = None, + position_embedding: str = "learned") -> None: + super(MldTrajEncoder, self).__init__() + + self.latent_size = latent_dim[0] + self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim + add_post_proj = force_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1]) + self.latent_proj = nn.Linear(self.latent_dim, latent_dim[-1]) if add_post_proj else nn.Identity() + + self.skel_embedding = nn.Linear(nfeats * 3, self.latent_dim) + + self.query_pos_encoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + + encoder_layer = TransformerEncoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + norm_eps + ) + encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None + self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post) + self.global_motion_token = nn.Parameter(torch.randn(self.latent_size, self.latent_dim)) + + def forward(self, features: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + bs, nframes, nfeats = features.shape + x = self.skel_embedding(features) + x = x.permute(1, 0, 2) + dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1)) + dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device) + aug_mask = torch.cat((dist_masks, mask), 1) + xseq = torch.cat((dist, x), 0) + xseq = self.query_pos_encoder(xseq) + global_token = self.encoder(xseq, src_key_padding_mask=~aug_mask)[0][:dist.shape[0]] + global_token = self.latent_proj(global_token) + global_token = global_token.permute(1, 0, 2) + return global_token diff --git a/mld/models/architectures/mld_vae.py b/mld/models/architectures/mld_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..32946ec5e2a1afcc3ea0bb40dcb07608dd4e6787 --- /dev/null +++ b/mld/models/architectures/mld_vae.py @@ -0,0 +1,136 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch.distributions.distribution import Distribution + +from mld.models.operator.attention import ( + SkipTransformerEncoder, + SkipTransformerDecoder, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer +) +from mld.models.operator.position_encoding import build_position_encoding + + +class MldVae(nn.Module): + + def __init__(self, + nfeats: int, + latent_dim: list = [1, 256], + hidden_dim: Optional[int] = None, + force_pre_post_proj: bool = False, + ff_size: int = 1024, + num_layers: int = 9, + num_heads: int = 4, + dropout: float = 0.1, + arch: str = "encoder_decoder", + normalize_before: bool = False, + norm_eps: float = 1e-5, + activation: str = "gelu", + norm_post: bool = True, + activation_post: Optional[str] = None, + position_embedding: str = "learned") -> None: + super(MldVae, self).__init__() + + self.latent_size = latent_dim[0] + self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim + add_pre_post_proj = force_pre_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1]) + self.latent_pre = nn.Linear(self.latent_dim, latent_dim[-1]) if add_pre_post_proj else nn.Identity() + self.latent_post = nn.Linear(latent_dim[-1], self.latent_dim) if add_pre_post_proj else nn.Identity() + + self.arch = arch + + self.query_pos_encoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + + encoder_layer = TransformerEncoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + norm_eps + ) + encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None + self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post) + + if self.arch == "all_encoder": + decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None + self.decoder = SkipTransformerEncoder(encoder_layer, num_layers, decoder_norm, activation_post) + elif self.arch == 'encoder_decoder': + self.query_pos_decoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + + decoder_layer = TransformerDecoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + norm_eps + ) + decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None + self.decoder = SkipTransformerDecoder(decoder_layer, num_layers, decoder_norm, activation_post) + else: + raise ValueError(f"Not support architecture: {self.arch}!") + + self.global_motion_token = nn.Parameter(torch.randn(self.latent_size * 2, self.latent_dim)) + self.skel_embedding = nn.Linear(nfeats, self.latent_dim) + self.final_layer = nn.Linear(self.latent_dim, nfeats) + + def forward(self, features: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, Distribution]: + z, dist = self.encode(features, mask) + feats_rst = self.decode(z, mask) + return feats_rst, z, dist + + def encode(self, features: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, Distribution]: + bs, nframes, nfeats = features.shape + x = self.skel_embedding(features) + x = x.permute(1, 0, 2) + dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1)) + dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device) + aug_mask = torch.cat((dist_masks, mask), 1) + xseq = torch.cat((dist, x), 0) + + xseq = self.query_pos_encoder(xseq) + dist = self.encoder(xseq, src_key_padding_mask=~aug_mask)[0][:dist.shape[0]] + dist = self.latent_pre(dist) + + mu = dist[0:self.latent_size, ...] + logvar = dist[self.latent_size:, ...] + + std = logvar.exp().pow(0.5) + dist = torch.distributions.Normal(mu, std) + latent = dist.rsample() + # [latent_dim[0], batch_size, latent_dim] -> [batch_size, latent_dim[0], latent_dim[1]] + latent = latent.permute(1, 0, 2) + return latent, dist + + def decode(self, z: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # [batch_size, latent_dim[0], latent_dim[1]] -> [latent_dim[0], batch_size, latent_dim[1]] + z = self.latent_post(z) + z = z.permute(1, 0, 2) + bs, nframes = mask.shape + queries = torch.zeros(nframes, bs, self.latent_dim, device=z.device) + + if self.arch == "all_encoder": + xseq = torch.cat((z, queries), axis=0) + z_mask = torch.ones((bs, self.latent_size), dtype=torch.bool, device=z.device) + aug_mask = torch.cat((z_mask, mask), axis=1) + xseq = self.query_pos_decoder(xseq) + output = self.decoder(xseq, src_key_padding_mask=~aug_mask)[0][z.shape[0]:] + elif self.arch == "encoder_decoder": + queries = self.query_pos_decoder(queries) + output = self.decoder(tgt=queries, memory=z, tgt_key_padding_mask=~mask)[0] + else: + raise ValueError(f"Not support architecture: {self.arch}!") + + output = self.final_layer(output) + output[~mask.T] = 0 + feats = output.permute(1, 0, 2) + return feats diff --git a/mld/models/architectures/t2m_motionenc.py b/mld/models/architectures/t2m_motionenc.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9783abdd37b641ce3d841b69f2003ee75c3f5b --- /dev/null +++ b/mld/models/architectures/t2m_motionenc.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence + + +class MovementConvEncoder(nn.Module): + def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None: + super(MovementConvEncoder, self).__init__() + self.main = nn.Sequential( + nn.Conv1d(input_size, hidden_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(hidden_size, output_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True)) + self.out_net = nn.Linear(output_size, output_size) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return self.out_net(outputs) + + +class MotionEncoderBiGRUCo(nn.Module): + def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None: + super(MotionEncoderBiGRUCo, self).__init__() + + self.input_emb = nn.Linear(input_size, hidden_size) + self.gru = nn.GRU( + hidden_size, hidden_size, batch_first=True, bidirectional=True + ) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size), + ) + + self.hidden_size = hidden_size + self.hidden = nn.Parameter( + torch.randn((2, 1, self.hidden_size), requires_grad=True) + ) + + def forward(self, inputs: torch.Tensor, m_lens: torch.Tensor) -> torch.Tensor: + num_samples = inputs.shape[0] + + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = m_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) diff --git a/mld/models/architectures/t2m_textenc.py b/mld/models/architectures/t2m_textenc.py new file mode 100644 index 0000000000000000000000000000000000000000..076b5be9bc9a79540af37bbf1bbc6b0fe41d5077 --- /dev/null +++ b/mld/models/architectures/t2m_textenc.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence + + +class TextEncoderBiGRUCo(nn.Module): + def __init__(self, word_size: int, pos_size: int, hidden_size: int, output_size: int) -> None: + super(TextEncoderBiGRUCo, self).__init__() + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size)) + + self.hidden_size = hidden_size + self.hidden = nn.Parameter( + torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + def forward(self, word_embs: torch.Tensor, pos_onehot: torch.Tensor, + cap_lens: torch.Tensor) -> torch.Tensor: + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) diff --git a/mld/models/metrics/__init__.py b/mld/models/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7104981155af2ec081f0e97724f86a9631a392fa --- /dev/null +++ b/mld/models/metrics/__init__.py @@ -0,0 +1,4 @@ +from .tm2t import TM2TMetrics +from .mm import MMMetrics +from .cm import ControlMetrics +from .pos import PosMetrics diff --git a/mld/models/metrics/cm.py b/mld/models/metrics/cm.py new file mode 100644 index 0000000000000000000000000000000000000000..a03ad2a3b86f3a88979ab675090e9c7b8382d7a7 --- /dev/null +++ b/mld/models/metrics/cm.py @@ -0,0 +1,53 @@ +import torch +from torchmetrics import Metric +from torchmetrics.utilities import dim_zero_cat + +from mld.utils.temos_utils import remove_padding +from .utils import calculate_skating_ratio, calculate_trajectory_error, control_l2 + + +class ControlMetrics(Metric): + + def __init__(self, dataset_name: str, dist_sync_on_step: bool = True) -> None: + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "Control errors" + self.dataset_name = dataset_name + + self.add_state("count_seq", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("skate_ratio_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("dist_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("traj_err", default=[], dist_reduce_fx="cat") + self.traj_err_key = ["traj_fail_20cm", "traj_fail_50cm", "kps_fail_20cm", "kps_fail_50cm", "kps_mean_err(m)"] + + def compute(self) -> dict: + count_seq = self.count_seq.item() + + metrics = dict() + metrics['Skating Ratio'] = self.skate_ratio_sum / count_seq + metrics['Control L2 dist'] = self.dist_sum / count_seq + traj_err = dim_zero_cat(self.traj_err).mean(0) + + for (k, v) in zip(self.traj_err_key, traj_err): + metrics[k] = v + + return metrics + + def update(self, joints: torch.Tensor, hint: torch.Tensor, + hint_mask: torch.Tensor, lengths: list[int]) -> None: + self.count_seq += len(lengths) + + joints_no_padding = remove_padding(joints, lengths) + for j in joints_no_padding: + skate_ratio, _ = calculate_skating_ratio(j.unsqueeze(0), self.dataset_name) + self.skate_ratio_sum += skate_ratio[0] + + hint_mask = hint_mask.sum(dim=-1, keepdim=True) != 0 + for j, h, m in zip(joints, hint, hint_mask): + control_error = control_l2(j, h, m) + mean_error = control_error.sum() / m.sum() + self.dist_sum += mean_error + control_error = control_error.reshape(-1) + m = m.reshape(-1) + err_np = calculate_trajectory_error(control_error, mean_error, m) + self.traj_err.append(err_np[None].to(joints.device)) diff --git a/mld/models/metrics/mm.py b/mld/models/metrics/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..347247d94c5d6e0b546147abbf230f7b0bc241e1 --- /dev/null +++ b/mld/models/metrics/mm.py @@ -0,0 +1,40 @@ +import torch +from torchmetrics import Metric +from torchmetrics.utilities import dim_zero_cat + +from .utils import calculate_multimodality_np + + +class MMMetrics(Metric): + + def __init__(self, mm_num_times: int = 10, dist_sync_on_step: bool = True) -> None: + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "MultiModality scores" + + self.mm_num_times = mm_num_times + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", default=torch.tensor(0), dist_reduce_fx="sum") + + self.metrics = ["MultiModality"] + self.add_state("MultiModality", default=torch.tensor(0.), dist_reduce_fx="sum") + + # cached batches + self.add_state("mm_motion_embeddings", default=[], dist_reduce_fx='cat') + + def compute(self) -> dict: + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # cat all embeddings + all_mm_motions = dim_zero_cat(self.mm_motion_embeddings).cpu().numpy() + metrics['MultiModality'] = calculate_multimodality_np(all_mm_motions, self.mm_num_times) + return metrics + + def update(self, mm_motion_embeddings: torch.Tensor, lengths: list[int]) -> None: + self.count += sum(lengths) + self.count_seq += len(lengths) + + # store all mm motion embeddings + self.mm_motion_embeddings.append(mm_motion_embeddings) diff --git a/mld/models/metrics/pos.py b/mld/models/metrics/pos.py new file mode 100644 index 0000000000000000000000000000000000000000..dc273218d748b7c39d65ce253f095dc29c517220 --- /dev/null +++ b/mld/models/metrics/pos.py @@ -0,0 +1,41 @@ +import torch +from torchmetrics import Metric + +from mld.utils.temos_utils import remove_padding +from .utils import calculate_mpjpe + + +class PosMetrics(Metric): + + def __init__(self, dist_sync_on_step: bool = True) -> None: + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "MPJPE (aligned & unaligned), Feature l2 error" + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("mpjpe_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("mpjpe_unaligned_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("feature_error_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + + def compute(self) -> dict: + metric = dict(MPJPE=self.mpjpe_sum / self.count, + MPJPE_unaligned=self.mpjpe_unaligned_sum / self.count, + FeaError=self.feature_error_sum / self.count) + return metric + + def update(self, joints_ref: torch.Tensor, joints_rst: torch.Tensor, + feats_ref: torch.Tensor, feats_rst: torch.Tensor, lengths: list[int]) -> None: + self.count += sum(lengths) + joints_rst = remove_padding(joints_rst, lengths) + joints_ref = remove_padding(joints_ref, lengths) + feats_ref = remove_padding(feats_ref, lengths) + feats_rst = remove_padding(feats_rst, lengths) + + for f1, f2 in zip(feats_ref, feats_rst): + self.feature_error_sum += torch.norm(f1 - f2, p=2) + + for j1, j2 in zip(joints_ref, joints_rst): + mpjpe = torch.sum(calculate_mpjpe(j1, j2)) + self.mpjpe_sum += mpjpe + mpjpe_unaligned = torch.sum(calculate_mpjpe(j1, j2, align_root=False)) + self.mpjpe_unaligned_sum += mpjpe_unaligned diff --git a/mld/models/metrics/tm2t.py b/mld/models/metrics/tm2t.py new file mode 100644 index 0000000000000000000000000000000000000000..2d5476631a3e649fde4afe9967a4748e4a7504d5 --- /dev/null +++ b/mld/models/metrics/tm2t.py @@ -0,0 +1,141 @@ +import torch + +from torchmetrics import Metric +from torchmetrics.utilities import dim_zero_cat + +from .utils import (euclidean_distance_matrix, calculate_top_k, calculate_diversity_np, + calculate_activation_statistics_np, calculate_frechet_distance_np) + + +class TM2TMetrics(Metric): + + def __init__(self, + top_k: int = 3, + R_size: int = 32, + diversity_times: int = 300, + dist_sync_on_step: bool = True) -> None: + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "Matching, FID, and Diversity scores" + + self.top_k = top_k + self.R_size = R_size + self.diversity_times = diversity_times + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = [] + # Matching scores + self.add_state("Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("gt_Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.Matching_metrics = ["Matching_score", "gt_Matching_score"] + for k in range(1, top_k + 1): + self.add_state( + f"R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"R_precision_top_{str(k)}") + for k in range(1, top_k + 1): + self.add_state( + f"gt_R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}") + + self.metrics.extend(self.Matching_metrics) + + # FID + self.add_state("FID", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.metrics.append("FID") + + # Diversity + self.add_state("Diversity", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("gt_Diversity", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.metrics.extend(["Diversity", "gt_Diversity"]) + + # cached batches + self.add_state("text_embeddings", default=[], dist_reduce_fx='cat') + self.add_state("recmotion_embeddings", default=[], dist_reduce_fx='cat') + self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx='cat') + + def compute(self) -> dict: + count_seq = self.count_seq.item() + + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + shuffle_idx = torch.randperm(count_seq) + all_texts = dim_zero_cat(self.text_embeddings).cpu()[shuffle_idx, :] + all_genmotions = dim_zero_cat(self.recmotion_embeddings).cpu()[shuffle_idx, :] + all_gtmotions = dim_zero_cat(self.gtmotion_embeddings).cpu()[shuffle_idx, :] + + # Compute r-precision + assert count_seq >= self.R_size + top_k_mat = torch.zeros((self.top_k,)) + for i in range(count_seq // self.R_size): + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + group_motions = all_genmotions[i * self.R_size:(i + 1) * self.R_size] + dist_mat = euclidean_distance_matrix(group_texts, group_motions).nan_to_num() + self.Matching_score += dist_mat.trace() + argmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argmax, top_k=self.top_k).sum(axis=0) + R_count = count_seq // self.R_size * self.R_size + metrics["Matching_score"] = self.Matching_score / R_count + for k in range(self.top_k): + metrics[f"R_precision_top_{str(k + 1)}"] = top_k_mat[k] / R_count + + # Compute r-precision with gt + assert count_seq >= self.R_size + top_k_mat = torch.zeros((self.top_k,)) + for i in range(count_seq // self.R_size): + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + group_motions = all_gtmotions[i * self.R_size:(i + 1) * self.R_size] + dist_mat = euclidean_distance_matrix(group_texts, group_motions).nan_to_num() + self.gt_Matching_score += dist_mat.trace() + argmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argmax, top_k=self.top_k).sum(axis=0) + metrics["gt_Matching_score"] = self.gt_Matching_score / R_count + for k in range(self.top_k): + metrics[f"gt_R_precision_top_{str(k + 1)}"] = top_k_mat[k] / R_count + + all_genmotions = all_genmotions.numpy() + all_gtmotions = all_gtmotions.numpy() + + # Compute fid + mu, cov = calculate_activation_statistics_np(all_genmotions) + gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov) + + # Compute diversity + assert count_seq >= self.diversity_times + metrics["Diversity"] = calculate_diversity_np(all_genmotions, self.diversity_times) + metrics["gt_Diversity"] = calculate_diversity_np(all_gtmotions, self.diversity_times) + + return {**metrics} + + def update( + self, + text_embeddings: torch.Tensor, + recmotion_embeddings: torch.Tensor, + gtmotion_embeddings: torch.Tensor, + lengths: list[int]) -> None: + self.count += sum(lengths) + self.count_seq += len(lengths) + + # store all texts and motions + self.text_embeddings.append(text_embeddings.detach()) + self.recmotion_embeddings.append(recmotion_embeddings.detach()) + self.gtmotion_embeddings.append(gtmotion_embeddings.detach()) diff --git a/mld/models/metrics/utils.py b/mld/models/metrics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dcf3be180b5d1f04d99314ef4d34bba66093462f --- /dev/null +++ b/mld/models/metrics/utils.py @@ -0,0 +1,276 @@ +import numpy as np + +import scipy.linalg +from scipy.ndimage import uniform_filter1d + +import torch +from torch import linalg + + +# Motion Reconstruction + +def calculate_mpjpe(gt_joints: torch.Tensor, pred_joints: torch.Tensor, align_root: bool = True) -> torch.Tensor: + """ + gt_joints: num_poses x num_joints x 3 + pred_joints: num_poses x num_joints x 3 + (obtained from recover_from_ric()) + """ + assert gt_joints.shape == pred_joints.shape, \ + f"GT shape: {gt_joints.shape}, pred shape: {pred_joints.shape}" + + # Align by root (pelvis) + if align_root: + gt_joints = gt_joints - gt_joints[:, [0]] + pred_joints = pred_joints - pred_joints[:, [0]] + + # Compute MPJPE + mpjpe = torch.linalg.norm(pred_joints - gt_joints, dim=-1) # num_poses x num_joints + mpjpe = mpjpe.mean(-1) # num_poses + + return mpjpe + + +# Text-to-Motion + +# (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train +def euclidean_distance_matrix(matrix1: torch.Tensor, matrix2: torch.Tensor) -> torch.Tensor: + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dists: N1 x N2 + dists[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * torch.mm(matrix1, matrix2.T) # shape (num_test, num_train) + d2 = torch.sum(torch.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) + d3 = torch.sum(torch.square(matrix2), axis=1) # shape (num_train, ) + dists = torch.sqrt(d1 + d2 + d3) # broadcasting + return dists + + +def euclidean_distance_matrix_np(matrix1: np.ndarray, matrix2: np.ndarray) -> np.ndarray: + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dists: N1 x N2 + dists[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) + d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) + d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) + dists = np.sqrt(d1 + d2 + d3) # broadcasting + return dists + + +def calculate_top_k(mat: torch.Tensor, top_k: int) -> torch.Tensor: + size = mat.shape[0] + gt_mat = (torch.unsqueeze(torch.arange(size), 1).to(mat.device).repeat_interleave(size, 1)) + bool_mat = mat == gt_mat + correct_vec = False + top_k_list = [] + for i in range(top_k): + correct_vec = correct_vec | bool_mat[:, i] + top_k_list.append(correct_vec[:, None]) + top_k_mat = torch.cat(top_k_list, dim=1) + return top_k_mat + + +def calculate_activation_statistics(activations: torch.Tensor) -> tuple: + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + activations = activations.cpu().numpy() + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return mu, sigma + + +def calculate_activation_statistics_np(activations: np.ndarray) -> tuple: + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + mu = np.mean(activations, axis=0) + cov = np.cov(activations, rowvar=False) + return mu, cov + + +def calculate_frechet_distance_np( + mu1: np.ndarray, + sigma1: np.ndarray, + mu2: np.ndarray, + sigma2: np.ndarray, + eps: float = 1e-6) -> float: + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert (mu1.shape == mu2.shape + ), "Training and test mean vectors have different lengths" + assert (sigma1.shape == sigma2.shape + ), "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + # Product might be almost singular + covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ("fid calculation produces singular product; " + "adding %s to diagonal of cov estimates") % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +def calculate_diversity(activation: torch.Tensor, diversity_times: int) -> float: + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + second_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + dist = linalg.norm(activation[first_indices] - activation[second_indices], + axis=1) + return dist.mean() + + +def calculate_diversity_np(activation: np.ndarray, diversity_times: int) -> float: + assert len(activation.shape) == 2 + assert activation.shape[0] >= diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + second_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + dist = scipy.linalg.norm(activation[first_indices] - + activation[second_indices], + axis=1) + return dist.mean() + + +def calculate_multimodality_np(activation: np.ndarray, multimodality_times: int) -> float: + assert len(activation.shape) == 3 + assert activation.shape[1] > multimodality_times + num_per_sent = activation.shape[1] + + first_dices = np.random.choice(num_per_sent, + multimodality_times, + replace=False) + second_dices = np.random.choice(num_per_sent, + multimodality_times, + replace=False) + dist = scipy.linalg.norm(activation[:, first_dices] - + activation[:, second_dices], + axis=2) + return dist.mean() + + +# Motion Control + +def calculate_skating_ratio(motions: torch.Tensor, dataset_name: str) -> tuple: + thresh_height = 0.05 + fps = 20.0 + thresh_vel = 0.50 + avg_window = 5 # frames + + # XZ plane, y up + # 10 left, 11 right foot. (HumanML3D) + # 15 left, 20 right foot. (KIT) + # motions [bsz, fs, 22 or 21, 3] + + if dataset_name == 'humanml3d': + foot_idx = [10, 11] + elif dataset_name == 'kit': + foot_idx = [15, 20] + else: + raise ValueError(f'Invalid Dataset: {dataset_name}') + + verts_feet = motions[:, :, foot_idx, :].detach().cpu().numpy() # [bsz, fs, 2, 3] + verts_feet_plane_vel = np.linalg.norm(verts_feet[:, 1:, :, [0, 2]] - + verts_feet[:, :-1, :, [0, 2]], axis=-1) * fps # [bsz, fs-1, 2] + vel_avg = uniform_filter1d(verts_feet_plane_vel, axis=1, size=avg_window, mode='constant', origin=0) + + verts_feet_height = verts_feet[:, :, :, 1] # [bsz, fs, 2] + # If feet touch ground in adjacent frames + feet_contact = np.logical_and((verts_feet_height[:, :-1, :] < thresh_height), + (verts_feet_height[:, 1:, :] < thresh_height)) # [bs, fs-1, 2] + # skate velocity + skate_vel = feet_contact * vel_avg + + skating = np.logical_and(feet_contact, (verts_feet_plane_vel > thresh_vel)) + skating = np.logical_and(skating, (vel_avg > thresh_vel)) + + # Both feet slide + skating = np.logical_or(skating[:, :, 0], skating[:, :, 1]) # [bs, fs-1] + skating_ratio = np.sum(skating, axis=1) / skating.shape[1] + + return skating_ratio, skate_vel + + +def calculate_trajectory_error(dist_error: torch.Tensor, mean_err_traj: torch.Tensor, + mask: torch.Tensor, strict: bool = True) -> torch.Tensor: + if strict: + # Traj fails if any of the key frame fails + traj_fail_02 = 1.0 - int((dist_error <= 0.2).all().item()) + traj_fail_05 = 1.0 - int((dist_error <= 0.5).all().item()) + else: + # Traj fails if the mean error of all keyframes more than the threshold + traj_fail_02 = int((mean_err_traj > 0.2).item()) + traj_fail_05 = int((mean_err_traj > 0.5).item()) + all_fail_02 = (dist_error > 0.2).sum() / mask.sum() + all_fail_05 = (dist_error > 0.5).sum() / mask.sum() + + return torch.tensor([traj_fail_02, traj_fail_05, all_fail_02, all_fail_05, dist_error.sum() / mask.sum()]) + + +def control_l2(motion: torch.Tensor, hint: torch.Tensor, hint_mask: torch.Tensor) -> torch.Tensor: + loss = torch.norm((motion - hint) * hint_mask, p=2, dim=-1) + return loss diff --git a/mld/models/modeltype/__init__.py b/mld/models/modeltype/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/models/modeltype/base.py b/mld/models/modeltype/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ca700491738e41cbb7d0a8d74635aa9517a1ba01 --- /dev/null +++ b/mld/models/modeltype/base.py @@ -0,0 +1,155 @@ +import os +from typing import Any +from collections import OrderedDict + +import numpy as np +from omegaconf import DictConfig + +import torch +import torch.nn as nn + +from mld.models.metrics import TM2TMetrics, MMMetrics, ControlMetrics, PosMetrics +from mld.models.architectures import t2m_motionenc, t2m_textenc + + +class BaseModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.times = [] + self.text_encoder_times = [] + self.traj_encoder_times = [] + self.diffusion_times = [] + self.vae_decode_times = [] + self.vae_encode_times = [] + self.frames = [] + + def _get_t2m_evaluator(self, cfg: DictConfig) -> None: + self.t2m_moveencoder = t2m_motionenc.MovementConvEncoder( + input_size=cfg.DATASET.NFEATS - 4, + hidden_size=cfg.model.t2m_motionencoder.dim_move_hidden, + output_size=cfg.model.t2m_motionencoder.dim_move_latent) + + self.t2m_motionencoder = t2m_motionenc.MotionEncoderBiGRUCo( + input_size=cfg.model.t2m_motionencoder.dim_move_latent, + hidden_size=cfg.model.t2m_motionencoder.dim_motion_hidden, + output_size=cfg.model.t2m_motionencoder.dim_motion_latent) + + self.t2m_textencoder = t2m_textenc.TextEncoderBiGRUCo( + word_size=cfg.model.t2m_textencoder.dim_word, + pos_size=cfg.model.t2m_textencoder.dim_pos_ohot, + hidden_size=cfg.model.t2m_textencoder.dim_text_hidden, + output_size=cfg.model.t2m_textencoder.dim_coemb_hidden) + + # load pretrained + dataname = cfg.DATASET.NAME + dataname = "t2m" if dataname == "humanml3d" else dataname + t2m_checkpoint = torch.load( + os.path.join(cfg.model.t2m_path, dataname, "text_mot_match/model/finest.tar"), map_location='cpu') + self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) + self.t2m_moveencoder.load_state_dict(t2m_checkpoint["movement_encoder"]) + self.t2m_motionencoder.load_state_dict(t2m_checkpoint["motion_encoder"]) + + # freeze params + self.t2m_textencoder.eval() + self.t2m_moveencoder.eval() + self.t2m_motionencoder.eval() + for p in self.t2m_textencoder.parameters(): + p.requires_grad = False + for p in self.t2m_moveencoder.parameters(): + p.requires_grad = False + for p in self.t2m_motionencoder.parameters(): + p.requires_grad = False + + def test_step(self, batch: dict) -> None: + total_samples = len(self.frames) + message = '' + if len(self.times) > 0: + inference_aits = round(np.sum(self.times) / total_samples, 5) + message += f"\nAverage Inference Time per Sentence ({total_samples}): {inference_aits}\n" + + if len(self.text_encoder_times) > 0: + inference_aits_text = round(np.sum(self.text_encoder_times) / total_samples, 5) + message += f"Average Inference Time per Sentence [Text]: {inference_aits_text}\n" + + if len(self.traj_encoder_times) > 0: + inference_aits_hint = round(np.sum(self.traj_encoder_times) / total_samples, 5) + message += f"Average Inference Time per Sentence [Hint]: {inference_aits_hint}\n" + + if len(self.diffusion_times) > 0: + inference_aits_diff = round(np.sum(self.diffusion_times) / total_samples, 5) + message += f"Average Inference Time per Sentence [Diffusion]: {inference_aits_diff}\n" + + if len(self.vae_encode_times) > 0: + inference_aits_vae_e = round(np.sum(self.vae_encode_times) / total_samples, 5) + message += f"Average Inference Time per Sentence [VAE Encode]: {inference_aits_vae_e}\n" + + if len(self.vae_decode_times) > 0: + inference_aits_vae_d = round(np.sum(self.vae_decode_times) / total_samples, 5) + message += f"Average Inference Time per Sentence [VAE Decode]: {inference_aits_vae_d}\n" + + if len(self.frames) > 0: + message += f"Average length: {round(np.mean(self.frames), 5)}\n" + message += f"FPS: {np.sum(self.frames) / np.sum(self.times)}\n" + + if message: + print(message) + + return self.allsplit_step("test", batch) + + def allsplit_epoch_end(self) -> dict: + res = dict() + if self.datamodule.is_mm and "TM2TMetrics" in self.metric_list: + metric_list = ['MMMetrics'] + else: + metric_list = self.metric_list + for metric in metric_list: + metrics_dict = getattr(self, metric).compute() + # reset metrics + getattr(self, metric).reset() + res.update({ + f"Metrics/{metric}": value.item() + for metric, value in metrics_dict.items() + }) + return res + + def on_save_checkpoint(self, checkpoint: dict) -> None: + state_dict = checkpoint['state_dict'] + if hasattr(self, 'text_encoder'): + clip_k = [] + for k, v in state_dict.items(): + if 'text_encoder' in k: + clip_k.append(k) + for k in clip_k: + del checkpoint['state_dict'][k] + + def load_state_dict(self, state_dict: dict, strict: bool = True) -> Any: + if hasattr(self, 'text_encoder'): + clip_state_dict = self.text_encoder.state_dict() + new_state_dict = OrderedDict() + for k, v in clip_state_dict.items(): + new_state_dict['text_encoder.' + k] = v + for k, v in state_dict.items(): + if 'text_encoder' not in k: + new_state_dict[k] = v + return super().load_state_dict(new_state_dict, strict) + else: + return super().load_state_dict(state_dict, strict) + + def configure_metrics(self) -> None: + for metric in self.metric_list: + if metric == "TM2TMetrics": + self.TM2TMetrics = TM2TMetrics( + diversity_times=self.cfg.TEST.DIVERSITY_TIMES, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP) + elif metric == 'ControlMetrics': + self.ControlMetrics = ControlMetrics(self.datamodule.name, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP) + elif metric == 'PosMetrics': + self.PosMetrics = PosMetrics(dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP) + else: + raise NotImplementedError(f"Do not support Metric Type {metric}.") + + if "TM2TMetrics" in self.metric_list and self.cfg.TEST.DO_MM_TEST: + self.MMMetrics = MMMetrics( + mm_num_times=self.cfg.TEST.MM_NUM_TIMES, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP) diff --git a/mld/models/modeltype/mld.py b/mld/models/modeltype/mld.py new file mode 100644 index 0000000000000000000000000000000000000000..7edf838c05bf3364538f64f44491b59fcfcf7327 --- /dev/null +++ b/mld/models/modeltype/mld.py @@ -0,0 +1,608 @@ +import time +import inspect +import logging +from typing import Optional + +import tqdm +import numpy as np +from omegaconf import DictConfig + +import torch +import torch.nn.functional as F +from diffusers.optimization import get_scheduler + +from mld.data.base import BaseDataModule +from mld.config import instantiate_from_config +from mld.utils.temos_utils import lengths_to_mask, remove_padding +from mld.utils.utils import count_parameters, get_guidance_scale_embedding, extract_into_tensor, control_loss_calculate +from mld.data.humanml.utils.plot_script import plot_3d_motion + +from .base import BaseModel + +logger = logging.getLogger(__name__) + + +class MLD(BaseModel): + def __init__(self, cfg: DictConfig, datamodule: BaseDataModule) -> None: + super().__init__() + + self.cfg = cfg + self.nfeats = cfg.DATASET.NFEATS + self.njoints = cfg.DATASET.NJOINTS + self.latent_dim = cfg.model.latent_dim + self.guidance_scale = cfg.model.guidance_scale + self.datamodule = datamodule + + if cfg.model.guidance_scale == 'dynamic': + s_cfg = cfg.model.scheduler + self.guidance_scale = s_cfg.cfg_step_map[s_cfg.num_inference_steps] + logger.info(f'Guidance Scale set as {self.guidance_scale}') + + self.text_encoder = instantiate_from_config(cfg.model.text_encoder) + self.vae = instantiate_from_config(cfg.model.motion_vae) + self.denoiser = instantiate_from_config(cfg.model.denoiser) + + self.scheduler = instantiate_from_config(cfg.model.scheduler) + self.alphas = torch.sqrt(self.scheduler.alphas_cumprod) + self.sigmas = torch.sqrt(1 - self.scheduler.alphas_cumprod) + + self._get_t2m_evaluator(cfg) + + self.metric_list = cfg.METRIC.TYPE + self.configure_metrics() + + self.feats2joints = datamodule.feats2joints + + self.vae_scale_factor = cfg.model.get("vae_scale_factor", 1.0) + self.guidance_uncondp = cfg.model.get('guidance_uncondp', 0.0) + + logger.info(f"vae_scale_factor: {self.vae_scale_factor}") + logger.info(f"prediction_type: {self.scheduler.config.prediction_type}") + logger.info(f"guidance_scale: {self.guidance_scale}") + logger.info(f"guidance_uncondp: {self.guidance_uncondp}") + + self.is_controlnet = cfg.model.get('is_controlnet', False) + if self.is_controlnet: + c_cfg = self.cfg.model.denoiser.copy() + c_cfg['params']['is_controlnet'] = True + self.controlnet = instantiate_from_config(c_cfg) + self.traj_encoder = instantiate_from_config(cfg.model.traj_encoder) + + self.vaeloss = cfg.model.get('vaeloss', False) + self.vaeloss_type = cfg.model.get('vaeloss_type', 'sum') + self.cond_ratio = cfg.model.get('cond_ratio', 0.0) + self.rot_ratio = cfg.model.get('rot_ratio', 0.0) + self.control_loss_func = cfg.model.get('control_loss_func', 'l2') + if self.vaeloss and self.cond_ratio == 0.0 and self.rot_ratio == 0.0: + raise ValueError("Error: When 'vaeloss' is True, 'cond_ratio' and 'rot_ratio' cannot both be 0.") + self.use_3d = cfg.model.get('use_3d', False) + self.guess_mode = cfg.model.get('guess_mode', False) + if self.guess_mode and not self.do_classifier_free_guidance: + raise ValueError( + "Invalid configuration: 'guess_mode' is enabled, but 'do_classifier_free_guidance' is not. " + "Ensure that 'do_classifier_free_guidance' is True (MLD) when 'guess_mode' is active." + ) + self.lcm_w_min_nax = cfg.model.get('lcm_w_min_nax') + self.lcm_num_ddim_timesteps = cfg.model.get('lcm_num_ddim_timesteps') + if (self.lcm_w_min_nax is not None or self.lcm_num_ddim_timesteps is not None) and self.denoiser.time_cond_proj_dim is None: + raise ValueError( + "Invalid configuration: When either 'lcm_w_min_nax' or 'lcm_num_ddim_timesteps' is not None, " + "'denoiser.time_cond_proj_dim' must be None (MotionLCM)." + ) + + logger.info(f"vaeloss: {self.vaeloss}, " + f"vaeloss_type: {self.vaeloss_type}, " + f"cond_ratio: {self.cond_ratio}, " + f"rot_ratio: {self.rot_ratio}, " + f"control_loss_func: {self.control_loss_func}") + logger.info(f"use_3d: {self.use_3d}, " + f"guess_mode: {self.guess_mode}") + logger.info(f"lcm_w_min_nax: {self.lcm_w_min_nax}, " + f"lcm_num_ddim_timesteps: {self.lcm_num_ddim_timesteps}") + + time.sleep(2) # 留个心眼 + + self.dno = instantiate_from_config(cfg.model['noise_optimizer']) \ + if cfg.model.get('noise_optimizer') else None + + self.summarize_parameters() + + @property + def do_classifier_free_guidance(self) -> bool: + return self.guidance_scale > 1 and self.denoiser.time_cond_proj_dim is None + + def summarize_parameters(self) -> None: + logger.info(f'VAE Encoder: {count_parameters(self.vae.encoder)}M') + logger.info(f'VAE Decoder: {count_parameters(self.vae.decoder)}M') + logger.info(f'Denoiser: {count_parameters(self.denoiser)}M') + + if self.is_controlnet: + traj_encoder = count_parameters(self.traj_encoder) + controlnet = count_parameters(self.controlnet) + logger.info(f'ControlNet: {controlnet}M') + logger.info(f'Trajectory Encoder: {traj_encoder}M') + + def forward(self, batch: dict) -> tuple: + texts = batch["text"] + feats_ref = batch.get("motion") + lengths = batch["length"] + hint = batch.get('hint') + hint_mask = batch.get('hint_mask') + + if self.do_classifier_free_guidance: + texts = texts + [""] * len(texts) + + text_emb = self.text_encoder(texts) + + controlnet_cond = None + if self.is_controlnet: + assert hint is not None + hint_reshaped = hint.view(hint.shape[0], hint.shape[1], -1) + hint_mask_reshaped = hint_mask.view(hint_mask.shape[0], hint_mask.shape[1], -1).sum(dim=-1) != 0 + controlnet_cond = self.traj_encoder(hint_reshaped, hint_mask_reshaped) + + latents = torch.randn((len(lengths), *self.latent_dim), device=text_emb.device) + mask = batch.get('mask', lengths_to_mask(lengths, text_emb.device)) + + if hint is not None and self.dno and self.dno.optimize: + latents = self._optimize_latents( + latents, text_emb, texts, lengths, mask, hint, hint_mask, + controlnet_cond=controlnet_cond, feats_ref=feats_ref) + + latents = self._diffusion_reverse(latents, text_emb, controlnet_cond=controlnet_cond) + feats_rst = self.vae.decode(latents / self.vae_scale_factor, mask) + + joints = self.feats2joints(feats_rst.detach().cpu()) + joints = remove_padding(joints, lengths) + + joints_ref = None + if feats_ref is not None: + joints_ref = self.feats2joints(feats_ref.detach().cpu()) + joints_ref = remove_padding(joints_ref, lengths) + + return joints, joints_ref + + def predicted_origin(self, model_output: torch.Tensor, timesteps: torch.Tensor, sample: torch.Tensor) -> tuple: + self.alphas = self.alphas.to(model_output.device) + self.sigmas = self.sigmas.to(model_output.device) + alphas = extract_into_tensor(self.alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(self.sigmas, timesteps, sample.shape) + + if self.scheduler.config.prediction_type == "epsilon": + pred_original_sample = (sample - sigmas * model_output) / alphas + pred_epsilon = model_output + elif self.scheduler.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alphas * model_output) / sigmas + else: + raise ValueError(f"Invalid prediction_type {self.scheduler.config.prediction_type}.") + + return pred_original_sample, pred_epsilon + + @torch.enable_grad() + def _optimize_latents( + self, + latents: torch.Tensor, + encoder_hidden_states: torch.Tensor, + texts: list[str], lengths: list[int], mask: torch.Tensor, + hint: torch.Tensor, hint_mask: torch.Tensor, + controlnet_cond: Optional[torch.Tensor] = None, + feats_ref: Optional[torch.Tensor] = None + ) -> torch.Tensor: + + current_latents = latents.clone().requires_grad_(True) + optimizer = torch.optim.Adam([current_latents], lr=self.dno.learning_rate) + lr_scheduler = get_scheduler( + self.dno.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=self.dno.lr_warmup_steps, + num_training_steps=self.dno.max_train_steps) + + do_visualize = self.dno.do_visualize + vis_id = self.dno.visualize_samples_done + hint_3d = self.datamodule.denorm_spatial(hint) * hint_mask + for step in tqdm.tqdm(range(1, self.dno.max_train_steps + 1)): + z_pred = self._diffusion_reverse(current_latents, encoder_hidden_states, controlnet_cond=controlnet_cond) + feats_rst = self.vae.decode(z_pred / self.vae_scale_factor, mask) + joints_rst = self.feats2joints(feats_rst) + + loss_hint = self.dno.loss_hint_func(joints_rst, hint_3d, reduction='none') * hint_mask + loss_hint = loss_hint.sum(dim=[1, 2, 3]) / hint_mask.sum(dim=[1, 2, 3]) + loss_diff = (current_latents - latents).norm(p=2, dim=[1, 2]) + loss_correlate = self.dno.noise_regularize_1d(current_latents, dim=1) + loss = loss_hint + self.dno.loss_correlate_penalty * loss_correlate + self.dno.loss_diff_penalty * loss_diff + loss_mean = loss.mean() + optimizer.zero_grad() + loss_mean.backward() + + grad_norm = current_latents.grad.norm(p=2, dim=[1, 2], keepdim=True) + if self.dno.clip_grad: + current_latents.grad.data /= grad_norm + + # Visualize + if do_visualize: + control_error = torch.norm((joints_rst - hint_3d) * hint_mask, p=2, dim=-1) + control_error = control_error.sum(dim=[1, 2]) / hint_mask.mean(dim=-1).sum(dim=[1, 2]) + for batch_id in range(latents.shape[0]): + metrics = { + 'loss': loss[batch_id].item(), + 'loss_hint': loss_hint[batch_id].mean().item(), + 'loss_diff': loss_diff[batch_id].item(), + 'loss_correlate': loss_correlate[batch_id].item(), + 'grad_norm': grad_norm[batch_id].item(), + 'lr': lr_scheduler.get_last_lr()[0], + 'control_error': control_error[batch_id].item() + } + for metric_name, metric_value in metrics.items(): + self.dno.writer.add_scalar(f'Optimize_{vis_id + batch_id}/{metric_name}', metric_value, step) + + if step in self.dno.visualize_ske_steps: + joints_rst_no_pad = joints_rst[batch_id][:lengths[batch_id]].detach().cpu().numpy() + hint_3d_no_pad = hint_3d[batch_id][:lengths[batch_id]].detach().cpu().numpy() + plot_3d_motion(f'{self.dno.vis_dir}/vis_id_{vis_id + batch_id}_step_{step}.mp4', + joints=joints_rst_no_pad, title=texts[batch_id], hint=hint_3d_no_pad, + fps=eval(f"self.cfg.DATASET.{self.cfg.DATASET.NAME.upper()}.FRAME_RATE")) + + optimizer.step() + lr_scheduler.step() + + if feats_ref is not None and do_visualize and len(self.dno.visualize_ske_steps) > 0: + joints_ref = self.feats2joints(feats_ref) + for batch_id in range(latents.shape[0]): + joints_ref_no_pad = joints_ref[batch_id][:lengths[batch_id]].detach().cpu().numpy() + hint_3d_no_pad = hint_3d[batch_id][:lengths[batch_id]].detach().cpu().numpy() + plot_3d_motion(f'{self.dno.vis_dir}/vis_id_{vis_id + batch_id}_ref.mp4', + joints=joints_ref_no_pad, title=texts[batch_id], hint=hint_3d_no_pad, + fps=eval(f"self.cfg.DATASET.{self.cfg.DATASET.NAME.upper()}.FRAME_RATE")) + + self.dno.visualize_samples_done += latents.shape[0] + return current_latents.detach() + + def _diffusion_reverse( + self, + latents: torch.Tensor, + encoder_hidden_states: torch.Tensor, + controlnet_cond: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + # set timesteps + self.scheduler.set_timesteps(self.cfg.model.scheduler.num_inference_steps) + timesteps = self.scheduler.timesteps.to(encoder_hidden_states.device) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, and between [0, 1] + extra_step_kwargs = {} + if "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys()): + extra_step_kwargs["eta"] = self.cfg.model.scheduler.eta + + timestep_cond = None + if self.denoiser.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(latents.shape[0]) + timestep_cond = get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.denoiser.time_cond_proj_dim + ).to(device=latents.device, dtype=latents.dtype) + + if self.is_controlnet and self.do_classifier_free_guidance and not self.guess_mode: + controlnet_cond = torch.cat([controlnet_cond] * 2) + + for i, t in tqdm.tqdm(enumerate(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = (torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + controlnet_residuals = None + if self.is_controlnet: + if self.do_classifier_free_guidance and self.guess_mode: + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = encoder_hidden_states.chunk(2)[0] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = encoder_hidden_states + + controlnet_residuals = self.controlnet( + sample=control_model_input, + timestep=t, + timestep_cond=timestep_cond, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=controlnet_cond)[0] + + if self.do_classifier_free_guidance and self.guess_mode: + controlnet_residuals = [torch.cat([d, torch.zeros_like(d)], dim=1) for d in controlnet_residuals] + + # predict the noise residual + model_output = self.denoiser( + sample=latent_model_input, + timestep=t, + timestep_cond=timestep_cond, + encoder_hidden_states=encoder_hidden_states, + controlnet_residuals=controlnet_residuals)[0] + + # perform guidance + if self.do_classifier_free_guidance: + model_output_text, model_output_uncond = model_output.chunk(2) + model_output = model_output_uncond + self.guidance_scale * (model_output_text - model_output_uncond) + + latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample + + return latents + + def _diffusion_process(self, latents: torch.Tensor, encoder_hidden_states: torch.Tensor, + hint: Optional[torch.Tensor] = None, hint_mask: Optional[torch.Tensor] = None) -> dict: + + controlnet_cond = None + if self.is_controlnet: + assert hint is not None + hint_reshaped = hint.view(hint.shape[0], hint.shape[1], -1) + hint_mask_reshaped = hint_mask.view(hint_mask.shape[0], hint_mask.shape[1], -1).sum(-1) != 0 + controlnet_cond = self.traj_encoder(hint_reshaped, mask=hint_mask_reshaped) + + # [batch_size, n_token, latent_dim] + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + if self.denoiser.time_cond_proj_dim is not None and self.lcm_num_ddim_timesteps is not None: + step_size = self.scheduler.config.num_train_timesteps // self.lcm_num_ddim_timesteps + candidate_timesteps = torch.arange( + start=step_size - 1, + end=self.scheduler.config.num_train_timesteps, + step=step_size, + device=latents.device + ) + timesteps = candidate_timesteps[torch.randint( + low=0, + high=candidate_timesteps.size(0), + size=(bsz,), + device=latents.device + )] + else: + timesteps = torch.randint( + 0, + self.scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device + ) + timesteps = timesteps.long() + noisy_latents = self.scheduler.add_noise(latents.clone(), noise, timesteps) + + timestep_cond = None + if self.denoiser.time_cond_proj_dim is not None: + if self.lcm_w_min_nax is None: + w = torch.tensor(self.guidance_scale - 1).repeat(latents.shape[0]) + else: + w = (self.lcm_w_min_nax[1] - self.lcm_w_min_nax[0]) * torch.rand((bsz,)) + self.lcm_w_min_nax[0] + timestep_cond = get_guidance_scale_embedding( + w, embedding_dim=self.denoiser.time_cond_proj_dim + ).to(device=latents.device, dtype=latents.dtype) + + controlnet_residuals = None + router_loss_controlnet = None + if self.is_controlnet: + controlnet_residuals, router_loss_controlnet = self.controlnet( + sample=noisy_latents, + timestep=timesteps, + timestep_cond=timestep_cond, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_cond) + + model_output, router_loss = self.denoiser( + sample=noisy_latents, + timestep=timesteps, + timestep_cond=timestep_cond, + encoder_hidden_states=encoder_hidden_states, + controlnet_residuals=controlnet_residuals) + + latents_pred, noise_pred = self.predicted_origin(model_output, timesteps, noisy_latents) + + n_set = { + "noise": noise, + "noise_pred": noise_pred, + "sample_pred": latents_pred, + "sample_gt": latents, + "router_loss": router_loss_controlnet if self.is_controlnet else router_loss + } + return n_set + + def train_diffusion_forward(self, batch: dict) -> dict: + feats_ref = batch["motion"] + mask = batch['mask'] + hint = batch.get('hint', None) + hint_mask = batch.get('hint_mask', None) + + with torch.no_grad(): + z, dist = self.vae.encode(feats_ref, mask) + z = z * self.vae_scale_factor + + text = batch["text"] + text = [ + "" if np.random.rand(1) < self.guidance_uncondp else i + for i in text + ] + text_emb = self.text_encoder(text) + n_set = self._diffusion_process(z, text_emb, hint=hint, hint_mask=hint_mask) + + loss_dict = dict() + + if self.denoiser.time_cond_proj_dim is not None: + # LCM (only used in motion ControlNet) + model_pred, target = n_set['sample_pred'], n_set['sample_gt'] + # Performance comparison: l2 loss > huber loss when training controlnet for LCM + diff_loss = F.mse_loss(model_pred, target, reduction="mean") + else: + # DM + if self.scheduler.config.prediction_type == "epsilon": + model_pred, target = n_set['noise_pred'], n_set['noise'] + elif self.scheduler.config.prediction_type == "sample": + model_pred, target = n_set['sample_pred'], n_set['sample_gt'] + else: + raise ValueError(f"Invalid prediction_type {self.scheduler.config.prediction_type}.") + diff_loss = F.mse_loss(model_pred, target, reduction="mean") + + loss_dict['diff_loss'] = diff_loss + + # Router loss + loss_dict['router_loss'] = n_set['router_loss'] if n_set['router_loss'] is not None \ + else torch.tensor(0., device=diff_loss.device) + + if self.is_controlnet and self.vaeloss: + feats_rst = self.vae.decode(n_set['sample_pred'] / self.vae_scale_factor, mask) + + if self.cond_ratio != 0: + joints_rst = self.feats2joints(feats_rst) + if self.use_3d: + hint = self.datamodule.denorm_spatial(hint) + else: + joints_rst = self.datamodule.norm_spatial(joints_rst) + hint_mask = hint_mask.sum(-1, keepdim=True) != 0 + cond_loss = control_loss_calculate(self.vaeloss_type, self.control_loss_func, joints_rst, hint, + hint_mask) + loss_dict['cond_loss'] = self.cond_ratio * cond_loss + else: + loss_dict['cond_loss'] = torch.tensor(0., device=diff_loss.device) + + if self.rot_ratio != 0: + mask = mask.unsqueeze(-1) + rot_loss = control_loss_calculate(self.vaeloss_type, self.control_loss_func, feats_rst, feats_ref, mask) + loss_dict['rot_loss'] = self.rot_ratio * rot_loss + else: + loss_dict['rot_loss'] = torch.tensor(0., device=diff_loss.device) + + else: + loss_dict['cond_loss'] = loss_dict['rot_loss'] = torch.tensor(0., device=diff_loss.device) + + total_loss = sum(loss_dict.values()) + loss_dict['loss'] = total_loss + return loss_dict + + def t2m_eval(self, batch: dict) -> dict: + texts = batch["text"] + feats_ref = batch["motion"] + mask = batch['mask'] + lengths = batch["length"] + word_embs = batch["word_embs"] + pos_ohot = batch["pos_ohot"] + text_lengths = batch["text_len"] + hint = batch.get('hint', None) + hint_mask = batch.get('hint_mask', None) + + start = time.time() + + if self.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + feats_ref = feats_ref.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, dim=0) + mask = mask.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, dim=0) + text_lengths = text_lengths.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, dim=0) + hint = hint and hint.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, dim=0) + hint_mask = hint_mask and hint_mask.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + if self.do_classifier_free_guidance: + texts = texts + [""] * len(texts) + + text_st = time.time() + text_emb = self.text_encoder(texts) + text_et = time.time() + self.text_encoder_times.append(text_et - text_st) + + controlnet_cond = None + if self.is_controlnet: + assert hint is not None + hint_st = time.time() + hint_reshaped = hint.view(hint.shape[0], hint.shape[1], -1) + hint_mask_reshaped = hint_mask.view(hint_mask.shape[0], hint_mask.shape[1], -1).sum(dim=-1) != 0 + controlnet_cond = self.traj_encoder(hint_reshaped, hint_mask_reshaped) + hint_et = time.time() + self.traj_encoder_times.append(hint_et - hint_st) + + diff_st = time.time() + + latents = torch.randn((feats_ref.shape[0], *self.latent_dim), device=text_emb.device) + + if hint is not None and self.dno and self.dno.optimize: + latents = self._optimize_latents( + latents, text_emb, texts, lengths, mask, hint, hint_mask, + controlnet_cond=controlnet_cond, feats_ref=feats_ref) + + latents = self._diffusion_reverse(latents, text_emb, controlnet_cond=controlnet_cond) + + diff_et = time.time() + self.diffusion_times.append(diff_et - diff_st) + + vae_st = time.time() + feats_rst = self.vae.decode(latents / self.vae_scale_factor, mask) + vae_et = time.time() + self.vae_decode_times.append(vae_et - vae_st) + + self.frames.extend(lengths) + + end = time.time() + self.times.append(end - start) + + # joints recover + joints_rst = self.feats2joints(feats_rst) + joints_ref = self.feats2joints(feats_ref) + + # renorm for t2m evaluators + feats_rst = self.datamodule.renorm4t2m(feats_rst) + feats_ref = self.datamodule.renorm4t2m(feats_ref) + + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=feats_ref.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + feats_ref = feats_ref[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, eval(f"self.cfg.DATASET.{self.cfg.DATASET.NAME.upper()}.UNIT_LEN"), + rounding_mode="floor") + + recons_mov = self.t2m_moveencoder(feats_rst[..., :-4]).detach() + recons_emb = self.t2m_motionencoder(recons_mov, m_lens) + motion_mov = self.t2m_moveencoder(feats_ref[..., :-4]).detach() + motion_emb = self.t2m_motionencoder(motion_mov, m_lens) + + # t2m text encoder + text_emb = self.t2m_textencoder(word_embs, pos_ohot, text_lengths)[align_idx] + + rs_set = {"m_ref": feats_ref, "m_rst": feats_rst, + "lat_t": text_emb, "lat_m": motion_emb, "lat_rm": recons_emb, + "joints_ref": joints_ref, "joints_rst": joints_rst} + + if 'hint' in batch: + hint_3d = self.datamodule.denorm_spatial(batch['hint']) * batch['hint_mask'] + rs_set['hint'] = hint_3d + rs_set['hint_mask'] = batch['hint_mask'] + + return rs_set + + def allsplit_step(self, split: str, batch: dict) -> Optional[dict]: + if split in ["test", "val"]: + rs_set = self.t2m_eval(batch) + + if self.datamodule.is_mm: + metric_list = ['MMMetrics'] + else: + metric_list = self.metric_list + + for metric in metric_list: + if metric == "TM2TMetrics": + getattr(self, metric).update( + rs_set["lat_t"], + rs_set["lat_rm"], + rs_set["lat_m"], + batch["length"]) + elif metric == "MMMetrics" and self.datamodule.is_mm: + getattr(self, metric).update(rs_set["lat_rm"].unsqueeze(0), batch["length"]) + elif metric == 'ControlMetrics': + getattr(self, metric).update(rs_set["joints_rst"], rs_set['hint'], + rs_set['hint_mask'], batch['length']) + else: + raise TypeError(f"Not support this metric: {metric}.") + + if split in ["train", "val"]: + loss_dict = self.train_diffusion_forward(batch) + return loss_dict diff --git a/mld/models/modeltype/vae.py b/mld/models/modeltype/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..d1099f5bd87252378d053056589231ea46b40f16 --- /dev/null +++ b/mld/models/modeltype/vae.py @@ -0,0 +1,222 @@ +import time +import logging +from typing import Optional + +import numpy as np +from omegaconf import DictConfig + +import torch +import torch.nn.functional as F + +from mld.data.base import BaseDataModule +from mld.config import instantiate_from_config +from mld.utils.temos_utils import remove_padding +from mld.utils.utils import count_parameters, get_guidance_scale_embedding, extract_into_tensor, sum_flat +from .base import BaseModel + +logger = logging.getLogger(__name__) + + +class VAE(BaseModel): + def __init__(self, cfg: DictConfig, datamodule: BaseDataModule) -> None: + super().__init__() + + self.cfg = cfg + self.datamodule = datamodule + self.njoints = cfg.DATASET.NJOINTS + + self.vae = instantiate_from_config(cfg.model.motion_vae) + + self._get_t2m_evaluator(cfg) + + self.metric_list = cfg.METRIC.TYPE + self.configure_metrics() + + self.feats2joints = datamodule.feats2joints + + self.rec_feats_ratio = cfg.model.rec_feats_ratio + self.rec_joints_ratio = cfg.model.rec_joints_ratio + self.rec_velocity_ratio = cfg.model.rec_velocity_ratio + self.kl_ratio = cfg.model.kl_ratio + + self.rec_feats_loss = cfg.model.rec_feats_loss + self.rec_joints_loss = cfg.model.rec_joints_loss + self.rec_velocity_loss = cfg.model.rec_velocity_loss + self.mask_loss = cfg.model.mask_loss + + logger.info(f"latent_dim: {cfg.model.latent_dim}") + logger.info(f"rec_feats_ratio: {self.rec_feats_ratio}, " + f"rec_joints_ratio: {self.rec_joints_ratio}, " + f"rec_velocity_ratio: {self.rec_velocity_ratio}, " + f"kl_ratio: {self.kl_ratio}") + logger.info(f"rec_feats_loss: {self.rec_feats_loss}, " + f"rec_joints_loss: {self.rec_joints_loss}, " + f"rec_velocity_loss: {self.rec_velocity_loss}") + logger.info(f"mask_loss: {cfg.model.mask_loss}") + + self.summarize_parameters() + + def summarize_parameters(self) -> None: + logger.info(f'VAE Encoder: {count_parameters(self.vae.encoder)}M') + logger.info(f'VAE Decoder: {count_parameters(self.vae.decoder)}M') + + def forward(self, batch: dict) -> tuple: + feats_ref = batch['motion'] + lengths = batch["length"] + mask = batch['mask'] + + z, dist_m = self.vae.encode(feats_ref, mask) + feats_rst = self.vae.decode(z, mask) + + joints = self.feats2joints(feats_rst.detach().cpu()) + joints = remove_padding(joints, lengths) + + joints_ref = None + if feats_ref is not None: + joints_ref = self.feats2joints(feats_ref.detach().cpu()) + joints_ref = remove_padding(joints_ref, lengths) + + return joints, joints_ref + + def loss_calculate(self, a: torch.Tensor, b: torch.Tensor, loss_type: str, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + mask = None if not self.mask_loss else mask + if loss_type == 'l1': + loss = F.l1_loss(a, b, reduction='none') + elif loss_type == 'l1_smooth': + loss = F.smooth_l1_loss(a, b, reduction='none') + elif loss_type == 'l2': + loss = F.mse_loss(a, b, reduction='none') + else: + raise ValueError(f'Unknown loss type: {loss_type}') + + if mask is not None: + loss = (loss.mean(dim=-1) * mask).sum(-1) / mask.sum(-1) + return loss.mean() + + def train_vae_forward(self, batch: dict) -> dict: + feats_ref = batch['motion'] + mask = batch['mask'] + + z, dist_m = self.vae.encode(feats_ref, mask) + feats_rst = self.vae.decode(z, mask) + + loss_dict = dict( + rec_feats_loss=torch.tensor(0., device=z.device), + rec_joints_loss=torch.tensor(0., device=z.device), + rec_velocity_loss=torch.tensor(0., device=z.device), + kl_loss=torch.tensor(0., device=z.device)) + + if self.rec_feats_ratio > 0: + rec_feats_loss = self.loss_calculate(feats_ref, feats_rst, self.rec_feats_loss, mask) + loss_dict['rec_feats_loss'] = rec_feats_loss * self.rec_feats_ratio + + if self.rec_joints_ratio > 0: + joints_rst = self.feats2joints(feats_rst).reshape(mask.size(0), mask.size(1), -1) + joints_ref = self.feats2joints(feats_ref).reshape(mask.size(0), mask.size(1), -1) + rec_joints_loss = self.loss_calculate(joints_ref, joints_rst, self.rec_joints_loss, mask) + loss_dict['rec_joints_loss'] = rec_joints_loss * self.rec_joints_ratio + + if self.rec_velocity_ratio > 0: + rec_velocity_loss = self.loss_calculate(feats_ref[..., 4: (self.njoints - 1) * 3 + 4], + feats_rst[..., 4: (self.njoints - 1) * 3 + 4], + self.rec_velocity_loss, mask) + loss_dict['rec_velocity_loss'] = rec_velocity_loss * self.rec_velocity_ratio + + if self.kl_ratio > 0: + mu_ref = torch.zeros_like(dist_m.loc) + scale_ref = torch.ones_like(dist_m.scale) + dist_ref = torch.distributions.Normal(mu_ref, scale_ref) + kl_loss = torch.distributions.kl_divergence(dist_m, dist_ref).mean() + loss_dict['kl_loss'] = kl_loss * self.kl_ratio + + loss = sum([v for v in loss_dict.values()]) + loss_dict['loss'] = loss + return loss_dict + + def t2m_eval(self, batch: dict) -> dict: + feats_ref_ori = batch["motion"] + mask = batch['mask'] + lengths = batch["length"] + word_embs = batch["word_embs"] + pos_ohot = batch["pos_ohot"] + text_lengths = batch["text_len"] + + start = time.time() + + vae_st_e = time.time() + z, dist_m = self.vae.encode(feats_ref_ori, mask) + vae_et_e = time.time() + self.vae_encode_times.append(vae_et_e - vae_st_e) + + vae_st_d = time.time() + feats_rst_ori = self.vae.decode(z, mask) + vae_et_d = time.time() + self.vae_decode_times.append(vae_et_d - vae_st_d) + + end = time.time() + self.times.append(end - start) + self.frames.extend(lengths) + + # joints recover + joints_rst = self.feats2joints(feats_rst_ori) + joints_ref = self.feats2joints(feats_ref_ori) + + # renorm for t2m evaluators + feats_rst = self.datamodule.renorm4t2m(feats_rst_ori) + feats_ref = self.datamodule.renorm4t2m(feats_ref_ori) + + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=feats_ref.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + feats_ref = feats_ref[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, eval(f"self.cfg.DATASET.{self.cfg.DATASET.NAME.upper()}.UNIT_LEN"), + rounding_mode="floor") + + recons_mov = self.t2m_moveencoder(feats_rst[..., :-4]).detach() + recons_emb = self.t2m_motionencoder(recons_mov, m_lens) + gt_mov = self.t2m_moveencoder(feats_ref[..., :-4]).detach() + gt_emb = self.t2m_motionencoder(gt_mov, m_lens) + + # t2m text encoder + text_emb = self.t2m_textencoder(word_embs, pos_ohot, text_lengths)[align_idx] + + rs_set = { + "m_ref": feats_ref, + "m_rst": feats_rst, + "m_ref_ori": feats_ref_ori, + "m_rst_ori": feats_rst_ori, + "lat_t": text_emb, + "lat_m": gt_emb, + "lat_rm": recons_emb, + "joints_ref": joints_ref, + "joints_rst": joints_rst + } + return rs_set + + def allsplit_step(self, split: str, batch: dict) -> Optional[dict]: + if split in ["test"]: + rs_set = self.t2m_eval(batch) + + for metric in self.metric_list: + if metric == "TM2TMetrics": + getattr(self, metric).update( + rs_set["lat_t"], + rs_set["lat_rm"], + rs_set["lat_m"], + batch["length"]) + elif metric == "PosMetrics": + getattr(self, metric).update(rs_set["joints_ref"], + rs_set["joints_rst"], + rs_set["m_ref_ori"], + rs_set["m_rst_ori"], + batch["length"]) + else: + raise TypeError(f"Not support this metric {metric}.") + + if split in ["train", "val"]: + loss_dict = self.train_vae_forward(batch) + return loss_dict diff --git a/mld/models/operator/__init__.py b/mld/models/operator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/models/operator/attention.py b/mld/models/operator/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8d784624299eb4b6d51e30f0cb8dfcd30625266c --- /dev/null +++ b/mld/models/operator/attention.py @@ -0,0 +1,399 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from .utils import get_clone, get_clones, get_activation_fn + + +class SkipTransformerEncoder(nn.Module): + def __init__(self, encoder_layer: nn.Module, num_layers: int, norm: Optional[nn.Module] = None, + act: Optional[str] = None, is_controlnet: bool = False, is_moe: bool = False) -> None: + super().__init__() + self.d_model = encoder_layer.d_model + + self.num_layers = num_layers + self.norm = norm + self.act = get_activation_fn(act) + self.is_controlnet = is_controlnet + self.is_moe = is_moe + assert num_layers % 2 == 1 + + num_block = (num_layers - 1) // 2 + self.input_blocks = get_clones(encoder_layer, num_block) + self.middle_block = get_clone(encoder_layer) + self.output_blocks = get_clones(encoder_layer, num_block) + self.linear_blocks = get_clones(nn.Linear(2 * self.d_model, self.d_model), num_block) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def maybe_controlnet_moe( + self, x: torch.Tensor, controlnet_residuals: Optional[list[torch.Tensor]] = None, + all_intermediates: Optional[tuple] = None, all_router_logits: Optional[tuple] = None + ) -> tuple: + if self.is_moe: + all_router_logits += (x[1],) + x = x[0] + + if controlnet_residuals is not None: + x = x + controlnet_residuals.pop() + + if self.is_controlnet: + all_intermediates += (x,) + + return x, controlnet_residuals, all_intermediates, all_router_logits + + def forward(self, src: torch.Tensor, + mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + controlnet_residuals: Optional[list[torch.Tensor]] = None) -> tuple: + x = src + xs = [] + + all_intermediates = () if self.is_controlnet else None + all_router_logits = () if self.is_moe else None + if controlnet_residuals is not None: + controlnet_residuals.reverse() + + for module in self.input_blocks: + x = module(x, src_mask=mask, src_key_padding_mask=src_key_padding_mask) + x, controlnet_residuals, all_intermediates, all_router_logits = self.maybe_controlnet_moe( + x, controlnet_residuals, all_intermediates, all_router_logits) + xs.append(x) + + x = self.middle_block(x, src_mask=mask, src_key_padding_mask=src_key_padding_mask) + x, controlnet_residuals, all_intermediates, all_router_logits = self.maybe_controlnet_moe( + x, controlnet_residuals, all_intermediates, all_router_logits) + + for (module, linear) in zip(self.output_blocks, self.linear_blocks): + x = torch.cat([x, xs.pop()], dim=-1) + x = linear(x) + x = module(x, src_mask=mask, src_key_padding_mask=src_key_padding_mask) + x, controlnet_residuals, all_intermediates, all_router_logits = self.maybe_controlnet_moe( + x, controlnet_residuals, all_intermediates, all_router_logits) + + if self.norm: + x = self.act(self.norm(x)) + + return x, all_intermediates, all_router_logits + + +class SkipTransformerDecoder(nn.Module): + def __init__(self, decoder_layer: nn.Module, num_layers: int, norm: Optional[nn.Module] = None, + act: Optional[str] = None, is_controlnet: bool = False, is_moe: bool = False) -> None: + super().__init__() + self.d_model = decoder_layer.d_model + + self.num_layers = num_layers + self.norm = norm + self.act = get_activation_fn(act) + self.is_controlnet = is_controlnet + self.is_moe = is_moe + assert num_layers % 2 == 1 + + num_block = (num_layers - 1) // 2 + self.input_blocks = get_clones(decoder_layer, num_block) + self.middle_block = get_clone(decoder_layer) + self.output_blocks = get_clones(decoder_layer, num_block) + self.linear_blocks = get_clones(nn.Linear(2 * self.d_model, self.d_model), num_block) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def maybe_controlnet_moe( + self, x: torch.Tensor, controlnet_residuals: Optional[list[torch.Tensor]] = None, + all_intermediates: Optional[tuple] = None, all_router_logits: Optional[tuple] = None + ) -> tuple: + if self.is_moe: + all_router_logits += (x[1],) + x = x[0] + + if self.is_controlnet: + x = x + controlnet_residuals.pop() + all_intermediates += (x,) + + return x, controlnet_residuals, all_intermediates, all_router_logits + + def forward(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + controlnet_residuals: Optional[list[torch.Tensor]] = None) -> tuple: + x = tgt + xs = [] + + all_intermediates = () if self.is_controlnet else None + all_router_logits = () if self.is_moe else None + if controlnet_residuals is not None: + controlnet_residuals.reverse() + + for module in self.input_blocks: + x = module(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + x, controlnet_residuals, all_intermediates, all_router_logits = self.maybe_controlnet_moe( + x, controlnet_residuals, all_intermediates, all_router_logits) + xs.append(x) + + x = self.middle_block(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + x, controlnet_residuals, all_intermediates, all_router_logits = self.maybe_controlnet_moe( + x, controlnet_residuals, all_intermediates, all_router_logits) + + for (module, linear) in zip(self.output_blocks, self.linear_blocks): + x = torch.cat([x, xs.pop()], dim=-1) + x = linear(x) + x = module(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + x, controlnet_residuals, all_intermediates, all_router_logits = self.maybe_controlnet_moe( + x, controlnet_residuals, all_intermediates, all_router_logits) + + if self.norm: + x = self.act(self.norm(x)) + + return x, all_intermediates, all_router_logits + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer: nn.Module, num_layers: int, norm: Optional[nn.Module] = None, + act: Optional[str] = None, return_intermediate: bool = False) -> None: + super().__init__() + self.layers = get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + self.norm = norm + self.act = get_activation_fn(act) + + def forward(self, src: torch.Tensor, + mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + controlnet_residuals: Optional[list[torch.Tensor]] = None) -> torch.Tensor: + output = src + intermediate = [] + index = 0 + for layer in self.layers: + output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) + + if controlnet_residuals is not None: + output = output + controlnet_residuals[index] + index += 1 + + if self.return_intermediate: + intermediate.append(output) + + if self.norm: + output = self.act(self.norm(output)) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer: nn.Module, num_layers: int, norm: Optional[nn.Module] = None, + act: Optional[str] = None, return_intermediate: bool = False) -> None: + super().__init__() + self.layers = get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + self.norm = norm + self.act = get_activation_fn(act) + + def forward(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + controlnet_residuals: Optional[list[torch.Tensor]] = None) -> torch.Tensor: + output = tgt + intermediate = [] + index = 0 + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + + if controlnet_residuals is not None: + output = output + controlnet_residuals[index] + index += 1 + + if self.return_intermediate: + intermediate.append(output) + + if self.norm: + output = self.act(self.norm(output)) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: str = "relu", normalize_before: bool = False, norm_eps: float = 1e-5) -> None: + super(TransformerEncoderLayer, self).__init__() + self.d_model = d_model + self.activation_name = activation + self.normalize_before = normalize_before + + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.linear1 = nn.Linear(d_model, dim_feedforward if activation != 'geglu' else dim_feedforward * 2) + self.activation = get_activation_fn(activation) if activation != 'geglu' else nn.GELU() + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward_post(self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + src2 = self.self_attn(src, src, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + if self.activation_name == 'geglu': + src2, gate = self.linear1(src).chunk(2, dim=-1) + src2 = src2 * self.activation(gate) + else: + src2 = self.activation(self.linear1(src)) + src2 = self.linear2(self.dropout(src2)) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + src2 = self.norm1(src) + src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + if self.activation_name == 'geglu': + src2, gate = self.linear1(src2).chunk(2, dim=-1) + src2 = src2 * self.activation(gate) + else: + src2 = self.activation(self.linear1(src2)) + src2 = self.linear2(self.dropout(src2)) + src = src + self.dropout2(src2) + return src + + def forward(self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask) + return self.forward_post(src, src_mask, src_key_padding_mask) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: str = "relu", normalize_before: bool = False, norm_eps: float = 1e-5) -> None: + super(TransformerDecoderLayer, self).__init__() + self.d_model = d_model + self.activation_name = activation + self.normalize_before = normalize_before + + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.linear1 = nn.Linear(d_model, dim_feedforward if activation != 'geglu' else dim_feedforward * 2) + self.activation = get_activation_fn(activation) if activation != 'geglu' else nn.GELU() + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + def forward_post(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=tgt, key=memory, value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + if self.activation_name == 'geglu': + tgt2, gate = self.linear1(tgt).chunk(2, dim=-1) + tgt2 = tgt2 * self.activation(gate) + else: + tgt2 = self.activation(self.linear1(tgt)) + tgt2 = self.linear2(self.dropout(tgt2)) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + tgt2 = self.norm1(tgt) + tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=tgt2, key=memory, value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + if self.activation_name == 'geglu': + tgt2, gate = self.linear1(tgt2).chunk(2, dim=-1) + tgt2 = tgt2 * self.activation(gate) + else: + tgt2 = self.activation(self.linear1(tgt2)) + tgt2 = self.linear2(self.dropout(tgt2)) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask) diff --git a/mld/models/operator/conv.py b/mld/models/operator/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..899425ad5966e79e6778fc0409623b60f65f594b --- /dev/null +++ b/mld/models/operator/conv.py @@ -0,0 +1,139 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from .utils import get_activation_fn + + +class ResConv1DBlock(nn.Module): + def __init__(self, n_in: int, n_state: int, dilation: int = 1, activation: str = 'silu', dropout: float = 0.1, + norm: Optional[str] = None, norm_groups: int = 32, norm_eps: float = 1e-5) -> None: + super(ResConv1DBlock, self).__init__() + + self.norm = norm + if norm == "LN": + self.norm1 = nn.LayerNorm(n_in, eps=norm_eps) + self.norm2 = nn.LayerNorm(n_in, eps=norm_eps) + elif norm == "GN": + self.norm1 = nn.GroupNorm(num_groups=norm_groups, num_channels=n_in, eps=norm_eps) + self.norm2 = nn.GroupNorm(num_groups=norm_groups, num_channels=n_in, eps=norm_eps) + elif norm == "BN": + self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=norm_eps) + self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=norm_eps) + else: + self.norm1 = nn.Identity() + self.norm2 = nn.Identity() + + self.activation = get_activation_fn(activation) + + self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding=dilation, dilation=dilation) + self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_orig = x + if self.norm == "LN": + x = self.norm1(x.transpose(-2, -1)) + x = self.activation(x.transpose(-2, -1)) + else: + x = self.norm1(x) + x = self.activation(x) + + x = self.conv1(x) + + if self.norm == "LN": + x = self.norm2(x.transpose(-2, -1)) + x = self.activation(x.transpose(-2, -1)) + else: + x = self.norm2(x) + x = self.activation(x) + + x = self.conv2(x) + x = self.dropout(x) + x = x + x_orig + return x + + +class Resnet1D(nn.Module): + def __init__(self, n_in: int, n_state: int, n_depth: int, reverse_dilation: bool = True, + dilation_growth_rate: int = 3, activation: str = 'relu', dropout: float = 0.1, + norm: Optional[str] = None, norm_groups: int = 32, norm_eps: float = 1e-5) -> None: + super(Resnet1D, self).__init__() + blocks = [ResConv1DBlock(n_in, n_state, dilation=dilation_growth_rate ** depth, activation=activation, + dropout=dropout, norm=norm, norm_groups=norm_groups, norm_eps=norm_eps) + for depth in range(n_depth)] + if reverse_dilation: + blocks = blocks[::-1] + self.model = nn.Sequential(*blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + +class ResEncoder(nn.Module): + def __init__(self, + in_width: int = 263, + mid_width: int = 512, + out_width: int = 512, + down_t: int = 2, + stride_t: int = 2, + n_depth: int = 3, + dilation_growth_rate: int = 3, + activation: str = 'relu', + dropout: float = 0.1, + norm: Optional[str] = None, + norm_groups: int = 32, + norm_eps: float = 1e-5, + double_z: bool = False) -> None: + super(ResEncoder, self).__init__() + + blocks = [] + filter_t, pad_t = stride_t * 2, stride_t // 2 + blocks.append(nn.Conv1d(in_width, mid_width, 3, 1, 1)) + blocks.append(get_activation_fn(activation)) + + for i in range(down_t): + block = nn.Sequential( + nn.Conv1d(mid_width, mid_width, filter_t, stride_t, pad_t), + Resnet1D(mid_width, mid_width, n_depth, reverse_dilation=True, dilation_growth_rate=dilation_growth_rate, + activation=activation, dropout=dropout, norm=norm, norm_groups=norm_groups, norm_eps=norm_eps)) + blocks.append(block) + blocks.append(nn.Conv1d(mid_width, out_width * 2 if double_z else out_width, 3, 1, 1)) + self.model = nn.Sequential(*blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x.permute(0, 2, 1)) # B x C x T + + +class ResDecoder(nn.Module): + def __init__(self, + in_width: int = 263, + mid_width: int = 512, + out_width: int = 512, + down_t: int = 2, + stride_t: int = 2, + n_depth: int = 3, + dilation_growth_rate: int = 3, + activation: str = 'relu', + dropout: float = 0.1, + norm: Optional[str] = None, + norm_groups: int = 32, + norm_eps: float = 1e-5) -> None: + super(ResDecoder, self).__init__() + blocks = [nn.Conv1d(out_width, mid_width, 3, 1, 1), get_activation_fn(activation)] + + for i in range(down_t): + block = nn.Sequential( + Resnet1D(mid_width, mid_width, n_depth, reverse_dilation=True, dilation_growth_rate=dilation_growth_rate, + activation=activation, dropout=dropout, norm=norm, norm_groups=norm_groups, norm_eps=norm_eps), + nn.Upsample(scale_factor=stride_t, mode='nearest'), + nn.Conv1d(mid_width, mid_width, 3, 1, 1)) + blocks.append(block) + blocks.append(nn.Conv1d(mid_width, mid_width, 3, 1, 1)) + blocks.append(get_activation_fn(activation)) + blocks.append(nn.Conv1d(mid_width, in_width, 3, 1, 1)) + self.model = nn.Sequential(*blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x).permute(0, 2, 1) # B x T x C diff --git a/mld/models/operator/embeddings.py b/mld/models/operator/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..cec01358a22846e457b0fdad77e1ff6d8abe14e3 --- /dev/null +++ b/mld/models/operator/embeddings.py @@ -0,0 +1,98 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn + +from .utils import get_activation_fn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str, + out_dim: Optional[int] = None, post_act_fn: Optional[str] = None, + cond_proj_dim: Optional[int] = None, zero_init_cond: bool = True) -> None: + super(TimestepEmbedding, self).__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + if zero_init_cond: + self.cond_proj.weight.data.fill_(0.0) + else: + self.cond_proj = None + + self.act = get_activation_fn(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation_fn(post_act_fn) + + def forward(self, sample: torch.Tensor, timestep_cond: Optional[torch.Tensor] = None) -> torch.Tensor: + if timestep_cond is not None: + sample = sample + self.cond_proj(timestep_cond) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, + downscale_freq_shift: float) -> None: + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift) + return t_emb diff --git a/mld/models/operator/moe.py b/mld/models/operator/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..c06b9dc8bd26679078be30d7e62ea207580b9b61 --- /dev/null +++ b/mld/models/operator/moe.py @@ -0,0 +1,182 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import get_clones, get_activation_fn + + +class SparseMoeMLP(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, dropout: float, activation: str) -> None: + super(SparseMoeMLP, self).__init__() + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.activation = get_activation_fn(activation) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.linear2(self.dropout(self.activation(self.linear1(hidden_states)))) + + +class SparseMoeBlock(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, dropout: float, activation: str, + num_experts: int, topk: int, jitter_noise: Optional[float] = None) -> None: + super(SparseMoeBlock, self).__init__() + self.topk = topk + self.num_experts = num_experts + self.jitter_noise = jitter_noise + + self.gate = nn.Linear(d_model, num_experts) + self.experts = get_clones(SparseMoeMLP(d_model, dim_feedforward, dropout, activation), num_experts) + + def forward(self, hidden_states: torch.Tensor) -> tuple: + sequence_length, batch_size, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise is not None: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + routing_weights = F.softmax(router_logits, dim=-1) + routing_weights, selected_experts = torch.topk(routing_weights, self.topk, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[top_x] + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + final_hidden_states.index_add_(0, top_x, current_hidden_states) + final_hidden_states = final_hidden_states.reshape(sequence_length, batch_size, hidden_dim) + return final_hidden_states, router_logits + + +class MoeTransformerEncoderLayer(nn.Module): + + def __init__(self, d_model: int, nhead: int, num_experts: int, topk: int, dim_feedforward: int = 2048, + dropout: float = 0.1, activation: str = "relu", normalize_before: bool = False, + norm_eps: float = 1e-5, jitter_noise: Optional[float] = None) -> None: + super(MoeTransformerEncoderLayer, self).__init__() + self.d_model = d_model + self.normalize_before = normalize_before + + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.moe = SparseMoeBlock( + d_model, dim_feedforward, dropout, activation, + num_experts=num_experts, topk=topk, jitter_noise=jitter_noise + ) + self.norm1 = nn.LayerNorm(d_model, eps=norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward_post(self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None) -> tuple: + src2 = self.self_attn(src, src, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2, logits = self.moe(src) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src, logits + + def forward_pre(self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None) -> tuple: + src2 = self.norm1(src) + src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2, logits = self.moe(src2) + src = src + self.dropout2(src2) + return src, logits + + def forward(self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None) -> tuple: + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask) + return self.forward_post(src, src_mask, src_key_padding_mask) + + +class MoeTransformerDecoderLayer(nn.Module): + + def __init__(self, d_model: int, nhead: int, num_experts: int, topk: int, dim_feedforward: int = 2048, + dropout: float = 0.1, activation: str = "relu", normalize_before: bool = False, + norm_eps: float = 1e-5, jitter_noise: Optional[float] = None) -> None: + super(MoeTransformerDecoderLayer, self).__init__() + self.d_model = d_model + self.normalize_before = normalize_before + + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.moe = SparseMoeBlock( + d_model, dim_feedforward, dropout, activation, + num_experts=num_experts, topk=topk, jitter_noise=jitter_noise + ) + + self.norm1 = nn.LayerNorm(d_model, eps=norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + def forward_post(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None) -> tuple: + tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=tgt, key=memory, value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2, logits = self.moe(tgt) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt, logits + + def forward_pre(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None) -> tuple: + tgt2 = self.norm1(tgt) + tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=tgt2, key=memory, value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2, logits = self.moe(tgt2) + tgt = tgt + self.dropout3(tgt2) + return tgt, logits + + def forward(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None) -> tuple: + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask) diff --git a/mld/models/operator/position_encoding.py b/mld/models/operator/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..a0cee22d0aa7ec8ea3415bc63fcb3330e5f7fc5c --- /dev/null +++ b/mld/models/operator/position_encoding.py @@ -0,0 +1,57 @@ +import numpy as np + +import torch +import torch.nn as nn + + +class PositionEmbeddingSine1D(nn.Module): + + def __init__(self, d_model: int, max_len: int = 500, batch_first: bool = False) -> None: + super().__init__() + self.batch_first = batch_first + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange( + 0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.batch_first: + x = x + self.pe.permute(1, 0, 2)[:, :x.shape[1], :] + else: + x = x + self.pe[:x.shape[0], :] + return x + + +class PositionEmbeddingLearned1D(nn.Module): + + def __init__(self, d_model: int, max_len: int = 500, batch_first: bool = False) -> None: + super().__init__() + self.batch_first = batch_first + self.pe = nn.Parameter(torch.zeros(max_len, 1, d_model)) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.uniform_(self.pe) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.batch_first: + x = x + self.pe.permute(1, 0, 2)[:, :x.shape[1], :] + else: + x = x + self.pe[:x.shape[0], :] + return x + + +def build_position_encoding(N_steps: int, position_embedding: str = "sine") -> nn.Module: + if position_embedding == 'sine': + position_embedding = PositionEmbeddingSine1D(N_steps) + elif position_embedding == 'learned': + position_embedding = PositionEmbeddingLearned1D(N_steps) + else: + raise ValueError(f"not supported {position_embedding}") + return position_embedding diff --git a/mld/models/operator/utils.py b/mld/models/operator/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72093aad19cf0741a1fb6910a36205f467e088b6 --- /dev/null +++ b/mld/models/operator/utils.py @@ -0,0 +1,37 @@ +import copy +from typing import Optional + +import torch.nn as nn + + +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU() +} + + +def get_clone(module: nn.Module) -> nn.Module: + return copy.deepcopy(module) + + +def get_clones(module: nn.Module, N: int) -> nn.ModuleList: + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + + +def get_activation_fn(act_fn: Optional[str] = None) -> nn.Module: + if act_fn is None: + return nn.Identity() + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + + +def zero_module(module: nn.Module) -> nn.Module: + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/mld/models/schedulers/__init__.py b/mld/models/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/models/schedulers/scheduling_lcm.py b/mld/models/schedulers/scheduling_lcm.py new file mode 100644 index 0000000000000000000000000000000000000000..447d6d2e98699af1410616acbe6fc3a715022143 --- /dev/null +++ b/mld/models/schedulers/scheduling_lcm.py @@ -0,0 +1,22 @@ +from typing import Optional, Union + +import torch +import diffusers + + +class LCMScheduler(diffusers.schedulers.LCMScheduler): + def __init__(self, timesteps_step_map: Optional[dict] = None, **kwargs) -> None: + super(LCMScheduler, self).__init__(**kwargs) + self.timesteps_step_map = timesteps_step_map + + def set_timesteps(self, num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, **kwargs) -> None: + if self.timesteps_step_map is None: + super().set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs) + else: + assert num_inference_steps is not None + self.num_inference_steps = num_inference_steps + timesteps = self.timesteps_step_map[num_inference_steps] + assert all([timestep < self.config.num_train_timesteps for timestep in timesteps]) + self.timesteps = torch.tensor(timesteps).to(device=device, dtype=torch.long) + self._step_index = None diff --git a/mld/render/__init__.py b/mld/render/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/render/blender/__init__.py b/mld/render/blender/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a82255db45b763479586f83f0b7c904387b814ba --- /dev/null +++ b/mld/render/blender/__init__.py @@ -0,0 +1 @@ +from .render import render diff --git a/mld/render/blender/camera.py b/mld/render/blender/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..bc3938ca2e3946438d3a77260cd777b3d4598a0a --- /dev/null +++ b/mld/render/blender/camera.py @@ -0,0 +1,35 @@ +import bpy + + +class Camera: + def __init__(self, first_root, mode): + camera = bpy.data.objects['Camera'] + + # initial position + camera.location.x = 7.36 + camera.location.y = -6.93 + camera.location.z = 5.6 + + # wider point of view + if mode == "sequence": + camera.data.lens = 65 + elif mode == "frame": + camera.data.lens = 130 + elif mode == "video": + camera.data.lens = 110 + + self.mode = mode + self.camera = camera + + self.camera.location.x += first_root[0] + self.camera.location.y += first_root[1] + + self._root = first_root + + def update(self, new_root): + delta_root = new_root - self._root + + self.camera.location.x += delta_root[0] + self.camera.location.y += delta_root[1] + + self._root = new_root diff --git a/mld/render/blender/floor.py b/mld/render/blender/floor.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4b44a2d204de6f76f3bddd9215219484f2a94d --- /dev/null +++ b/mld/render/blender/floor.py @@ -0,0 +1,63 @@ +import bpy +from .materials import floor_mat + + +def plot_floor(data, big_plane=True): + # Create a floor + minx, miny, _ = data.min(axis=(0, 1)) + maxx, maxy, _ = data.max(axis=(0, 1)) + + location = ((maxx + minx)/2, (maxy + miny)/2, 0) + # a little bit bigger + scale = (1.08*(maxx - minx)/2, 1.08*(maxy - miny)/2, 1) + + bpy.ops.mesh.primitive_plane_add(size=2, enter_editmode=False, align='WORLD', location=location, scale=(1, 1, 1)) + + bpy.ops.transform.resize(value=scale, orient_type='GLOBAL', orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), orient_matrix_type='GLOBAL', + constraint_axis=(False, True, False), mirror=True, use_proportional_edit=False, + proportional_edit_falloff='SMOOTH', proportional_size=1, use_proportional_connected=False, + use_proportional_projected=False, release_confirm=True) + obj = bpy.data.objects["Plane"] + obj.name = "SmallPlane" + obj.data.name = "SmallPlane" + + if not big_plane: + obj.active_material = floor_mat(color=(0.2, 0.2, 0.2, 1)) + else: + obj.active_material = floor_mat(color=(0.1, 0.1, 0.1, 1)) + + if big_plane: + location = ((maxx + minx)/2, (maxy + miny)/2, -0.01) + bpy.ops.mesh.primitive_plane_add(size=2, enter_editmode=False, align='WORLD', location=location, scale=(1, 1, 1)) + + bpy.ops.transform.resize(value=[2*x for x in scale], orient_type='GLOBAL', orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), orient_matrix_type='GLOBAL', + constraint_axis=(False, True, False), mirror=True, use_proportional_edit=False, + proportional_edit_falloff='SMOOTH', proportional_size=1, use_proportional_connected=False, + use_proportional_projected=False, release_confirm=True) + + obj = bpy.data.objects["Plane"] + obj.name = "BigPlane" + obj.data.name = "BigPlane" + obj.active_material = floor_mat(color=(0.2, 0.2, 0.2, 1)) + + +def show_trajectory(coords): + for i, coord in enumerate(coords): + import matplotlib + cmap = matplotlib.cm.get_cmap('Greens') + begin = 0.45 + end = 1.0 + frac = i / len(coords) + rgb_color = cmap(begin + (end - begin) * frac) + + x, y, z = coord + bpy.ops.mesh.primitive_uv_sphere_add(radius=0.04, location=(x, y, z)) + obj = bpy.context.active_object + + mat = bpy.data.materials.new(name="SphereMaterial") + obj.data.materials.append(mat) + mat.use_nodes = True + bsdf = mat.node_tree.nodes["Principled BSDF"] + bsdf.inputs['Base Color'].default_value = rgb_color + + bpy.ops.object.mode_set(mode='OBJECT') diff --git a/mld/render/blender/materials.py b/mld/render/blender/materials.py new file mode 100644 index 0000000000000000000000000000000000000000..2152b433a0decea7d368d5b0c3da9f749ec1aaa9 --- /dev/null +++ b/mld/render/blender/materials.py @@ -0,0 +1,70 @@ +import bpy + + +def clear_material(material): + if material.node_tree: + material.node_tree.links.clear() + material.node_tree.nodes.clear() + + +def colored_material_diffuse_BSDF(r, g, b, a=1, roughness=0.127451): + materials = bpy.data.materials + material = materials.new(name="body") + material.use_nodes = True + clear_material(material) + nodes = material.node_tree.nodes + links = material.node_tree.links + output = nodes.new(type='ShaderNodeOutputMaterial') + diffuse = nodes.new(type='ShaderNodeBsdfDiffuse') + diffuse.inputs["Color"].default_value = (r, g, b, a) + diffuse.inputs["Roughness"].default_value = roughness + links.new(diffuse.outputs['BSDF'], output.inputs['Surface']) + return material + + +# keys: +# ['Base Color', 'Subsurface', 'Subsurface Radius', 'Subsurface Color', 'Metallic', 'Specular', 'Specular Tint', 'Roughness', 'Anisotropic', 'Anisotropic Rotation', 'Sheen', 1Sheen Tint', 'Clearcoat', 'Clearcoat Roughness', 'IOR', 'Transmission', 'Transmission Roughness', 'Emission', 'Emission Strength', 'Alpha', 'Normal', 'Clearcoat Normal', 'Tangent'] +DEFAULT_BSDF_SETTINGS = {"Subsurface": 0.15, + "Subsurface Radius": [1.1, 0.2, 0.1], + "Metallic": 0.3, + "Specular": 0.5, + "Specular Tint": 0.5, + "Roughness": 0.75, + "Anisotropic": 0.25, + "Anisotropic Rotation": 0.25, + "Sheen": 0.75, + "Sheen Tint": 0.5, + "Clearcoat": 0.5, + "Clearcoat Roughness": 0.5, + "IOR": 1.450, + "Transmission": 0.1, + "Transmission Roughness": 0.1, + "Emission": (0, 0, 0, 1), + "Emission Strength": 0.0, + "Alpha": 1.0} + + +def body_material(r, g, b, a=1, name="body", oldrender=True): + if oldrender: + material = colored_material_diffuse_BSDF(r, g, b, a=a) + else: + materials = bpy.data.materials + material = materials.new(name=name) + material.use_nodes = True + nodes = material.node_tree.nodes + diffuse = nodes["Principled BSDF"] + inputs = diffuse.inputs + + settings = DEFAULT_BSDF_SETTINGS.copy() + settings["Base Color"] = (r, g, b, a) + settings["Subsurface Color"] = (r, g, b, a) + settings["Subsurface"] = 0.0 + + for setting, val in settings.items(): + inputs[setting].default_value = val + + return material + + +def floor_mat(color=(0.1, 0.1, 0.1, 1), roughness=0.127451): + return colored_material_diffuse_BSDF(color[0], color[1], color[2], a=color[3], roughness=roughness) diff --git a/mld/render/blender/meshes.py b/mld/render/blender/meshes.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8998332daa3cb14142c1a1a7e07461876339a --- /dev/null +++ b/mld/render/blender/meshes.py @@ -0,0 +1,82 @@ +import numpy as np + +from .materials import body_material + +# Orange +GEN_SMPL = body_material(0.658, 0.214, 0.0114) +# Green +GT_SMPL = body_material(0.035, 0.415, 0.122) + + +class Meshes: + def __init__(self, data, gt, mode, trajectory, faces_path, always_on_floor, oldrender=True): + data, trajectory = prepare_meshes(data, trajectory, always_on_floor=always_on_floor) + + self.faces = np.load(faces_path) + print(faces_path) + self.data = data + self.mode = mode + self.oldrender = oldrender + + self.N = len(data) + # self.trajectory = data[:, :, [0, 1]].mean(1) + self.trajectory = trajectory + + if gt: + self.mat = GT_SMPL + else: + self.mat = GEN_SMPL + + def get_sequence_mat(self, frac): + import matplotlib + # cmap = matplotlib.cm.get_cmap('Blues') + cmap = matplotlib.cm.get_cmap('Oranges') + # begin = 0.60 + # end = 0.90 + begin = 0.50 + end = 0.90 + rgb_color = cmap(begin + (end-begin)*frac) + mat = body_material(*rgb_color, oldrender=self.oldrender) + return mat + + def get_root(self, index): + return self.data[index].mean(0) + + def get_mean_root(self): + return self.data.mean((0, 1)) + + def load_in_blender(self, index, mat): + vertices = self.data[index] + faces = self.faces + name = f"{str(index).zfill(4)}" + + from .tools import load_numpy_vertices_into_blender + load_numpy_vertices_into_blender(vertices, faces, name, mat) + + return name + + def __len__(self): + return self.N + + +def prepare_meshes(data, trajectory, always_on_floor=False): + # Swap axis (gravity=Z instead of Y) + data = data[..., [2, 0, 1]] + + if trajectory is not None: + trajectory = trajectory[..., [2, 0, 1]] + mask = trajectory.sum(-1) != 0 + trajectory = trajectory[mask] + + # Remove the floor + height_offset = data[..., 2].min() + data[..., 2] -= height_offset + + if trajectory is not None: + trajectory[..., 2] -= height_offset + + # Put all the body on the floor + if always_on_floor: + data[..., 2] -= data[..., 2].min(1)[:, None] + + return data, trajectory diff --git a/mld/render/blender/render.py b/mld/render/blender/render.py new file mode 100644 index 0000000000000000000000000000000000000000..4f43ec749e791ca5677f909813a13a862cb61f01 --- /dev/null +++ b/mld/render/blender/render.py @@ -0,0 +1,142 @@ +import os +import shutil + +import bpy + +from .camera import Camera +from .floor import plot_floor, show_trajectory +from .sampler import get_frameidx +from .scene import setup_scene +from .tools import delete_objs + +from mld.render.video import Video + + +def prune_begin_end(data, perc): + to_remove = int(len(data) * perc) + if to_remove == 0: + return data + return data[to_remove:-to_remove] + + +def render_current_frame(path): + bpy.context.scene.render.filepath = path + bpy.ops.render.render(use_viewport=True, write_still=True) + + +def render(npydata, trajectory, path, mode, faces_path, gt=False, + exact_frame=None, num=8, always_on_floor=False, denoising=True, + oldrender=True, res="high", accelerator='gpu', device=[0], fps=20): + + if mode == 'video': + if always_on_floor: + frames_folder = path.replace(".pkl", "_of_frames") + else: + frames_folder = path.replace(".pkl", "_frames") + + if os.path.exists(frames_folder.replace("_frames", ".mp4")) or os.path.exists(frames_folder): + print(f"pkl is rendered or under rendering {path}") + return + + os.makedirs(frames_folder, exist_ok=False) + + elif mode == 'sequence': + path = path.replace('.pkl', '.png') + img_name, ext = os.path.splitext(path) + if always_on_floor: + img_name += "_of" + img_path = f"{img_name}{ext}" + if os.path.exists(img_path): + print(f"pkl is rendered or under rendering {img_path}") + return + + elif mode == 'frame': + path = path.replace('.pkl', '.png') + img_name, ext = os.path.splitext(path) + if always_on_floor: + img_name += "_of" + img_path = f"{img_name}_{exact_frame}{ext}" + if os.path.exists(img_path): + print(f"pkl is rendered or under rendering {img_path}") + return + else: + raise ValueError(f'Invalid mode: {mode}') + + # Setup the scene (lights / render engine / resolution etc) + setup_scene(res=res, denoising=denoising, oldrender=oldrender, accelerator=accelerator, device=device) + + # remove X% of beginning and end + # as it is almost always static + # in this part + # if mode == "sequence": + # perc = 0.2 + # npydata = prune_begin_end(npydata, perc) + + from .meshes import Meshes + data = Meshes(npydata, gt=gt, mode=mode, trajectory=trajectory, + faces_path=faces_path, always_on_floor=always_on_floor) + + # Number of frames possible to render + nframes = len(data) + + # Show the trajectory + if trajectory is not None: + show_trajectory(data.trajectory) + + # Create a floor + plot_floor(data.data, big_plane=False) + + # initialize the camera + camera = Camera(first_root=data.get_root(0), mode=mode) + + frameidx = get_frameidx(mode=mode, nframes=nframes, + exact_frame=exact_frame, + frames_to_keep=num) + + nframes_to_render = len(frameidx) + + # center the camera to the middle + if mode == "sequence": + camera.update(data.get_mean_root()) + + imported_obj_names = [] + for index, frameidx in enumerate(frameidx): + if mode == "sequence": + frac = index / (nframes_to_render - 1) + mat = data.get_sequence_mat(frac) + else: + mat = data.mat + camera.update(data.get_root(frameidx)) + + islast = index == (nframes_to_render - 1) + + obj_name = data.load_in_blender(frameidx, mat) + name = f"{str(index).zfill(4)}" + + if mode == "video": + path = os.path.join(frames_folder, f"frame_{name}.png") + else: + path = img_path + + if mode == "sequence": + imported_obj_names.extend(obj_name) + elif mode == "frame": + camera.update(data.get_root(frameidx)) + + if mode != "sequence" or islast: + render_current_frame(path) + delete_objs(obj_name) + + # remove every object created + delete_objs(imported_obj_names) + delete_objs(["Plane", "myCurve", "Cylinder"]) + + if mode == "video": + video = Video(frames_folder, fps=fps) + vid_path = frames_folder.replace("_frames", ".mp4") + video.save(out_path=vid_path) + shutil.rmtree(frames_folder) + print(f"remove tmp fig folder and save video in {vid_path}") + + else: + print(f"Frame generated at: {img_path}") diff --git a/mld/render/blender/sampler.py b/mld/render/blender/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..149c6d99d05e5614bc291e0bec32940c4bbbc8ed --- /dev/null +++ b/mld/render/blender/sampler.py @@ -0,0 +1,17 @@ +import numpy as np + + +def get_frameidx(mode, nframes, exact_frame, frames_to_keep): + if mode == "sequence": + frameidx = np.linspace(0, nframes - 1, frames_to_keep) + frameidx = np.round(frameidx).astype(int) + frameidx = list(frameidx) + return frameidx + elif mode == "frame": + index_frame = int(exact_frame*nframes) + frameidx = [index_frame] + elif mode == "video": + frameidx = range(0, nframes) + else: + raise ValueError(f"Not support {mode} render mode") + return frameidx diff --git a/mld/render/blender/scene.py b/mld/render/blender/scene.py new file mode 100644 index 0000000000000000000000000000000000000000..79c09731756322e70c4433c26337e5a5f22e856a --- /dev/null +++ b/mld/render/blender/scene.py @@ -0,0 +1,94 @@ +import bpy + + +def setup_renderer(denoising=True, oldrender=True, accelerator="gpu", device=[0]): + bpy.context.scene.render.engine = "CYCLES" + bpy.data.scenes[0].render.engine = "CYCLES" + if accelerator.lower() == "gpu": + bpy.context.preferences.addons[ + "cycles" + ].preferences.compute_device_type = "CUDA" + bpy.context.scene.cycles.device = "GPU" + i = 0 + bpy.context.preferences.addons["cycles"].preferences.get_devices() + for d in bpy.context.preferences.addons["cycles"].preferences.devices: + if i in device: # gpu id + d["use"] = 1 + print(d["name"], "".join(str(i) for i in device)) + else: + d["use"] = 0 + i += 1 + + if denoising: + bpy.context.scene.cycles.use_denoising = True + + bpy.context.scene.render.tile_x = 256 + bpy.context.scene.render.tile_y = 256 + bpy.context.scene.cycles.samples = 64 + + if not oldrender: + bpy.context.scene.view_settings.view_transform = "Standard" + bpy.context.scene.render.film_transparent = True + bpy.context.scene.display_settings.display_device = "sRGB" + bpy.context.scene.view_settings.gamma = 1.2 + bpy.context.scene.view_settings.exposure = -0.75 + + +# Setup scene +def setup_scene( + res="high", denoising=True, oldrender=True, accelerator="gpu", device=[0] +): + scene = bpy.data.scenes["Scene"] + assert res in ["ultra", "high", "med", "low"] + if res == "high": + scene.render.resolution_x = 1280 + scene.render.resolution_y = 1024 + elif res == "med": + scene.render.resolution_x = 1280 // 2 + scene.render.resolution_y = 1024 // 2 + elif res == "low": + scene.render.resolution_x = 1280 // 4 + scene.render.resolution_y = 1024 // 4 + elif res == "ultra": + scene.render.resolution_x = 1280 * 2 + scene.render.resolution_y = 1024 * 2 + + scene.render.film_transparent = True + world = bpy.data.worlds["World"] + world.use_nodes = True + bg = world.node_tree.nodes["Background"] + bg.inputs[0].default_value[:3] = (1.0, 1.0, 1.0) + bg.inputs[1].default_value = 1.0 + + # Remove default cube + if "Cube" in bpy.data.objects: + bpy.data.objects["Cube"].select_set(True) + bpy.ops.object.delete() + + bpy.ops.object.light_add( + type="SUN", align="WORLD", location=(0, 0, 0), scale=(1, 1, 1) + ) + bpy.data.objects["Sun"].data.energy = 1.5 + + # rotate camera + bpy.ops.object.empty_add( + type="PLAIN_AXES", align="WORLD", location=(0, 0, 0), scale=(1, 1, 1) + ) + bpy.ops.transform.resize( + value=(10, 10, 10), + orient_type="GLOBAL", + orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), + orient_matrix_type="GLOBAL", + mirror=True, + use_proportional_edit=False, + proportional_edit_falloff="SMOOTH", + proportional_size=1, + use_proportional_connected=False, + use_proportional_projected=False, + ) + bpy.ops.object.select_all(action="DESELECT") + + setup_renderer( + denoising=denoising, oldrender=oldrender, accelerator=accelerator, device=device + ) + return scene diff --git a/mld/render/blender/tools.py b/mld/render/blender/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..57ea7f2c59a41dd633fa37a56dc317733945373e --- /dev/null +++ b/mld/render/blender/tools.py @@ -0,0 +1,40 @@ +import bpy +import numpy as np + + +# see this for more explanation +# https://gist.github.com/iyadahmed/7c7c0fae03c40bd87e75dc7059e35377 +# This should be solved with new version of blender +class ndarray_pydata(np.ndarray): + def __bool__(self) -> bool: + return len(self) > 0 + + +def load_numpy_vertices_into_blender(vertices, faces, name, mat): + mesh = bpy.data.meshes.new(name) + mesh.from_pydata(vertices, [], faces.view(ndarray_pydata)) + mesh.validate() + + obj = bpy.data.objects.new(name, mesh) + bpy.context.scene.collection.objects.link(obj) + + bpy.ops.object.select_all(action='DESELECT') + obj.select_set(True) + obj.active_material = mat + bpy.context.view_layer.objects.active = obj + bpy.ops.object.shade_smooth() + bpy.ops.object.select_all(action='DESELECT') + return True + + +def delete_objs(names): + if not isinstance(names, list): + names = [names] + # bpy.ops.object.mode_set(mode='OBJECT') + bpy.ops.object.select_all(action='DESELECT') + for obj in bpy.context.scene.objects: + for name in names: + if obj.name.startswith(name) or obj.name.endswith(name): + obj.select_set(True) + bpy.ops.object.delete() + bpy.ops.object.select_all(action='DESELECT') diff --git a/mld/render/video.py b/mld/render/video.py new file mode 100644 index 0000000000000000000000000000000000000000..451170bb67b73f3856fbc70c3fda3b51f151ac49 --- /dev/null +++ b/mld/render/video.py @@ -0,0 +1,66 @@ +import moviepy.editor as mp +import os +import imageio + + +def mask_png(frames): + for frame in frames: + im = imageio.imread(frame) + im[im[:, :, 3] < 1, :] = 255 + imageio.imwrite(frame, im[:, :, 0:3]) + return + + +class Video: + def __init__(self, frame_path: str, fps: float = 12.5, res="high"): + frame_path = str(frame_path) + self.fps = fps + + self._conf = {"codec": "libx264", + "fps": self.fps, + "audio_codec": "aac", + "temp_audiofile": "temp-audio.m4a", + "remove_temp": True} + + if res == "low": + bitrate = "500k" + else: + bitrate = "5000k" + + self._conf = {"bitrate": bitrate, + "fps": self.fps} + + # Load video + # video = mp.VideoFileClip(video1_path, audio=False) + # Load with frames + frames = [os.path.join(frame_path, x) + for x in sorted(os.listdir(frame_path))] + + # mask background white for videos + mask_png(frames) + + video = mp.ImageSequenceClip(frames, fps=fps) + self.video = video + self.duration = video.duration + + def add_text(self, text): + # needs ImageMagick + video_text = mp.TextClip(text, + font='Amiri', + color='white', + method='caption', + align="center", + size=(self.video.w, None), + fontsize=30) + video_text = video_text.on_color(size=(self.video.w, video_text.h + 5), + color=(0, 0, 0), + col_opacity=0.6) + # video_text = video_text.set_pos('bottom') + video_text = video_text.set_pos('top') + + self.video = mp.CompositeVideoClip([self.video, video_text]) + + def save(self, out_path): + out_path = str(out_path) + self.video.subclip(0, self.duration).write_videofile( + out_path, **self._conf) diff --git a/mld/transforms/__init__.py b/mld/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/transforms/joints2rots/__init__.py b/mld/transforms/joints2rots/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/transforms/joints2rots/config.py b/mld/transforms/joints2rots/config.py new file mode 100644 index 0000000000000000000000000000000000000000..cd15c7581a68ed6e97f32a6f070d4bbd5a7befb8 --- /dev/null +++ b/mld/transforms/joints2rots/config.py @@ -0,0 +1,31 @@ +# Map joints Name to SMPL joints idx +JOINT_MAP = { + 'MidHip': 0, + 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, + 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, + 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, 'LHand': 22, + 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, 'RHand': 23, + 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, + 'LCollar': 13, 'RCollar': 14 +} + +full_smpl_idx = range(24) +key_smpl_idx = [0, 1, 4, 7, 2, 5, 8, 17, 19, 21, 16, 18, 20] + +AMASS_JOINT_MAP = { + 'MidHip': 0, + 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, + 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, + 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, + 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, + 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, + 'LCollar': 13, 'RCollar': 14, +} +amass_idx = range(22) +amass_smpl_idx = range(22) + +SMPL_MODEL_DIR = "./deps/smpl_models" +GMM_MODEL_DIR = "./deps/smpl_models" +SMPL_MEAN_FILE = "./deps/smpl_models/neutral_smpl_mean_params.h5" +# for collision +Part_Seg_DIR = "./deps/smpl_models/smplx_parts_segm.pkl" diff --git a/mld/transforms/joints2rots/customloss.py b/mld/transforms/joints2rots/customloss.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f3a922184c1efc5a3198cfe120c0e2ab7b6554 --- /dev/null +++ b/mld/transforms/joints2rots/customloss.py @@ -0,0 +1,100 @@ +import torch + +from mld.transforms.joints2rots import config + + +def gmof(x, sigma): + """ + Geman-McClure error function + """ + x_squared = x ** 2 + sigma_squared = sigma ** 2 + return (sigma_squared * x_squared) / (sigma_squared + x_squared) + + +def angle_prior(pose): + """ + Angle prior that penalizes unnatural bending of the knees and elbows + """ + # We subtract 3 because pose does not include the global rotation of the model + return torch.exp( + pose[:, [55 - 3, 58 - 3, 12 - 3, 15 - 3]] * torch.tensor([1., -1., -1, -1.], device=pose.device)) ** 2 + + +def body_fitting_loss_3d(body_pose, preserve_pose, + betas, model_joints, camera_translation, + j3d, pose_prior, + joints3d_conf, + sigma=100, pose_prior_weight=4.78 * 1.5, + shape_prior_weight=5.0, angle_prior_weight=15.2, + joint_loss_weight=500.0, + pose_preserve_weight=0.0, + use_collision=False, + model_vertices=None, model_faces=None, + search_tree=None, pen_distance=None, filter_faces=None, + collision_loss_weight=1000 + ): + """ + Loss function for body fitting + """ + batch_size = body_pose.shape[0] + + joint3d_error = gmof((model_joints + camera_translation) - j3d, sigma) + + joint3d_loss_part = (joints3d_conf ** 2) * joint3d_error.sum(dim=-1) + joint3d_loss = ((joint_loss_weight ** 2) * joint3d_loss_part).sum(dim=-1) + + # Pose prior loss + pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) + # Angle prior for knees and elbows + angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) + # Regularizer to prevent betas from taking large values + shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) + + collision_loss = 0.0 + # Calculate the loss due to interpenetration + if use_collision: + triangles = torch.index_select( + model_vertices, 1, + model_faces).view(batch_size, -1, 3, 3) + + with torch.no_grad(): + collision_idxs = search_tree(triangles) + + # Remove unwanted collisions + if filter_faces is not None: + collision_idxs = filter_faces(collision_idxs) + + if collision_idxs.ge(0).sum().item() > 0: + collision_loss = torch.sum(collision_loss_weight * pen_distance(triangles, collision_idxs)) + + pose_preserve_loss = (pose_preserve_weight ** 2) * ((body_pose - preserve_pose) ** 2).sum(dim=-1) + + total_loss = joint3d_loss + pose_prior_loss + angle_prior_loss + shape_prior_loss + collision_loss + pose_preserve_loss + + return total_loss.sum() + + +def camera_fitting_loss_3d(model_joints, camera_t, camera_t_est, + j3d, joints_category="orig", depth_loss_weight=100.0): + """ + Loss function for camera optimization. + """ + model_joints = model_joints + camera_t + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category == "orig": + select_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category == "AMASS": + select_joints_ind = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + j3d_error_loss = (j3d[:, select_joints_ind] - model_joints[:, gt_joints_ind]) ** 2 + + # Loss that penalizes deviation from depth estimate + depth_loss = (depth_loss_weight ** 2) * (camera_t - camera_t_est) ** 2 + + total_loss = j3d_error_loss + depth_loss + return total_loss.sum() diff --git a/mld/transforms/joints2rots/prior.py b/mld/transforms/joints2rots/prior.py new file mode 100644 index 0000000000000000000000000000000000000000..70e8d19558661c712f1c3de82340d994c837dd3d --- /dev/null +++ b/mld/transforms/joints2rots/prior.py @@ -0,0 +1,208 @@ +import os +import sys +import pickle + +import numpy as np + +import torch +import torch.nn as nn + +DEFAULT_DTYPE = torch.float32 + + +def create_prior(prior_type, **kwargs): + if prior_type == 'gmm': + prior = MaxMixturePrior(**kwargs) + elif prior_type == 'l2': + return L2Prior(**kwargs) + elif prior_type == 'angle': + return SMPLifyAnglePrior(**kwargs) + elif prior_type == 'none' or prior_type is None: + # Don't use any pose prior + def no_prior(*args, **kwargs): + return 0.0 + prior = no_prior + else: + raise ValueError('Prior {}'.format(prior_type) + ' is not implemented') + return prior + + +class SMPLifyAnglePrior(nn.Module): + def __init__(self, dtype=DEFAULT_DTYPE, **kwargs): + super(SMPLifyAnglePrior, self).__init__() + + # Indices for the rotation angle of + # 55: left elbow, 90deg bend at -np.pi/2 + # 58: right elbow, 90deg bend at np.pi/2 + # 12: left knee, 90deg bend at np.pi/2 + # 15: right knee, 90deg bend at np.pi/2 + angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64) + angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long) + self.register_buffer('angle_prior_idxs', angle_prior_idxs) + + angle_prior_signs = np.array([1, -1, -1, -1], + dtype=np.float32 if dtype == torch.float32 + else np.float64) + angle_prior_signs = torch.tensor(angle_prior_signs, + dtype=dtype) + self.register_buffer('angle_prior_signs', angle_prior_signs) + + def forward(self, pose, with_global_pose=False): + ''' Returns the angle prior loss for the given pose + + Args: + pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle + representation of the rotations of the joints of the SMPL model. + Kwargs: + with_global_pose: Whether the pose vector also contains the global + orientation of the SMPL model. If not then the indices must be + corrected. + Returns: + A sze (B) tensor containing the angle prior loss for each element + in the batch. + ''' + angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3 + return torch.exp(pose[:, angle_prior_idxs] * + self.angle_prior_signs).pow(2) + + +class L2Prior(nn.Module): + def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs): + super(L2Prior, self).__init__() + + def forward(self, module_input, *args): + return torch.sum(module_input.pow(2)) + + +class MaxMixturePrior(nn.Module): + + def __init__(self, prior_folder='prior', + num_gaussians=6, dtype=DEFAULT_DTYPE, epsilon=1e-16, + use_merged=True, + **kwargs): + super(MaxMixturePrior, self).__init__() + + if dtype == DEFAULT_DTYPE: + np_dtype = np.float32 + elif dtype == torch.float64: + np_dtype = np.float64 + else: + print('Unknown float type {}, exiting!'.format(dtype)) + sys.exit(-1) + + self.num_gaussians = num_gaussians + self.epsilon = epsilon + self.use_merged = use_merged + gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) + + full_gmm_fn = os.path.join(prior_folder, gmm_fn) + if not os.path.exists(full_gmm_fn): + print('The path to the mixture prior "{}"'.format(full_gmm_fn) + + ' does not exist, exiting!') + sys.exit(-1) + + with open(full_gmm_fn, 'rb') as f: + gmm = pickle.load(f, encoding='latin1') + + if type(gmm) == dict: + means = gmm['means'].astype(np_dtype) + covs = gmm['covars'].astype(np_dtype) + weights = gmm['weights'].astype(np_dtype) + elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): + means = gmm.means_.astype(np_dtype) + covs = gmm.covars_.astype(np_dtype) + weights = gmm.weights_.astype(np_dtype) + else: + print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) + sys.exit(-1) + + self.register_buffer('means', torch.tensor(means, dtype=dtype)) + + self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) + + precisions = [np.linalg.inv(cov) for cov in covs] + precisions = np.stack(precisions).astype(np_dtype) + + self.register_buffer('precisions', + torch.tensor(precisions, dtype=dtype)) + + # The constant term: + sqrdets = np.array([(np.sqrt(np.linalg.det(c))) + for c in gmm['covars']]) + const = (2 * np.pi)**(69 / 2.) + + nll_weights = np.asarray(gmm['weights'] / (const * + (sqrdets / sqrdets.min()))) + nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) + self.register_buffer('nll_weights', nll_weights) + + weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) + self.register_buffer('weights', weights) + + self.register_buffer('pi_term', + torch.log(torch.tensor(2 * np.pi, dtype=dtype))) + + cov_dets = [np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) + for cov in covs] + self.register_buffer('cov_dets', + torch.tensor(cov_dets, dtype=dtype)) + + # The dimensionality of the random variable + self.random_var_dim = self.means.shape[1] + + def get_mean(self): + ''' Returns the mean of the mixture ''' + mean_pose = torch.matmul(self.weights, self.means) + return mean_pose + + def merged_log_likelihood(self, pose, betas): + diff_from_mean = pose.unsqueeze(dim=1) - self.means + + prec_diff_prod = torch.einsum('mij,bmj->bmi', + [self.precisions, diff_from_mean]) + diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) + + curr_loglikelihood = 0.5 * diff_prec_quadratic - \ + torch.log(self.nll_weights) + # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) + + # self.random_var_dim * self.pi_term + + # diff_prec_quadratic + # ) - torch.log(self.weights) + + min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) + return min_likelihood + + def log_likelihood(self, pose, betas, *args, **kwargs): + ''' Create graph operation for negative log-likelihood calculation + ''' + likelihoods = [] + + for idx in range(self.num_gaussians): + mean = self.means[idx] + prec = self.precisions[idx] + cov = self.covs[idx] + diff_from_mean = pose - mean + + curr_loglikelihood = torch.einsum('bj,ji->bi', + [diff_from_mean, prec]) + curr_loglikelihood = torch.einsum('bi,bi->b', + [curr_loglikelihood, + diff_from_mean]) + cov_term = torch.log(torch.det(cov) + self.epsilon) + curr_loglikelihood += 0.5 * (cov_term + + self.random_var_dim * + self.pi_term) + likelihoods.append(curr_loglikelihood) + + log_likelihoods = torch.stack(likelihoods, dim=1) + min_idx = torch.argmin(log_likelihoods, dim=1) + weight_component = self.nll_weights[:, min_idx] + weight_component = -torch.log(weight_component) + + return weight_component + log_likelihoods[:, min_idx] + + def forward(self, pose, betas): + if self.use_merged: + return self.merged_log_likelihood(pose, betas) + else: + return self.log_likelihood(pose, betas) diff --git a/mld/transforms/joints2rots/smplify.py b/mld/transforms/joints2rots/smplify.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ed9dc87dcffb78257f88e997f5c4aabffcce20 --- /dev/null +++ b/mld/transforms/joints2rots/smplify.py @@ -0,0 +1,274 @@ +import os +import pickle + +from tqdm import tqdm + +import torch + +from mld.transforms.joints2rots import config +from mld.transforms.joints2rots.customloss import camera_fitting_loss_3d, body_fitting_loss_3d +from mld.transforms.joints2rots.prior import MaxMixturePrior + + +@torch.no_grad() +def guess_init_3d(model_joints, j3d, joints_category="orig"): + """ + Initialize the camera translation via triangle similarity, by using the torso joints . + """ + # get the indexed four + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category == "orig": + joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category == "AMASS": + joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum(dim=1) + init_t = sum_init_t / 4.0 + return init_t + + +# SMPLIfy 3D +class SMPLify3D(): + """Implementation of SMPLify, use 3D joints.""" + + def __init__(self, + smplxmodel, + step_size=1e-2, + batch_size=1, + num_iters=100, + use_collision=False, + use_lbfgs=True, + joints_category="orig", + device=torch.device('cuda:0'), + ): + + # Store options + self.batch_size = batch_size + self.device = device + self.step_size = step_size + + self.num_iters = num_iters + # --- choose optimizer + self.use_lbfgs = use_lbfgs + # GMM pose prior + self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR, + num_gaussians=8, + dtype=torch.float32).to(device) + # collision part + self.use_collision = use_collision + if self.use_collision: + self.part_segm_fn = config.Part_Seg_DIR + + # reLoad SMPL-X model + self.smpl = smplxmodel + + self.model_faces = smplxmodel.faces_tensor.view(-1) + + # select joint joint_category + self.joints_category = joints_category + + if joints_category == "orig": + self.smpl_index = config.full_smpl_idx + self.corr_index = config.full_smpl_idx + elif joints_category == "AMASS": + self.smpl_index = config.amass_smpl_idx + self.corr_index = config.amass_idx + else: + self.smpl_index = None + self.corr_index = None + print("NO SUCH JOINTS CATEGORY!") + + # ---- get the man function here ------ + def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, seq_ind=0): + """Perform body fitting. + Input: + init_pose: SMPL pose estimate + init_betas: SMPL betas estimate + init_cam_t: Camera translation estimate + j3d: joints 3d aka keypoints + conf_3d: confidence for 3d joints + seq_ind: index of the sequence + Returns: + vertices: Vertices of optimized shape + joints: 3D joints of optimized shape + pose: SMPL pose parameters of optimized shape + betas: SMPL beta parameters of optimized shape + camera_translation: Camera translation + """ + + # # # add the mesh inter-section to avoid + search_tree = None + pen_distance = None + filter_faces = None + + if self.use_collision: + from mesh_intersection.bvh_search_tree import BVH + import mesh_intersection.loss as collisions_loss + from mesh_intersection.filter_faces import FilterFaces + + search_tree = BVH(max_collisions=8) + + pen_distance = collisions_loss.DistanceFieldPenetrationLoss( + sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True) + + if self.part_segm_fn: + # Read the part segmentation + part_segm_fn = os.path.expandvars(self.part_segm_fn) + with open(part_segm_fn, 'rb') as faces_parents_file: + face_segm_data = pickle.load(faces_parents_file, encoding='latin1') + faces_segm = face_segm_data['segm'] + faces_parents = face_segm_data['parents'] + # Create the module used to filter invalid collision pairs + filter_faces = FilterFaces( + faces_segm=faces_segm, faces_parents=faces_parents, + ign_part_pairs=None).to(device=self.device) + + # Split SMPL pose to body pose and global orientation + body_pose = init_pose[:, 3:].detach().clone() + global_orient = init_pose[:, :3].detach().clone() + betas = init_betas.detach().clone() + + # use guess 3d to get the initial + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).unsqueeze(1).detach() + camera_translation = init_cam_t.clone() + + preserve_pose = init_pose[:, 3:].detach().clone() + # -------------Step 1: Optimize camera translation and body orientation-------- + # Optimize only camera translation and body orientation + body_pose.requires_grad = False + betas.requires_grad = False + global_orient.requires_grad = True + camera_translation.requires_grad = True + + camera_opt_params = [global_orient, camera_translation] + + if self.use_lbfgs: + camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + for i in range(10): + def closure(): + camera_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + loss = camera_fitting_loss_3d(model_joints, camera_translation, + init_cam_t, j3d, self.joints_category) + loss.backward() + return loss + + camera_optimizer.step(closure) + else: + camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(20): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + loss = camera_fitting_loss_3d(model_joints[:, self.smpl_index], camera_translation, + init_cam_t, j3d[:, self.corr_index], self.joints_category) + camera_optimizer.zero_grad() + loss.backward() + camera_optimizer.step() + + # Fix camera translation after optimizing camera + # --------Step 2: Optimize body joints -------------------------- + # Optimize only the body pose and global orientation of the body + body_pose.requires_grad = True + global_orient.requires_grad = True + camera_translation.requires_grad = True + + # --- if we use the sequence, fix the shape + if seq_ind == 0: + betas.requires_grad = True + body_opt_params = [body_pose, betas, global_orient, camera_translation] + else: + betas.requires_grad = False + body_opt_params = [body_pose, global_orient, camera_translation] + + if self.use_lbfgs: + body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + + for i in tqdm(range(self.num_iters), desc=f"LBFGS iter: "): + def closure(): + body_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], + camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + pose_preserve_weight=5.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, + filter_faces=filter_faces) + loss.backward() + return loss + + body_optimizer.step(closure) + else: + body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(self.num_iters): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], + camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, + filter_faces=filter_faces) + body_optimizer.zero_grad() + loss.backward() + body_optimizer.step() + + # Get final loss value + with torch.no_grad(): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas, return_full_pose=True) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + final_loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], + camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, model_vertices=model_vertices, + model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, + filter_faces=filter_faces) + + vertices = smpl_output.vertices.detach() + joints = smpl_output.joints.detach() + pose = torch.cat([global_orient, body_pose], dim=-1).detach() + betas = betas.detach() + + return vertices, joints, pose, betas, camera_translation, final_loss diff --git a/mld/utils/__init__.py b/mld/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/utils/temos_utils.py b/mld/utils/temos_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..22b62750d1dd0dca0ca2b85af58109043b0559c1 --- /dev/null +++ b/mld/utils/temos_utils.py @@ -0,0 +1,18 @@ +import torch + + +def lengths_to_mask(lengths: list[int], + device: torch.device, + max_len: int = None) -> torch.Tensor: + lengths = torch.tensor(lengths, device=device) + max_len = max_len if max_len else max(lengths) + mask = torch.arange(max_len, device=device).expand( + len(lengths), max_len) < lengths.unsqueeze(1) + return mask + + +def remove_padding(tensors: torch.Tensor, lengths: list[int]) -> list: + return [ + tensor[:tensor_length] + for tensor, tensor_length in zip(tensors, lengths) + ] diff --git a/mld/utils/utils.py b/mld/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d85397cf05a84e5aa13a26135095606518a22e41 --- /dev/null +++ b/mld/utils/utils.py @@ -0,0 +1,99 @@ +import random + +import numpy as np + +from rich import get_console +from rich.table import Table + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def print_table(title: str, metrics: dict) -> None: + table = Table(title=title) + + table.add_column("Metrics", style="cyan", no_wrap=True) + table.add_column("Value", style="magenta") + + for key, value in metrics.items(): + table.add_row(key, str(value)) + + console = get_console() + console.print(table, justify="center") + + +def move_batch_to_device(batch: dict, device: torch.device) -> dict: + for key in batch.keys(): + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].to(device) + return batch + + +def count_parameters(module: nn.Module) -> float: + num_params = sum(p.numel() for p in module.parameters()) + return round(num_params / 1e6, 3) + + +def get_guidance_scale_embedding(w: torch.Tensor, embedding_dim: int = 512, + dtype: torch.dtype = torch.float32) -> torch.Tensor: + assert len(w.shape) == 1 + w = w * 1000.0 + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + +def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor: + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def sum_flat(tensor: torch.Tensor) -> torch.Tensor: + return tensor.sum(dim=list(range(1, len(tensor.shape)))) + + +def control_loss_calculate( + vaeloss_type: str, loss_func: str, src: torch.Tensor, + tgt: torch.Tensor, mask: torch.Tensor +) -> torch.Tensor: + + if loss_func == 'l1': + loss = F.l1_loss(src, tgt, reduction='none') + elif loss_func == 'l1_smooth': + loss = F.smooth_l1_loss(src, tgt, reduction='none') + elif loss_func == 'l2': + loss = F.mse_loss(src, tgt, reduction='none') + else: + raise ValueError(f'Unknown loss func: {loss_func}') + + if vaeloss_type == 'sum': + loss = loss.sum(-1, keepdims=True) * mask + loss = loss.sum() / mask.sum() + elif vaeloss_type == 'sum_mask': + loss = loss.sum(-1, keepdims=True) * mask + loss = sum_flat(loss) / sum_flat(mask) + loss = loss.mean() + elif vaeloss_type == 'mask': + loss = sum_flat(loss * mask) + n_entries = src.shape[-1] + non_zero_elements = sum_flat(mask) * n_entries + loss = loss / non_zero_elements + loss = loss.mean() + else: + raise ValueError(f'Unsupported vaeloss_type: {vaeloss_type}') + + return loss diff --git a/prepare/download_glove.sh b/prepare/download_glove.sh new file mode 100644 index 0000000000000000000000000000000000000000..7bb914f20ef6f1ac3e3979a9b7eab6f18c282f0f --- /dev/null +++ b/prepare/download_glove.sh @@ -0,0 +1,12 @@ +mkdir -p deps/ +cd deps/ + +echo -e "Downloading glove (in use by the evaluators)" +gdown --fuzzy https://drive.google.com/file/d/1cmXKUT31pqd7_XpJAiWEo1K81TMYHA5n/view?usp=sharing +rm -rf glove + +unzip glove.zip +echo -e "Cleaning\n" +rm glove.zip + +echo -e "Downloading done!" \ No newline at end of file diff --git a/prepare/download_pretrained_models.sh b/prepare/download_pretrained_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..80db402c2b673a3047d7c2c43ec022156b841875 --- /dev/null +++ b/prepare/download_pretrained_models.sh @@ -0,0 +1,17 @@ +echo -e "Downloading experiments_recons!" +gdown --fuzzy https://drive.google.com/file/d/15zFDitcOLhjbQ0CaOoM-QNKQUeyJw-Om/view?usp=sharing +unzip experiments_recons.zip + +echo -e "Downloading experiments_t2m!" +gdown --fuzzy https://drive.google.com/file/d/1U7homKobR2gaDLfL5flS3N0g7e0a_AQd/view?usp=sharing +unzip experiments_t2m.zip + +echo -e "Downloading experiments_control!" +gdown --fuzzy https://drive.google.com/file/d/1o6oFdH5dgQJNB5J2rDGKCMw3FDr9gUkW/view?usp=sharing +unzip experiments_control.zip + +rm experiments_recons.zip +rm experiments_t2m.zip +rm experiments_control.zip + +echo -e "Downloading done!" diff --git a/prepare/download_smpl_models.sh b/prepare/download_smpl_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..33777c17720c0df4abfb10afc5926684fb8e157a --- /dev/null +++ b/prepare/download_smpl_models.sh @@ -0,0 +1,12 @@ +mkdir -p deps/ +cd deps/ + +echo -e "Downloading smpl models" +gdown --fuzzy https://drive.google.com/file/d/1J2pTxrar_q689Du5r3jES343fZUmCs_y/view?usp=sharing +rm -rf smpl_models + +unzip smpl_models.zip +echo -e "Cleaning\n" +rm smpl_models.zip + +echo -e "Downloading done!" \ No newline at end of file diff --git a/prepare/download_t2m_evaluators.sh b/prepare/download_t2m_evaluators.sh new file mode 100644 index 0000000000000000000000000000000000000000..4f38467b6d43686592b14ba5b538bf1788ee3595 --- /dev/null +++ b/prepare/download_t2m_evaluators.sh @@ -0,0 +1,13 @@ +mkdir -p deps/ +cd deps/ + +echo "The t2m evaluators will be stored in the './deps' folder" + +echo "Downloading" +gdown --fuzzy https://drive.google.com/file/d/16hyR4XlEyksVyNVjhIWK684Lrm_7_pvX/view?usp=sharing +echo "Extracting" +unzip t2m.zip +echo "Cleaning" +rm t2m.zip + +echo "Downloading done!" diff --git a/prepare/prepare_bert.sh b/prepare/prepare_bert.sh new file mode 100644 index 0000000000000000000000000000000000000000..a05c9d1085fc418e61727d64256d52717bd4c68f --- /dev/null +++ b/prepare/prepare_bert.sh @@ -0,0 +1,5 @@ +mkdir -p deps/ +cd deps/ +git lfs install +git clone https://huggingface.co/distilbert-base-uncased +cd .. diff --git a/prepare/prepare_clip.sh b/prepare/prepare_clip.sh new file mode 100644 index 0000000000000000000000000000000000000000..bf786225bdcd891623a69fea6a16f86f7c155cf7 --- /dev/null +++ b/prepare/prepare_clip.sh @@ -0,0 +1,5 @@ +mkdir -p deps/ +cd deps/ +git lfs install +git clone https://huggingface.co/openai/clip-vit-large-patch14 +cd .. diff --git a/prepare/prepare_t5.sh b/prepare/prepare_t5.sh new file mode 100644 index 0000000000000000000000000000000000000000..507dcc2b15eac88d7ca7c7afbd85faaa9d9927d4 --- /dev/null +++ b/prepare/prepare_t5.sh @@ -0,0 +1,5 @@ +mkdir -p deps/ +cd deps/ +git lfs install +git clone https://huggingface.co/sentence-transformers/sentence-t5-large +cd .. diff --git a/prepare/prepare_tiny_humanml3d.sh b/prepare/prepare_tiny_humanml3d.sh new file mode 100644 index 0000000000000000000000000000000000000000..a9691edc94dfeaf434e3848bd83a1ccf52fb5c71 --- /dev/null +++ b/prepare/prepare_tiny_humanml3d.sh @@ -0,0 +1,6 @@ +mkdir -p datasets/ +cd datasets/ + +gdown --fuzzy https://drive.google.com/file/d/1Mg_3RnWmRt0tk_lyLRRiOZg1W-Fu4wLL/view?usp=sharing +unzip humanml3d_tiny.zip +rm humanml3d_tiny.zip diff --git a/render.py b/render.py new file mode 100644 index 0000000000000000000000000000000000000000..b59188774e826fd8f1b82dd11fdd0e720b3ff285 --- /dev/null +++ b/render.py @@ -0,0 +1,80 @@ +import os +import pickle +import sys +from argparse import ArgumentParser + +try: + import bpy + sys.path.append(os.path.dirname(bpy.data.filepath)) +except ImportError: + raise ImportError( + "Blender is not properly installed or not launch properly. See README.md to have instruction on how to install and use blender.") + +import mld.launch.blender # noqa +from mld.render.blender import render + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--pkl", type=str, default=None, help="pkl motion file") + parser.add_argument("--dir", type=str, default=None, help="pkl motion folder") + parser.add_argument("--mode", type=str, default="sequence", help="render target: video, sequence, frame") + parser.add_argument("--res", type=str, default="high") + parser.add_argument("--denoising", type=bool, default=True) + parser.add_argument("--oldrender", type=bool, default=True) + parser.add_argument("--accelerator", type=str, default='gpu', help='accelerator device') + parser.add_argument("--device", type=int, nargs='+', default=[0], help='gpu ids') + parser.add_argument("--faces_path", type=str, default='./deps/smpl_models/smplh/smplh.faces') + parser.add_argument("--always_on_floor", action="store_true", help='put all the body on the floor (not recommended)') + parser.add_argument("--gt", type=str, default=False, help='green for gt, otherwise orange') + parser.add_argument("--fps", type=int, default=20, help="the frame rate of the rendered video") + parser.add_argument("--num", type=int, default=8, help="the number of frames rendered in 'sequence' mode") + parser.add_argument("--exact_frame", type=float, default=0.5, help="the frame id selected under 'frame' mode ([0, 1])") + cfg = parser.parse_args() + return cfg + + +def render_cli() -> None: + cfg = parse_args() + + if cfg.pkl: + paths = [cfg.pkl] + elif cfg.dir: + paths = [] + file_list = os.listdir(cfg.dir) + for item in file_list: + if item.endswith("_mesh.pkl"): + paths.append(os.path.join(cfg.dir, item)) + else: + raise ValueError(f'{cfg.pkl} and {cfg.dir} are both None!') + + for path in paths: + try: + with open(path, 'rb') as f: + pkl = pickle.load(f) + data = pkl['vertices'] + trajectory = pkl['hint'] + + except FileNotFoundError: + print(f"{path} not found") + continue + + render( + data, + trajectory, + path, + exact_frame=cfg.exact_frame, + num=cfg.num, + mode=cfg.mode, + faces_path=cfg.faces_path, + always_on_floor=cfg.always_on_floor, + oldrender=cfg.oldrender, + res=cfg.res, + gt=cfg.gt, + accelerator=cfg.accelerator, + device=cfg.device, + fps=cfg.fps) + + +if __name__ == "__main__": + render_cli() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b19f9e282840b17b06e77fe272dbd55bf06fcd22 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +--extra-index-url https://download.pytorch.org/whl/cu116 +torch==1.13.1 +gdown +omegaconf +rich +swanlab==0.3.1 +torchmetrics==1.3.2 +scipy==1.11.2 +matplotlib==3.3.4 +transformers==4.38.0 +sentence-transformers==2.2.2 +diffusers==0.24.0 +tensorboard==2.15.1 +h5py==3.11.0 +smplx==0.1.28 +chumpy==0.70 +numpy==1.23.1 +natsort==8.4.0 diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..205fdf22a742bac32c965ac6c68ea791dc1ba3df --- /dev/null +++ b/test.py @@ -0,0 +1,160 @@ +import os +import sys +import json +import datetime +import logging +import os.path as osp +from typing import Union + + +import numpy as np +from tqdm.auto import tqdm +from omegaconf import OmegaConf + +import torch +from torch.utils.data import DataLoader + +from mld.config import parse_args +from mld.data.get_data import get_dataset +from mld.models.modeltype.mld import MLD +from mld.models.modeltype.vae import VAE +from mld.utils.utils import print_table, set_seed, move_batch_to_device + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def get_metric_statistics(values: np.ndarray, replication_times: int) -> tuple: + mean = np.mean(values, axis=0) + std = np.std(values, axis=0) + conf_interval = 1.96 * std / np.sqrt(replication_times) + return mean, conf_interval + + +@torch.no_grad() +def test_one_epoch(model: Union[VAE, MLD], dataloader: DataLoader, device: torch.device) -> dict: + for batch in tqdm(dataloader): + batch = move_batch_to_device(batch, device) + model.test_step(batch) + metrics = model.allsplit_epoch_end() + return metrics + + +def main(): + cfg = parse_args() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + set_seed(cfg.SEED_VALUE) + + name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + cfg.output_dir = osp.join(cfg.TEST_FOLDER, name_time_str) + os.makedirs(cfg.output_dir, exist_ok=False) + + steam_handler = logging.StreamHandler(sys.stdout) + file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log')) + logging.basicConfig(level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[steam_handler, file_handler]) + logger = logging.getLogger(__name__) + + OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml')) + + state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"] + logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS)) + + # Step 1: Check if the checkpoint is VAE-based. + is_vae = False + vae_key = 'vae.skel_embedding.weight' + if vae_key in state_dict: + is_vae = True + logger.info(f'Is VAE: {is_vae}') + + # Step 2: Check if the checkpoint is MLD-based. + is_mld = False + mld_key = 'denoiser.time_embedding.linear_1.weight' + if mld_key in state_dict: + is_mld = True + logger.info(f'Is MLD: {is_mld}') + + # Step 3: Check if the checkpoint is LCM-based. + is_lcm = False + lcm_key = 'denoiser.time_embedding.cond_proj.weight' # unique key for CFG + if lcm_key in state_dict: + is_lcm = True + time_cond_proj_dim = state_dict[lcm_key].shape[1] + cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim + logger.info(f'Is LCM: {is_lcm}') + + # Step 4: Check if the checkpoint is Controlnet-based. + cn_key = "controlnet.controlnet_cond_embedding.0.weight" + is_controlnet = True if cn_key in state_dict else False + cfg.model.is_controlnet = is_controlnet + logger.info(f'Is Controlnet: {is_controlnet}') + + if is_mld or is_lcm or is_controlnet: + target_model_class = MLD + else: + target_model_class = VAE + + if cfg.optimize: + assert cfg.model.get('noise_optimizer') is not None + cfg.model.noise_optimizer.params.optimize = True + logger.info('Optimization enabled. Set the batch size to 1.') + logger.info(f'Original batch size: {cfg.TEST.BATCH_SIZE}') + cfg.TEST.BATCH_SIZE = 1 + + dataset = get_dataset(cfg) + test_dataloader = dataset.test_dataloader() + model = target_model_class(cfg, dataset) + model.to(device) + model.eval() + model.requires_grad_(False) + logger.info(model.load_state_dict(state_dict)) + + all_metrics = {} + replication_times = cfg.TEST.REPLICATION_TIMES + max_num_samples = cfg.TEST.get('MAX_NUM_SAMPLES') + name_list = test_dataloader.dataset.name_list + # calculate metrics + for i in range(replication_times): + if max_num_samples is not None: + chosen_list = np.random.choice(name_list, max_num_samples, replace=False) + test_dataloader.dataset.name_list = chosen_list + + metrics_type = ", ".join(cfg.METRIC.TYPE) + logger.info(f"Evaluating {metrics_type} - Replication {i}") + metrics = test_one_epoch(model, test_dataloader, device) + + if "TM2TMetrics" in metrics_type and cfg.TEST.DO_MM_TEST: + # mm metrics + logger.info(f"Evaluating MultiModality - Replication {i}") + dataset.mm_mode(True) + test_mm_dataloader = dataset.test_dataloader() + mm_metrics = test_one_epoch(model, test_mm_dataloader, device) + metrics.update(mm_metrics) + dataset.mm_mode(False) + + print_table(f"Metrics@Replication-{i}", metrics) + logger.info(metrics) + + for key, item in metrics.items(): + if key not in all_metrics: + all_metrics[key] = [item] + else: + all_metrics[key] += [item] + + all_metrics_new = dict() + for key, item in all_metrics.items(): + mean, conf_interval = get_metric_statistics(np.array(item), replication_times) + all_metrics_new[key + "/mean"] = mean + all_metrics_new[key + "/conf_interval"] = conf_interval + print_table(f"Mean Metrics", all_metrics_new) + all_metrics_new.update(all_metrics) + # save metrics to file + metric_file = osp.join(cfg.output_dir, f"metrics.json") + with open(metric_file, "w", encoding="utf-8") as f: + json.dump(all_metrics_new, f, indent=4) + logger.info(f"Testing done, the metrics are saved to {str(metric_file)}") + + +if __name__ == "__main__": + main() diff --git a/train_mld.py b/train_mld.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d06a334a8337da0447c71c5c431ae90cc3c17b --- /dev/null +++ b/train_mld.py @@ -0,0 +1,232 @@ +import os +import sys +import logging +import datetime +import os.path as osp + +from tqdm.auto import tqdm +from omegaconf import OmegaConf + +import torch +import swanlab +import diffusers +import transformers +from torch.utils.tensorboard import SummaryWriter +from diffusers.optimization import get_scheduler + +from mld.config import parse_args +from mld.data.get_data import get_dataset +from mld.models.modeltype.mld import MLD +from mld.utils.utils import print_table, set_seed, move_batch_to_device + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def main(): + cfg = parse_args() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + set_seed(cfg.SEED_VALUE) + + name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + cfg.output_dir = osp.join(cfg.FOLDER, name_time_str) + os.makedirs(cfg.output_dir, exist_ok=False) + os.makedirs(f"{cfg.output_dir}/checkpoints", exist_ok=False) + if cfg.TRAIN.model_ema: + os.makedirs(f"{cfg.output_dir}/checkpoints_ema", exist_ok=False) + + if cfg.vis == "tb": + writer = SummaryWriter(cfg.output_dir) + elif cfg.vis == "swanlab": + writer = swanlab.init(project="MotionLCM", + experiment_name=os.path.normpath(cfg.output_dir).replace(os.path.sep, "-"), + suffix=None, config=dict(**cfg), logdir=cfg.output_dir) + else: + raise ValueError(f"Invalid vis method: {cfg.vis}") + + stream_handler = logging.StreamHandler(sys.stdout) + file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log')) + handlers = [file_handler, stream_handler] + logging.basicConfig(level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=handlers) + logger = logging.getLogger(__name__) + + OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml')) + + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + + dataset = get_dataset(cfg) + train_dataloader = dataset.train_dataloader() + val_dataloader = dataset.val_dataloader() + + model = MLD(cfg, dataset) + + assert cfg.TRAIN.PRETRAINED, "cfg.TRAIN.PRETRAINED must not be None." + logger.info(f"Loading pre-trained model: {cfg.TRAIN.PRETRAINED}") + state_dict = torch.load(cfg.TRAIN.PRETRAINED, map_location="cpu")["state_dict"] + logger.info(model.load_state_dict(state_dict, strict=False)) + + model.vae.requires_grad_(False) + model.text_encoder.requires_grad_(False) + model.vae.eval() + model.text_encoder.eval() + model.to(device) + + logger.info("learning_rate: {}".format(cfg.TRAIN.learning_rate)) + optimizer = torch.optim.AdamW( + model.denoiser.parameters(), + lr=cfg.TRAIN.learning_rate, + betas=(cfg.TRAIN.adam_beta1, cfg.TRAIN.adam_beta2), + weight_decay=cfg.TRAIN.adam_weight_decay, + eps=cfg.TRAIN.adam_epsilon) + + if cfg.TRAIN.max_train_steps == -1: + assert cfg.TRAIN.max_train_epochs != -1 + cfg.TRAIN.max_train_steps = cfg.TRAIN.max_train_epochs * len(train_dataloader) + + if cfg.TRAIN.checkpointing_steps == -1: + assert cfg.TRAIN.checkpointing_epochs != -1 + cfg.TRAIN.checkpointing_steps = cfg.TRAIN.checkpointing_epochs * len(train_dataloader) + + if cfg.TRAIN.validation_steps == -1: + assert cfg.TRAIN.validation_epochs != -1 + cfg.TRAIN.validation_steps = cfg.TRAIN.validation_epochs * len(train_dataloader) + + lr_scheduler = get_scheduler( + cfg.TRAIN.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.TRAIN.lr_warmup_steps, + num_training_steps=cfg.TRAIN.max_train_steps) + + # EMA + model_ema = None + if cfg.TRAIN.model_ema: + alpha = 1.0 - cfg.TRAIN.model_ema_decay + logger.info(f'EMA alpha: {alpha}') + model_ema = torch.optim.swa_utils.AveragedModel(model, device, lambda p0, p1, _: (1 - alpha) * p0 + alpha * p1) + + # Train! + logger.info("***** Running training *****") + logging.info(f" Num examples = {len(train_dataloader.dataset)}") + logging.info(f" Num Epochs = {cfg.TRAIN.max_train_epochs}") + logging.info(f" Instantaneous batch size per device = {cfg.TRAIN.BATCH_SIZE}") + logging.info(f" Total optimization steps = {cfg.TRAIN.max_train_steps}") + + global_step = 0 + + @torch.no_grad() + def validation(target_model: MLD, ema: bool = False) -> tuple: + target_model.denoiser.eval() + val_loss_list = [] + for val_batch in tqdm(val_dataloader): + val_batch = move_batch_to_device(val_batch, device) + val_loss_dict = target_model.allsplit_step(split='val', batch=val_batch) + val_loss_list.append(val_loss_dict) + metrics = target_model.allsplit_epoch_end() + metrics[f"Val/loss"] = sum([d['loss'] for d in val_loss_list]).item() / len(val_dataloader) + metrics[f"Val/diff_loss"] = sum([d['diff_loss'] for d in val_loss_list]).item() / len(val_dataloader) + metrics[f"Val/router_loss"] = sum([d['router_loss'] for d in val_loss_list]).item() / len(val_dataloader) + max_val_rp1 = metrics['Metrics/R_precision_top_1'] + min_val_fid = metrics['Metrics/FID'] + print_table(f'Validation@Step-{global_step}', metrics) + for mk, mv in metrics.items(): + mk = mk + '_EMA' if ema else mk + if cfg.vis == "tb": + writer.add_scalar(mk, mv, global_step=global_step) + elif cfg.vis == "swanlab": + writer.log({mk: mv}, step=global_step) + target_model.denoiser.train() + return max_val_rp1, min_val_fid + + max_rp1, min_fid = validation(model) + if cfg.TRAIN.model_ema: + validation(model_ema.module, ema=True) + + progress_bar = tqdm(range(0, cfg.TRAIN.max_train_steps), desc="Steps") + while True: + for step, batch in enumerate(train_dataloader): + batch = move_batch_to_device(batch, device) + loss_dict = model.allsplit_step('train', batch) + + diff_loss = loss_dict['diff_loss'] + router_loss = loss_dict['router_loss'] + loss = loss_dict['loss'] + loss.backward() + torch.nn.utils.clip_grad_norm_(model.denoiser.parameters(), cfg.TRAIN.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + progress_bar.update(1) + global_step += 1 + + if cfg.TRAIN.model_ema and global_step % cfg.TRAIN.model_ema_steps == 0: + model_ema.update_parameters(model) + + if global_step % cfg.TRAIN.checkpointing_steps == 0: + save_path = os.path.join(cfg.output_dir, 'checkpoints', f"checkpoint-{global_step}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path}") + + if cfg.TRAIN.model_ema: + save_path = os.path.join(cfg.output_dir, 'checkpoints_ema', f"checkpoint-{global_step}.ckpt") + ckpt = dict(state_dict=model_ema.module.state_dict()) + model_ema.module.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved EMA state to {save_path}") + + if global_step % cfg.TRAIN.validation_steps == 0: + cur_rp1, cur_fid = validation(model) + if cfg.TRAIN.model_ema: + validation(model_ema.module, ema=True) + + if cur_rp1 > max_rp1: + max_rp1 = cur_rp1 + save_path = os.path.join(cfg.output_dir, 'checkpoints', + f"checkpoint-{global_step}-rp1-{round(cur_rp1, 3)}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with rp1:{round(cur_rp1, 3)}") + + if cur_fid < min_fid: + min_fid = cur_fid + save_path = os.path.join(cfg.output_dir, 'checkpoints', + f"checkpoint-{global_step}-fid-{round(cur_fid, 3)}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with fid:{round(cur_fid, 3)}") + + logs = {"loss": loss.item(), + "diff_loss": diff_loss.item(), + "router_loss": router_loss.item(), + "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + for k, v in logs.items(): + if cfg.vis == "tb": + writer.add_scalar(f"Train/{k}", v, global_step=global_step) + elif cfg.vis == "swanlab": + writer.log({f"Train/{k}": v}, step=global_step) + + if global_step >= cfg.TRAIN.max_train_steps: + save_path = os.path.join(cfg.output_dir, 'checkpoints', f"checkpoint-last.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + + if cfg.TRAIN.model_ema: + save_path = os.path.join(cfg.output_dir, 'checkpoints_ema', f"checkpoint-last.ckpt") + ckpt = dict(state_dict=model_ema.module.state_dict()) + model_ema.module.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + + exit(0) + + +if __name__ == "__main__": + main() diff --git a/train_motion_control.py b/train_motion_control.py new file mode 100644 index 0000000000000000000000000000000000000000..bddaee00b6cca8f7c50f62fcb88893a1331c172d --- /dev/null +++ b/train_motion_control.py @@ -0,0 +1,217 @@ +import os +import sys +import logging +import datetime +import os.path as osp + +from tqdm.auto import tqdm +from omegaconf import OmegaConf + +import torch +import swanlab +import diffusers +import transformers +from torch.utils.tensorboard import SummaryWriter +from diffusers.optimization import get_scheduler + +from mld.config import parse_args +from mld.data.get_data import get_dataset +from mld.models.modeltype.mld import MLD +from mld.utils.utils import print_table, set_seed, move_batch_to_device + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def main(): + cfg = parse_args() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + set_seed(cfg.SEED_VALUE) + + name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + cfg.output_dir = osp.join(cfg.FOLDER, name_time_str) + os.makedirs(cfg.output_dir, exist_ok=False) + os.makedirs(f"{cfg.output_dir}/checkpoints", exist_ok=False) + + if cfg.vis == "tb": + writer = SummaryWriter(cfg.output_dir) + elif cfg.vis == "swanlab": + writer = swanlab.init(project="MotionLCM", + experiment_name=os.path.normpath(cfg.output_dir).replace(os.path.sep, "-"), + suffix=None, config=dict(**cfg), logdir=cfg.output_dir) + else: + raise ValueError(f"Invalid vis method: {cfg.vis}") + + stream_handler = logging.StreamHandler(sys.stdout) + file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log')) + handlers = [file_handler, stream_handler] + logging.basicConfig(level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=handlers) + logger = logging.getLogger(__name__) + + OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml')) + + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + + assert cfg.model.is_controlnet, "cfg.model.is_controlnet must be true for controlling!" + + dataset = get_dataset(cfg) + train_dataloader = dataset.train_dataloader() + val_dataloader = dataset.val_dataloader() + + logger.info(f"Loading pretrained model: {cfg.TRAIN.PRETRAINED}") + state_dict = torch.load(cfg.TRAIN.PRETRAINED, map_location="cpu")["state_dict"] + lcm_key = 'denoiser.time_embedding.cond_proj.weight' + is_lcm = False + if lcm_key in state_dict: + is_lcm = True + time_cond_proj_dim = state_dict[lcm_key].shape[1] + cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim + logger.info(f'Is LCM: {is_lcm}') + + model = MLD(cfg, dataset) + logger.info(model.load_state_dict(state_dict, strict=False)) + logger.info(model.controlnet.load_state_dict(model.denoiser.state_dict(), strict=False)) + + model.vae.requires_grad_(False) + model.text_encoder.requires_grad_(False) + model.denoiser.requires_grad_(False) + model.vae.eval() + model.text_encoder.eval() + model.denoiser.eval() + model.to(device) + + controlnet_params = list(model.controlnet.parameters()) + traj_encoder_params = list(model.traj_encoder.parameters()) + params = controlnet_params + traj_encoder_params + params_to_optimize = [{'params': controlnet_params, 'lr': cfg.TRAIN.learning_rate}, + {'params': traj_encoder_params, 'lr': cfg.TRAIN.learning_rate_spatial}] + + logger.info("learning_rate: {}, learning_rate_spatial: {}". + format(cfg.TRAIN.learning_rate, cfg.TRAIN.learning_rate_spatial)) + + optimizer = torch.optim.AdamW( + params_to_optimize, + betas=(cfg.TRAIN.adam_beta1, cfg.TRAIN.adam_beta2), + weight_decay=cfg.TRAIN.adam_weight_decay, + eps=cfg.TRAIN.adam_epsilon) + + if cfg.TRAIN.max_train_steps == -1: + assert cfg.TRAIN.max_train_epochs != -1 + cfg.TRAIN.max_train_steps = cfg.TRAIN.max_train_epochs * len(train_dataloader) + + if cfg.TRAIN.checkpointing_steps == -1: + assert cfg.TRAIN.checkpointing_epochs != -1 + cfg.TRAIN.checkpointing_steps = cfg.TRAIN.checkpointing_epochs * len(train_dataloader) + + if cfg.TRAIN.validation_steps == -1: + assert cfg.TRAIN.validation_epochs != -1 + cfg.TRAIN.validation_steps = cfg.TRAIN.validation_epochs * len(train_dataloader) + + lr_scheduler = get_scheduler( + cfg.TRAIN.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.TRAIN.lr_warmup_steps, + num_training_steps=cfg.TRAIN.max_train_steps) + + # Train! + logger.info("***** Running training *****") + logging.info(f" Num examples = {len(train_dataloader.dataset)}") + logging.info(f" Num Epochs = {cfg.TRAIN.max_train_epochs}") + logging.info(f" Instantaneous batch size per device = {cfg.TRAIN.BATCH_SIZE}") + logging.info(f" Total optimization steps = {cfg.TRAIN.max_train_steps}") + + global_step = 0 + + @torch.no_grad() + def validation(): + model.controlnet.eval() + model.traj_encoder.eval() + val_loss_list = [] + for val_batch in tqdm(val_dataloader): + val_batch = move_batch_to_device(val_batch, device) + val_loss_dict = model.allsplit_step(split='val', batch=val_batch) + val_loss_list.append(val_loss_dict) + metrics = model.allsplit_epoch_end() + for loss_k in val_loss_list[0].keys(): + metrics[f"Val/{loss_k}"] = sum([d[loss_k] for d in val_loss_list]).item() / len(val_dataloader) + min_val_km = metrics['Metrics/kps_mean_err(m)'] + min_val_tj = metrics['Metrics/traj_fail_50cm'] + print_table(f'Validation@Step-{global_step}', metrics) + for mk, mv in metrics.items(): + if cfg.vis == "tb": + writer.add_scalar(mk, mv, global_step=global_step) + elif cfg.vis == "swanlab": + writer.log({mk: mv}, step=global_step) + model.controlnet.train() + model.traj_encoder.train() + return min_val_km, min_val_tj + + min_km, min_tj = validation() + + progress_bar = tqdm(range(0, cfg.TRAIN.max_train_steps), desc="Steps") + while True: + for step, batch in enumerate(train_dataloader): + batch = move_batch_to_device(batch, device) + loss_dict = model.allsplit_step('train', batch) + + diff_loss = loss_dict['diff_loss'] + cond_loss = loss_dict['cond_loss'] + rot_loss = loss_dict['rot_loss'] + loss = loss_dict['loss'] + + loss.backward() + torch.nn.utils.clip_grad_norm_(params, cfg.TRAIN.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + progress_bar.update(1) + global_step += 1 + + if global_step % cfg.TRAIN.checkpointing_steps == 0: + save_path = os.path.join(cfg.output_dir, 'checkpoints', f"checkpoint-{global_step}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % cfg.TRAIN.validation_steps == 0: + cur_km, cur_tj = validation() + if cur_km < min_km: + min_km = cur_km + save_path = os.path.join(cfg.output_dir, 'checkpoints', f"checkpoint-{global_step}-km-{round(cur_km, 3)}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with km:{round(cur_km, 3)}") + + if cur_tj < min_tj: + min_tj = cur_tj + save_path = os.path.join(cfg.output_dir, 'checkpoints', f"checkpoint-{global_step}-tj-{round(cur_tj, 3)}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with tj:{round(cur_tj, 3)}") + + logs = {"loss": loss.item(), "lr": lr_scheduler.get_last_lr()[0], + "diff_loss": diff_loss.item(), 'cond_loss': cond_loss.item(), 'rot_loss': rot_loss.item()} + progress_bar.set_postfix(**logs) + for k, v in logs.items(): + if cfg.vis == "tb": + writer.add_scalar(f"Train/{k}", v, global_step=global_step) + elif cfg.vis == "swanlab": + writer.log({f"Train/{k}": v}, step=global_step) + + if global_step >= cfg.TRAIN.max_train_steps: + save_path = os.path.join(cfg.output_dir, 'checkpoints', f"checkpoint-last.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + exit(0) + + +if __name__ == "__main__": + main() diff --git a/train_motionlcm.py b/train_motionlcm.py new file mode 100644 index 0000000000000000000000000000000000000000..40f56645a802bf3674446d50c79b5b0fb8b3080d --- /dev/null +++ b/train_motionlcm.py @@ -0,0 +1,443 @@ +import os +import sys +import logging +import datetime +import os.path as osp +from typing import Generator + +import numpy as np +from tqdm.auto import tqdm +from omegaconf import OmegaConf + +import torch +import swanlab +import diffusers +import transformers +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from diffusers.optimization import get_scheduler + +from mld.config import parse_args, instantiate_from_config +from mld.data.get_data import get_dataset +from mld.models.modeltype.mld import MLD +from mld.utils.utils import print_table, set_seed, move_batch_to_device + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def guidance_scale_embedding(w: torch.Tensor, embedding_dim: int = 512, + dtype: torch.dtype = torch.float32) -> torch.Tensor: + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + +def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def scalings_for_boundary_conditions(timestep: torch.Tensor, sigma_data: float = 0.5, + timestep_scaling: float = 10.0) -> tuple: + c_skip = sigma_data ** 2 / ((timestep * timestep_scaling) ** 2 + sigma_data ** 2) + c_out = (timestep * timestep_scaling) / ((timestep * timestep_scaling) ** 2 + sigma_data ** 2) ** 0.5 + return c_skip, c_out + + +def predicted_origin( + model_output: torch.Tensor, + timesteps: torch.Tensor, + sample: torch.Tensor, + prediction_type: str, + alphas: torch.Tensor, + sigmas: torch.Tensor +) -> torch.Tensor: + if prediction_type == "epsilon": + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "v_prediction": + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError(f"Prediction type {prediction_type} currently not supported.") + + return pred_x_0 + + +def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor: + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +class DDIMSolver: + def __init__(self, alpha_cumprods: np.ndarray, timesteps: int = 1000, ddim_timesteps: int = 50) -> None: + # DDIM sampling parameters + step_ratio = timesteps // ddim_timesteps + self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 + self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] + self.ddim_alpha_cumprods_prev = np.asarray( + [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() + ) + # convert to torch tensors + self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() + self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) + self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) + + def to(self, device: torch.device) -> "DDIMSolver": + self.ddim_timesteps = self.ddim_timesteps.to(device) + self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) + self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) + return self + + def ddim_step(self, pred_x0: torch.Tensor, pred_noise: torch.Tensor, + timestep_index: torch.Tensor) -> torch.Tensor: + alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) + dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise + x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt + return x_prev + + +@torch.no_grad() +def update_ema(target_params: Generator, source_params: Generator, rate: float = 0.99) -> None: + for tgt, src in zip(target_params, source_params): + tgt.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def main(): + cfg = parse_args() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + set_seed(cfg.SEED_VALUE) + + name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + cfg.output_dir = osp.join(cfg.FOLDER, name_time_str) + os.makedirs(cfg.output_dir, exist_ok=False) + os.makedirs(f"{cfg.output_dir}/checkpoints", exist_ok=False) + + if cfg.vis == "tb": + writer = SummaryWriter(cfg.output_dir) + elif cfg.vis == "swanlab": + writer = swanlab.init(project="MotionLCM", + experiment_name=os.path.normpath(cfg.output_dir).replace(os.path.sep, "-"), + suffix=None, config=dict(**cfg), logdir=cfg.output_dir) + else: + raise ValueError(f"Invalid vis method: {cfg.vis}") + + stream_handler = logging.StreamHandler(sys.stdout) + file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log')) + handlers = [file_handler, stream_handler] + logging.basicConfig(level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=handlers) + logger = logging.getLogger(__name__) + + OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml')) + + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + + logger.info(f'Training guidance scale range (w): [{cfg.TRAIN.w_min}, {cfg.TRAIN.w_max}]') + logger.info(f'EMA rate (mu): {cfg.TRAIN.ema_decay}') + logger.info(f'Skipping interval (k): {cfg.model.scheduler.params.num_train_timesteps / cfg.TRAIN.num_ddim_timesteps}') + logger.info(f'Loss type (huber or l2): {cfg.TRAIN.loss_type}') + + dataset = get_dataset(cfg) + train_dataloader = dataset.train_dataloader() + val_dataloader = dataset.val_dataloader() + + state_dict = torch.load(cfg.TRAIN.PRETRAINED, map_location="cpu")["state_dict"] + base_model = MLD(cfg, dataset) + logger.info(f"Loading pretrained model: {cfg.TRAIN.PRETRAINED}") + logger.info(base_model.load_state_dict(state_dict)) + + scheduler = base_model.scheduler + alpha_schedule = torch.sqrt(scheduler.alphas_cumprod) + sigma_schedule = torch.sqrt(1 - scheduler.alphas_cumprod) + solver = DDIMSolver( + scheduler.alphas_cumprod.numpy(), + timesteps=scheduler.config.num_train_timesteps, + ddim_timesteps=cfg.TRAIN.num_ddim_timesteps) + + base_model.to(device) + + vae = base_model.vae + text_encoder = base_model.text_encoder + teacher_unet = base_model.denoiser + base_model.denoiser = None + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + teacher_unet.requires_grad_(False) + + # Apply CFG here (Important!!!) + cfg.model.denoiser.params.time_cond_proj_dim = cfg.TRAIN.unet_time_cond_proj_dim + unet = instantiate_from_config(cfg.model.denoiser) + logger.info(f'Loading pretrained model for [unet]') + logger.info(unet.load_state_dict(teacher_unet.state_dict(), strict=False)) + target_unet = instantiate_from_config(cfg.model.denoiser) + logger.info(f'Loading pretrained model for [target_unet]') + logger.info(target_unet.load_state_dict(teacher_unet.state_dict(), strict=False)) + + unet = unet.to(device) + target_unet = target_unet.to(device) + target_unet.requires_grad_(False) + + # Also move the alpha and sigma noise schedules to device + alpha_schedule = alpha_schedule.to(device) + sigma_schedule = sigma_schedule.to(device) + solver = solver.to(device) + + optimizer = torch.optim.AdamW( + unet.parameters(), + lr=cfg.TRAIN.learning_rate, + betas=(cfg.TRAIN.adam_beta1, cfg.TRAIN.adam_beta2), + weight_decay=cfg.TRAIN.adam_weight_decay, + eps=cfg.TRAIN.adam_epsilon) + + if cfg.TRAIN.max_train_steps == -1: + assert cfg.TRAIN.max_train_epochs != -1 + cfg.TRAIN.max_train_steps = cfg.TRAIN.max_train_epochs * len(train_dataloader) + + if cfg.TRAIN.checkpointing_steps == -1: + assert cfg.TRAIN.checkpointing_epochs != -1 + cfg.TRAIN.checkpointing_steps = cfg.TRAIN.checkpointing_epochs * len(train_dataloader) + + if cfg.TRAIN.validation_steps == -1: + assert cfg.TRAIN.validation_epochs != -1 + cfg.TRAIN.validation_steps = cfg.TRAIN.validation_epochs * len(train_dataloader) + + lr_scheduler = get_scheduler( + cfg.TRAIN.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.TRAIN.lr_warmup_steps, + num_training_steps=cfg.TRAIN.max_train_steps) + + uncond_prompt_embeds = text_encoder([""] * cfg.TRAIN.BATCH_SIZE) + + # Train! + logger.info("***** Running training *****") + logging.info(f" Num examples = {len(train_dataloader.dataset)}") + logging.info(f" Num Epochs = {cfg.TRAIN.max_train_epochs}") + logging.info(f" Instantaneous batch size per device = {cfg.TRAIN.BATCH_SIZE}") + logging.info(f" Total optimization steps = {cfg.TRAIN.max_train_steps}") + + global_step = 0 + + @torch.no_grad() + def validation(ema: bool = False) -> tuple: + base_model.denoiser = target_unet if ema else unet + base_model.eval() + for val_batch in tqdm(val_dataloader): + val_batch = move_batch_to_device(val_batch, device) + base_model.allsplit_step(split='test', batch=val_batch) + metrics = base_model.allsplit_epoch_end() + max_val_rp1 = metrics['Metrics/R_precision_top_1'] + min_val_fid = metrics['Metrics/FID'] + print_table(f'Validation@Step-{global_step}', metrics) + for k, v in metrics.items(): + k = k + '_EMA' if ema else k + if cfg.vis == "tb": + writer.add_scalar(k, v, global_step=global_step) + elif cfg.vis == "swanlab": + writer.log({k: v}, step=global_step) + base_model.train() + base_model.denoiser = unet + return max_val_rp1, min_val_fid + + max_rp1, min_fid = validation() + # validation(ema=True) + + progress_bar = tqdm(range(0, cfg.TRAIN.max_train_steps), desc="Steps") + while True: + for step, batch in enumerate(train_dataloader): + batch = move_batch_to_device(batch, device) + feats_ref = batch["motion"] + text = batch['text'] + mask = batch['mask'] + + # Encode motions to latents + with torch.no_grad(): + latents, _ = vae.encode(feats_ref, mask) + + prompt_embeds = text_encoder(text) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + topk = scheduler.config.num_train_timesteps // cfg.TRAIN.num_ddim_timesteps + index = torch.randint(0, cfg.TRAIN.num_ddim_timesteps, (bsz,), device=latents.device).long() + start_timesteps = solver.ddim_timesteps[index] + timesteps = start_timesteps - topk + timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) + + # Get boundary scalings for start_timesteps and (end) timesteps. + c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] + c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noisy_model_input = scheduler.add_noise(latents, noise, start_timesteps) + + # Sample a random guidance scale w from U[w_min, w_max] and embed it + w = (cfg.TRAIN.w_max - cfg.TRAIN.w_min) * torch.rand((bsz,)) + cfg.TRAIN.w_min + w_embedding = guidance_scale_embedding(w, embedding_dim=cfg.TRAIN.unet_time_cond_proj_dim) + w = append_dims(w, latents.ndim) + # Move to U-Net device and dtype + w = w.to(device=latents.device, dtype=latents.dtype) + w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) + + # Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + noise_pred = unet( + noisy_model_input, + start_timesteps, + timestep_cond=w_embedding, + encoder_hidden_states=prompt_embeds)[0] + + pred_x_0 = predicted_origin( + noise_pred, + start_timesteps, + noisy_model_input, + scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule) + + model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 + + # Use the ODE solver to predict the k-th step in the augmented PF-ODE trajectory after + # noisy_latents with both the conditioning embedding c and unconditional embedding 0 + # Get teacher model prediction on noisy_latents and conditional embedding + with torch.no_grad(): + cond_teacher_output = teacher_unet( + noisy_model_input, + start_timesteps, + encoder_hidden_states=prompt_embeds)[0] + cond_pred_x0 = predicted_origin( + cond_teacher_output, + start_timesteps, + noisy_model_input, + scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule) + + # Get teacher model prediction on noisy_latents and unconditional embedding + uncond_teacher_output = teacher_unet( + noisy_model_input, + start_timesteps, + encoder_hidden_states=uncond_prompt_embeds[:bsz])[0] + uncond_pred_x0 = predicted_origin( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule) + + # Perform "CFG" to get z_prev estimate (using the LCM paper's CFG formulation) + pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) + pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + x_prev = solver.ddim_step(pred_x0, pred_noise, index) + + # Get target LCM prediction on z_prev, w, c, t_n + with torch.no_grad(): + target_noise_pred = target_unet( + x_prev.float(), + timesteps, + timestep_cond=w_embedding, + encoder_hidden_states=prompt_embeds)[0] + pred_x_0 = predicted_origin( + target_noise_pred, + timesteps, + x_prev, + scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule) + target = c_skip * x_prev + c_out * pred_x_0 + + # Calculate loss + if cfg.TRAIN.loss_type == "l2": + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + elif cfg.TRAIN.loss_type == "huber": + loss = torch.mean( + torch.sqrt( + (model_pred.float() - target.float()) ** 2 + cfg.TRAIN.huber_c ** 2) - cfg.TRAIN.huber_c + ) + else: + raise ValueError(f'Unknown loss type: {cfg.TRAIN.loss_type}.') + + # Back propagate on the online student model (`unet`) + loss.backward() + torch.nn.utils.clip_grad_norm_(unet.parameters(), cfg.TRAIN.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Make EMA update to target student model parameters + update_ema(target_unet.parameters(), unet.parameters(), cfg.TRAIN.ema_decay) + progress_bar.update(1) + global_step += 1 + + if global_step % cfg.TRAIN.checkpointing_steps == 0: + save_path = os.path.join(cfg.output_dir, 'checkpoints', f"checkpoint-{global_step}.ckpt") + ckpt = dict(state_dict=base_model.state_dict()) + base_model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % cfg.TRAIN.validation_steps == 0: + cur_rp1, cur_fid = validation() + # validation(ema=True) + if cur_rp1 > max_rp1: + max_rp1 = cur_rp1 + save_path = os.path.join(cfg.output_dir, 'checkpoints', + f"checkpoint-{global_step}-rp1-{round(cur_rp1, 3)}.ckpt") + ckpt = dict(state_dict=base_model.state_dict()) + base_model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with rp1:{round(cur_rp1, 3)}") + + if cur_fid < min_fid: + min_fid = cur_fid + save_path = os.path.join(cfg.output_dir, 'checkpoints', + f"checkpoint-{global_step}-fid-{round(cur_fid, 3)}.ckpt") + ckpt = dict(state_dict=base_model.state_dict()) + base_model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with fid:{round(cur_fid, 3)}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + if cfg.vis == "tb": + writer.add_scalar('loss', logs['loss'], global_step=global_step) + writer.add_scalar('lr', logs['lr'], global_step=global_step) + elif cfg.vis == "swanlab": + writer.log({'loss': logs['loss'], 'lr': logs['lr']}, step=global_step) + + if global_step >= cfg.TRAIN.max_train_steps: + save_path = os.path.join(cfg.output_dir, 'checkpoints', f"checkpoint-last.ckpt") + ckpt = dict(state_dict=base_model.state_dict()) + base_model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + exit(0) + + +if __name__ == "__main__": + main() diff --git a/train_vae.py b/train_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..0b9c557a58af9b3cfd4889ec1a73947e8a7d11ac --- /dev/null +++ b/train_vae.py @@ -0,0 +1,210 @@ +import os +import sys +import logging +import datetime +import os.path as osp + +from tqdm.auto import tqdm +from omegaconf import OmegaConf + +import torch +import swanlab +import diffusers +import transformers +from torch.utils.tensorboard import SummaryWriter +from diffusers.optimization import get_scheduler + +from mld.config import parse_args +from mld.data.get_data import get_dataset +from mld.models.modeltype.vae import VAE +from mld.utils.utils import print_table, set_seed, move_batch_to_device + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def main(): + cfg = parse_args() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + set_seed(cfg.SEED_VALUE) + + name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + cfg.output_dir = osp.join(cfg.FOLDER, name_time_str) + os.makedirs(cfg.output_dir, exist_ok=False) + os.makedirs(f"{cfg.output_dir}/checkpoints", exist_ok=False) + + if cfg.vis == "tb": + writer = SummaryWriter(cfg.output_dir) + elif cfg.vis == "swanlab": + writer = swanlab.init(project="MotionLCM", + experiment_name=os.path.normpath(cfg.output_dir).replace(os.path.sep, "-"), + suffix=None, config=dict(**cfg), logdir=cfg.output_dir) + else: + raise ValueError(f"Invalid vis method: {cfg.vis}") + + stream_handler = logging.StreamHandler(sys.stdout) + file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log')) + handlers = [file_handler, stream_handler] + logging.basicConfig(level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=handlers) + logger = logging.getLogger(__name__) + + OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml')) + + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + + dataset = get_dataset(cfg, motion_only=cfg.TRAIN.get('MOTION_ONLY', False)) + train_dataloader = dataset.train_dataloader() + val_dataloader = dataset.val_dataloader() + dataset = get_dataset(cfg, motion_only=False) + test_dataloader = dataset.test_dataloader() + + model = VAE(cfg, dataset) + model.to(device) + + if cfg.TRAIN.PRETRAINED: + logger.info(f"Loading pre-trained model: {cfg.TRAIN.PRETRAINED}") + state_dict = torch.load(cfg.TRAIN.PRETRAINED, map_location="cpu")["state_dict"] + logger.info(model.load_state_dict(state_dict)) + + logger.info("learning_rate: {}".format(cfg.TRAIN.learning_rate)) + optimizer = torch.optim.AdamW( + model.vae.parameters(), + lr=cfg.TRAIN.learning_rate, + betas=(cfg.TRAIN.adam_beta1, cfg.TRAIN.adam_beta2), + weight_decay=cfg.TRAIN.adam_weight_decay, + eps=cfg.TRAIN.adam_epsilon) + + if cfg.TRAIN.max_train_steps == -1: + assert cfg.TRAIN.max_train_epochs != -1 + cfg.TRAIN.max_train_steps = cfg.TRAIN.max_train_epochs * len(train_dataloader) + + if cfg.TRAIN.checkpointing_steps == -1: + assert cfg.TRAIN.checkpointing_epochs != -1 + cfg.TRAIN.checkpointing_steps = cfg.TRAIN.checkpointing_epochs * len(train_dataloader) + + if cfg.TRAIN.validation_steps == -1: + assert cfg.TRAIN.validation_epochs != -1 + cfg.TRAIN.validation_steps = cfg.TRAIN.validation_epochs * len(train_dataloader) + + lr_scheduler = get_scheduler( + cfg.TRAIN.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.TRAIN.lr_warmup_steps, + num_training_steps=cfg.TRAIN.max_train_steps) + + # Train! + logger.info("***** Running training *****") + logging.info(f" Num examples = {len(train_dataloader.dataset)}") + logging.info(f" Num Epochs = {cfg.TRAIN.max_train_epochs}") + logging.info(f" Instantaneous batch size per device = {cfg.TRAIN.BATCH_SIZE}") + logging.info(f" Total optimization steps = {cfg.TRAIN.max_train_steps}") + + global_step = 0 + + @torch.no_grad() + def validation(): + model.vae.eval() + + val_loss_list = [] + for val_batch in tqdm(val_dataloader): + val_batch = move_batch_to_device(val_batch, device) + val_loss_dict = model.allsplit_step(split='val', batch=val_batch) + val_loss_list.append(val_loss_dict) + + for val_batch in tqdm(test_dataloader): + val_batch = move_batch_to_device(val_batch, device) + model.allsplit_step(split='test', batch=val_batch) + metrics = model.allsplit_epoch_end() + + for loss_k in val_loss_list[0].keys(): + metrics[f"Val/{loss_k}"] = sum([d[loss_k] for d in val_loss_list]).item() / len(val_dataloader) + + max_val_mpjpe = metrics['Metrics/MPJPE'] + min_val_fid = metrics['Metrics/FID'] + print_table(f'Validation@Step-{global_step}', metrics) + for mk, mv in metrics.items(): + if cfg.vis == "tb": + writer.add_scalar(mk, mv, global_step=global_step) + elif cfg.vis == "swanlab": + writer.log({mk: mv}, step=global_step) + + model.vae.train() + return max_val_mpjpe, min_val_fid + + min_mpjpe, min_fid = validation() + + progress_bar = tqdm(range(0, cfg.TRAIN.max_train_steps), desc="Steps") + while True: + for step, batch in enumerate(train_dataloader): + batch = move_batch_to_device(batch, device) + loss_dict = model.allsplit_step('train', batch) + + rec_feats_loss = loss_dict['rec_feats_loss'] + rec_joints_loss = loss_dict['rec_joints_loss'] + rec_velocity_loss = loss_dict['rec_velocity_loss'] + kl_loss = loss_dict['kl_loss'] + loss = loss_dict['loss'] + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.vae.parameters(), cfg.TRAIN.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + progress_bar.update(1) + global_step += 1 + + if global_step % cfg.TRAIN.checkpointing_steps == 0: + save_path = os.path.join(cfg.output_dir, 'checkpoints', f"checkpoint-{global_step}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % cfg.TRAIN.validation_steps == 0: + cur_mpjpe, cur_fid = validation() + if cur_mpjpe < min_mpjpe: + min_mpjpe = cur_mpjpe + save_path = os.path.join(cfg.output_dir, 'checkpoints', + f"checkpoint-{global_step}-mpjpe-{round(cur_mpjpe, 5)}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with mpjpe: {round(cur_mpjpe, 5)}") + + if cur_fid < min_fid: + min_fid = cur_fid + save_path = os.path.join(cfg.output_dir, 'checkpoints', + f"checkpoint-{global_step}-fid-{round(cur_fid, 3)}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with fid: {round(cur_fid, 3)}") + + logs = {"loss": loss.item(), + "lr": lr_scheduler.get_last_lr()[0], + "rec_feats_loss": rec_feats_loss.item(), + 'rec_joints_loss': rec_joints_loss.item(), + 'rec_velocity_loss': rec_velocity_loss.item(), + 'kl_loss': kl_loss.item()} + + progress_bar.set_postfix(**logs) + for k, v in logs.items(): + if cfg.vis == "tb": + writer.add_scalar(f"Train/{k}", v, global_step=global_step) + elif cfg.vis == "swanlab": + writer.log({f"Train/{k}": v}, step=global_step) + + if global_step >= cfg.TRAIN.max_train_steps: + save_path = os.path.join(cfg.output_dir, 'checkpoints', "checkpoint-last.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + exit(0) + + +if __name__ == "__main__": + main()