wxDai
commited on
Commit
•
c64dfa4
0
Parent(s):
[Init]
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +12 -0
- LICENSE +25 -0
- README.md +11 -0
- app.py +234 -0
- configs/mld_t2m.yaml +104 -0
- configs/modules/denoiser.yaml +28 -0
- configs/modules/motion_vae.yaml +18 -0
- configs/modules/noise_optimizer.yaml +15 -0
- configs/modules/scheduler_ddim.yaml +14 -0
- configs/modules/scheduler_lcm.yaml +19 -0
- configs/modules/text_encoder.yaml +5 -0
- configs/modules/traj_encoder.yaml +17 -0
- configs/motionlcm_control_s.yaml +113 -0
- configs/motionlcm_control_t.yaml +111 -0
- configs/motionlcm_t2m.yaml +109 -0
- configs/motionlcm_t2m_clt.yaml +69 -0
- configs/vae.yaml +103 -0
- configs_v1/modules/denoiser.yaml +28 -0
- configs_v1/modules/motion_vae.yaml +18 -0
- configs_v1/modules/scheduler_lcm.yaml +11 -0
- configs_v1/modules/text_encoder.yaml +5 -0
- configs_v1/modules/traj_encoder.yaml +17 -0
- configs_v1/motionlcm_control_t.yaml +114 -0
- configs_v1/motionlcm_t2m.yaml +109 -0
- demo.py +196 -0
- fit.py +136 -0
- mld/__init__.py +0 -0
- mld/config.py +52 -0
- mld/data/__init__.py +0 -0
- mld/data/base.py +58 -0
- mld/data/data.py +73 -0
- mld/data/get_data.py +79 -0
- mld/data/humanml/__init__.py +0 -0
- mld/data/humanml/common/quaternion.py +29 -0
- mld/data/humanml/dataset.py +348 -0
- mld/data/humanml/scripts/motion_process.py +51 -0
- mld/data/humanml/utils/__init__.py +0 -0
- mld/data/humanml/utils/paramUtil.py +62 -0
- mld/data/humanml/utils/plot_script.py +98 -0
- mld/data/humanml/utils/word_vectorizer.py +82 -0
- mld/data/utils.py +52 -0
- mld/launch/__init__.py +0 -0
- mld/launch/blender.py +23 -0
- mld/models/__init__.py +0 -0
- mld/models/architectures/__init__.py +0 -0
- mld/models/architectures/dno.py +79 -0
- mld/models/architectures/mld_clip.py +72 -0
- mld/models/architectures/mld_denoiser.py +200 -0
- mld/models/architectures/mld_traj_encoder.py +64 -0
- mld/models/architectures/mld_vae.py +136 -0
.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**/*.pyc
|
2 |
+
.idea/
|
3 |
+
__pycache__/
|
4 |
+
|
5 |
+
deps/
|
6 |
+
datasets/
|
7 |
+
experiments_t2m/
|
8 |
+
experiments_t2m_test/
|
9 |
+
experiments_control/
|
10 |
+
experiments_control_test/
|
11 |
+
experiments_recons/
|
12 |
+
experiments_recons_test/
|
LICENSE
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright Tsinghua University and Shanghai AI Laboratory. All Rights Reserved.
|
2 |
+
|
3 |
+
License for Non-commercial Scientific Research Purposes.
|
4 |
+
|
5 |
+
For more information see <https://github.com/Dai-Wenxun/MotionLCM>.
|
6 |
+
If you use this software, please cite the corresponding publications
|
7 |
+
listed on the above website.
|
8 |
+
|
9 |
+
Permission to use, copy, modify, and distribute this software and its
|
10 |
+
documentation for educational, research, and non-profit purposes only.
|
11 |
+
Any modification based on this work must be open-source and prohibited
|
12 |
+
for commercial, pornographic, military, or surveillance use.
|
13 |
+
|
14 |
+
The authors grant you a non-exclusive, worldwide, non-transferable,
|
15 |
+
non-sublicensable, revocable, royalty-free, and limited license under
|
16 |
+
our copyright interests to reproduce, distribute, and create derivative
|
17 |
+
works of the text, videos, and codes solely for your non-commercial
|
18 |
+
research purposes.
|
19 |
+
|
20 |
+
You must retain, in the source form of any derivative works that you
|
21 |
+
distribute, all copyright, patent, trademark, and attribution notices
|
22 |
+
from the source form of this work.
|
23 |
+
|
24 |
+
For commercial uses of this software, please send email to all people
|
25 |
+
in the author list.
|
README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: MotionLCM
|
3 |
+
emoji: 🏎️💨
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.44.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
python_version: 3.10.12
|
11 |
+
---
|
app.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
import datetime
|
5 |
+
import os.path as osp
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import tqdm
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
from mld.config import get_module_config
|
15 |
+
from mld.data.get_data import get_dataset
|
16 |
+
from mld.models.modeltype.mld import MLD
|
17 |
+
from mld.utils.utils import set_seed
|
18 |
+
from mld.data.humanml.utils.plot_script import plot_3d_motion
|
19 |
+
|
20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
21 |
+
|
22 |
+
WEBSITE = """
|
23 |
+
<div class="embed_hidden">
|
24 |
+
<h1 style='text-align: center'> MotionLCM: Real-time Controllable Motion Generation via Latent Consistency Model </h1>
|
25 |
+
<h2 style='text-align: center'>
|
26 |
+
<a href="https://github.com/Dai-Wenxun/" target="_blank"><nobr>Wenxun Dai</nobr><sup>1</sup></a>  
|
27 |
+
<a href="https://lhchen.top/" target="_blank"><nobr>Ling-Hao Chen</nobr></a><sup>1</sup>  
|
28 |
+
<a href="https://wangjingbo1219.github.io/" target="_blank"><nobr>Jingbo Wang</nobr></a><sup>2</sup>  
|
29 |
+
<a href="https://moonsliu.github.io/" target="_blank"><nobr>Jinpeng Liu</nobr></a><sup>1</sup>  
|
30 |
+
<a href="https://daibo.info/" target="_blank"><nobr>Bo Dai</nobr></a><sup>2</sup>  
|
31 |
+
<a href="https://andytang15.github.io/" target="_blank"><nobr>Yansong Tang</nobr></a><sup>1</sup>
|
32 |
+
</h2>
|
33 |
+
<h2 style='text-align: center'>
|
34 |
+
<nobr><sup>1</sup>Tsinghua University</nobr>  
|
35 |
+
<nobr><sup>2</sup>Shanghai AI Laboratory</nobr>
|
36 |
+
</h2>
|
37 |
+
</div>
|
38 |
+
"""
|
39 |
+
|
40 |
+
WEBSITE_bottom = """
|
41 |
+
<div class="embed_hidden">
|
42 |
+
<p>
|
43 |
+
Space adapted from <a href="https://huggingface.co/spaces/Mathux/TMR" target="_blank">TMR</a>
|
44 |
+
and <a href="https://huggingface.co/spaces/MeYourHint/MoMask" target="_blank">MoMask</a>.
|
45 |
+
</p>
|
46 |
+
</div>
|
47 |
+
"""
|
48 |
+
|
49 |
+
EXAMPLES = [
|
50 |
+
"a person does a jump",
|
51 |
+
"a person waves both arms in the air.",
|
52 |
+
"The person takes 4 steps backwards.",
|
53 |
+
"this person bends forward as if to bow.",
|
54 |
+
"The person was pushed but did not fall.",
|
55 |
+
"a man walks forward in a snake like pattern.",
|
56 |
+
"a man paces back and forth along the same line.",
|
57 |
+
"with arms out to the sides a person walks forward",
|
58 |
+
"A man bends down and picks something up with his right hand.",
|
59 |
+
"The man walked forward, spun right on one foot and walked back to his original position.",
|
60 |
+
"a person slightly bent over with right hand pressing against the air walks forward slowly"
|
61 |
+
]
|
62 |
+
|
63 |
+
if not os.path.exists("./experiments_t2m/"):
|
64 |
+
os.system("bash prepare/download_pretrained_models.sh")
|
65 |
+
if not os.path.exists('./deps/glove/'):
|
66 |
+
os.system("bash prepare/download_glove.sh")
|
67 |
+
if not os.path.exists('./deps/sentence-t5-large/'):
|
68 |
+
os.system("bash prepare/prepare_t5.sh")
|
69 |
+
if not os.path.exists('./deps/t2m/'):
|
70 |
+
os.system("bash prepare/download_t2m_evaluators.sh")
|
71 |
+
if not os.path.exists('./datasets/humanml3d/'):
|
72 |
+
os.system("bash prepare/prepare_tiny_humanml3d.sh")
|
73 |
+
|
74 |
+
DEFAULT_TEXT = "A person is "
|
75 |
+
MAX_VIDEOS = 8
|
76 |
+
NUM_ROWS = 2
|
77 |
+
NUM_COLS = MAX_VIDEOS // NUM_ROWS
|
78 |
+
EXAMPLES_PER_PAGE = 12
|
79 |
+
T2M_CFG = "./configs_v1/motionlcm_t2m.yaml"
|
80 |
+
|
81 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
82 |
+
print("device: ", device)
|
83 |
+
|
84 |
+
cfg = OmegaConf.load(T2M_CFG)
|
85 |
+
cfg_root = os.path.dirname(T2M_CFG)
|
86 |
+
cfg_model = get_module_config(cfg.model, cfg.model.target, cfg_root)
|
87 |
+
cfg = OmegaConf.merge(cfg, cfg_model)
|
88 |
+
set_seed(cfg.SEED_VALUE)
|
89 |
+
|
90 |
+
name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
91 |
+
cfg.output_dir = osp.join(cfg.TEST_FOLDER, name_time_str)
|
92 |
+
vis_dir = osp.join(cfg.output_dir, 'samples')
|
93 |
+
os.makedirs(cfg.output_dir, exist_ok=False)
|
94 |
+
os.makedirs(vis_dir, exist_ok=False)
|
95 |
+
|
96 |
+
state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
|
97 |
+
print("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
|
98 |
+
|
99 |
+
is_lcm = False
|
100 |
+
lcm_key = 'denoiser.time_embedding.cond_proj.weight' # unique key for CFG
|
101 |
+
if lcm_key in state_dict:
|
102 |
+
is_lcm = True
|
103 |
+
time_cond_proj_dim = state_dict[lcm_key].shape[1]
|
104 |
+
cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim
|
105 |
+
print(f'Is LCM: {is_lcm}')
|
106 |
+
|
107 |
+
dataset = get_dataset(cfg)
|
108 |
+
model = MLD(cfg, dataset)
|
109 |
+
model.to(device)
|
110 |
+
model.eval()
|
111 |
+
model.requires_grad_(False)
|
112 |
+
model.load_state_dict(state_dict)
|
113 |
+
|
114 |
+
FPS = eval(f"cfg.DATASET.{cfg.DATASET.NAME.upper()}.FRAME_RATE")
|
115 |
+
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def generate(text_, motion_len_):
|
119 |
+
batch = {"text": [text_] * MAX_VIDEOS, "length": [motion_len_] * MAX_VIDEOS}
|
120 |
+
|
121 |
+
s = time.time()
|
122 |
+
joints = model(batch)[0]
|
123 |
+
runtime_infer = round(time.time() - s, 3)
|
124 |
+
|
125 |
+
s = time.time()
|
126 |
+
path = []
|
127 |
+
for i in tqdm.tqdm(range(len(joints))):
|
128 |
+
uid = random.randrange(999999999)
|
129 |
+
video_path = osp.join(vis_dir, f"sample_{uid}.mp4")
|
130 |
+
plot_3d_motion(video_path, joints[i].detach().cpu().numpy(), '', fps=FPS)
|
131 |
+
path.append(video_path)
|
132 |
+
runtime_draw = round(time.time() - s, 3)
|
133 |
+
|
134 |
+
runtime_info = f'Inference {len(joints)} motions, Runtime (Inference): {runtime_infer}s, ' \
|
135 |
+
f'Runtime (Draw Skeleton): {runtime_draw}s, device: {device} '
|
136 |
+
|
137 |
+
return path, runtime_info
|
138 |
+
|
139 |
+
|
140 |
+
def generate_component(generate_function, text_, motion_len_, num_inference_steps_, guidance_scale_):
|
141 |
+
if text_ == DEFAULT_TEXT or text_ == "" or text_ is None:
|
142 |
+
return [None] * MAX_VIDEOS + ["Please modify the default text prompt."]
|
143 |
+
|
144 |
+
model.cfg.model.scheduler.num_inference_steps = num_inference_steps_
|
145 |
+
model.guidance_scale = guidance_scale_
|
146 |
+
motion_len_ = max(36, min(int(float(motion_len_) * FPS), 196))
|
147 |
+
paths, info = generate_function(text_, motion_len_)
|
148 |
+
paths = paths + [None] * (MAX_VIDEOS - len(paths))
|
149 |
+
return paths + [info]
|
150 |
+
|
151 |
+
theme = gr.themes.Default(primary_hue="purple", secondary_hue="gray")
|
152 |
+
generate_and_show = partial(generate_component, generate)
|
153 |
+
|
154 |
+
with gr.Blocks(theme=theme) as demo:
|
155 |
+
gr.HTML(WEBSITE)
|
156 |
+
videos = []
|
157 |
+
|
158 |
+
with gr.Row():
|
159 |
+
with gr.Column(scale=3):
|
160 |
+
text = gr.Textbox(
|
161 |
+
show_label=True,
|
162 |
+
label="Text prompt",
|
163 |
+
value=DEFAULT_TEXT,
|
164 |
+
)
|
165 |
+
|
166 |
+
with gr.Row():
|
167 |
+
with gr.Column(scale=1):
|
168 |
+
motion_len = gr.Slider(
|
169 |
+
minimum=1.8,
|
170 |
+
maximum=9.8,
|
171 |
+
step=0.2,
|
172 |
+
value=5.0,
|
173 |
+
label="Motion length",
|
174 |
+
info="Motion duration in seconds: [1.8s, 9.8s] (FPS = 20)."
|
175 |
+
)
|
176 |
+
|
177 |
+
with gr.Column(scale=1):
|
178 |
+
num_inference_steps = gr.Slider(
|
179 |
+
minimum=1,
|
180 |
+
maximum=4,
|
181 |
+
step=1,
|
182 |
+
value=1,
|
183 |
+
label="Inference steps",
|
184 |
+
info="Number of inference steps.",
|
185 |
+
)
|
186 |
+
|
187 |
+
cfg = gr.Slider(
|
188 |
+
minimum=1,
|
189 |
+
maximum=15,
|
190 |
+
step=0.5,
|
191 |
+
value=7.5,
|
192 |
+
label="CFG",
|
193 |
+
info="Classifier-free diffusion guidance.",
|
194 |
+
)
|
195 |
+
|
196 |
+
gen_btn = gr.Button("Generate", variant="primary")
|
197 |
+
clear = gr.Button("Clear", variant="secondary")
|
198 |
+
|
199 |
+
results = gr.Textbox(show_label=True,
|
200 |
+
label='Inference info (runtime and device)',
|
201 |
+
info='Real-time inference cannot be achieved using the free CPU. Local GPU deployment is recommended.',
|
202 |
+
interactive=False)
|
203 |
+
|
204 |
+
with gr.Column(scale=2):
|
205 |
+
examples = gr.Examples(
|
206 |
+
examples=EXAMPLES,
|
207 |
+
inputs=[text],
|
208 |
+
examples_per_page=EXAMPLES_PER_PAGE)
|
209 |
+
|
210 |
+
for i in range(NUM_ROWS):
|
211 |
+
with gr.Row():
|
212 |
+
for j in range(NUM_COLS):
|
213 |
+
video = gr.Video(autoplay=True, loop=True)
|
214 |
+
videos.append(video)
|
215 |
+
|
216 |
+
# gr.HTML(WEBSITE_bottom)
|
217 |
+
|
218 |
+
gen_btn.click(
|
219 |
+
fn=generate_and_show,
|
220 |
+
inputs=[text, motion_len, num_inference_steps, cfg],
|
221 |
+
outputs=videos+[results],
|
222 |
+
)
|
223 |
+
text.submit(
|
224 |
+
fn=generate_and_show,
|
225 |
+
inputs=[text, motion_len, num_inference_steps, cfg],
|
226 |
+
outputs=videos+[results],
|
227 |
+
)
|
228 |
+
|
229 |
+
def clear_videos():
|
230 |
+
return [None] * MAX_VIDEOS + [DEFAULT_TEXT] + [None]
|
231 |
+
|
232 |
+
clear.click(fn=clear_videos, outputs=videos + [text] + [results])
|
233 |
+
|
234 |
+
demo.launch()
|
configs/mld_t2m.yaml
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_t2m'
|
2 |
+
TEST_FOLDER: './experiments_t2m_test'
|
3 |
+
|
4 |
+
NAME: 'mld_humanml'
|
5 |
+
|
6 |
+
SEED_VALUE: 1234
|
7 |
+
|
8 |
+
TRAIN:
|
9 |
+
BATCH_SIZE: 64
|
10 |
+
SPLIT: 'train'
|
11 |
+
NUM_WORKERS: 8
|
12 |
+
PERSISTENT_WORKERS: true
|
13 |
+
|
14 |
+
PRETRAINED: 'experiments_recons/vae_humanml/vae_humanml.ckpt'
|
15 |
+
|
16 |
+
validation_steps: -1
|
17 |
+
validation_epochs: 50
|
18 |
+
checkpointing_steps: -1
|
19 |
+
checkpointing_epochs: 50
|
20 |
+
max_train_steps: -1
|
21 |
+
max_train_epochs: 3000
|
22 |
+
learning_rate: 1e-4
|
23 |
+
lr_scheduler: "cosine"
|
24 |
+
lr_warmup_steps: 1000
|
25 |
+
adam_beta1: 0.9
|
26 |
+
adam_beta2: 0.999
|
27 |
+
adam_weight_decay: 0.0
|
28 |
+
adam_epsilon: 1e-08
|
29 |
+
max_grad_norm: 1.0
|
30 |
+
model_ema: false
|
31 |
+
model_ema_steps: 32
|
32 |
+
model_ema_decay: 0.999
|
33 |
+
|
34 |
+
VAL:
|
35 |
+
BATCH_SIZE: 32
|
36 |
+
SPLIT: 'test'
|
37 |
+
NUM_WORKERS: 12
|
38 |
+
PERSISTENT_WORKERS: true
|
39 |
+
|
40 |
+
TEST:
|
41 |
+
BATCH_SIZE: 32
|
42 |
+
SPLIT: 'test'
|
43 |
+
NUM_WORKERS: 12
|
44 |
+
PERSISTENT_WORKERS: true
|
45 |
+
|
46 |
+
CHECKPOINTS: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
|
47 |
+
|
48 |
+
# Testing Args
|
49 |
+
REPLICATION_TIMES: 20
|
50 |
+
MM_NUM_SAMPLES: 100
|
51 |
+
MM_NUM_REPEATS: 30
|
52 |
+
MM_NUM_TIMES: 10
|
53 |
+
DIVERSITY_TIMES: 300
|
54 |
+
DO_MM_TEST: true
|
55 |
+
|
56 |
+
DATASET:
|
57 |
+
NAME: 'humanml3d'
|
58 |
+
SMPL_PATH: './deps/smpl'
|
59 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
60 |
+
HUMANML3D:
|
61 |
+
FRAME_RATE: 20.0
|
62 |
+
UNIT_LEN: 4
|
63 |
+
ROOT: './datasets/humanml3d'
|
64 |
+
CONTROL_ARGS:
|
65 |
+
CONTROL: false
|
66 |
+
TEMPORAL: false
|
67 |
+
TRAIN_JOINTS: [0]
|
68 |
+
TEST_JOINTS: [0]
|
69 |
+
TRAIN_DENSITY: 'random'
|
70 |
+
TEST_DENSITY: 100
|
71 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
72 |
+
SAMPLER:
|
73 |
+
MAX_LEN: 200
|
74 |
+
MIN_LEN: 40
|
75 |
+
MAX_TEXT_LEN: 20
|
76 |
+
PADDING_TO_MAX: false
|
77 |
+
WINDOW_SIZE: null
|
78 |
+
|
79 |
+
METRIC:
|
80 |
+
DIST_SYNC_ON_STEP: true
|
81 |
+
TYPE: ['TM2TMetrics']
|
82 |
+
|
83 |
+
model:
|
84 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_ddim', 'noise_optimizer']
|
85 |
+
latent_dim: [16, 32]
|
86 |
+
guidance_scale: 7.5
|
87 |
+
guidance_uncondp: 0.1
|
88 |
+
|
89 |
+
t2m_textencoder:
|
90 |
+
dim_word: 300
|
91 |
+
dim_pos_ohot: 15
|
92 |
+
dim_text_hidden: 512
|
93 |
+
dim_coemb_hidden: 512
|
94 |
+
|
95 |
+
t2m_motionencoder:
|
96 |
+
dim_move_hidden: 512
|
97 |
+
dim_move_latent: 512
|
98 |
+
dim_motion_hidden: 1024
|
99 |
+
dim_motion_latent: 512
|
100 |
+
|
101 |
+
bert_path: './deps/distilbert-base-uncased'
|
102 |
+
clip_path: './deps/clip-vit-large-patch14'
|
103 |
+
t5_path: './deps/sentence-t5-large'
|
104 |
+
t2m_path: './deps/t2m/'
|
configs/modules/denoiser.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
denoiser:
|
2 |
+
target: mld.models.architectures.mld_denoiser.MldDenoiser
|
3 |
+
params:
|
4 |
+
latent_dim: ${model.latent_dim}
|
5 |
+
hidden_dim: 256
|
6 |
+
text_dim: 768
|
7 |
+
time_dim: 768
|
8 |
+
ff_size: 1024
|
9 |
+
num_layers: 9
|
10 |
+
num_heads: 4
|
11 |
+
dropout: 0.1
|
12 |
+
normalize_before: false
|
13 |
+
norm_eps: 1e-5
|
14 |
+
activation: 'gelu'
|
15 |
+
norm_post: true
|
16 |
+
activation_post: null
|
17 |
+
flip_sin_to_cos: true
|
18 |
+
freq_shift: 0
|
19 |
+
time_act_fn: 'silu'
|
20 |
+
time_post_act_fn: null
|
21 |
+
position_embedding: 'learned'
|
22 |
+
arch: 'trans_enc'
|
23 |
+
add_mem_pos: true
|
24 |
+
force_pre_post_proj: true
|
25 |
+
text_act_fn: null
|
26 |
+
zero_init_cond: true
|
27 |
+
controlnet_embed_dim: 256
|
28 |
+
controlnet_act_fn: 'silu'
|
configs/modules/motion_vae.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
motion_vae:
|
2 |
+
target: mld.models.architectures.mld_vae.MldVae
|
3 |
+
params:
|
4 |
+
nfeats: ${DATASET.NFEATS}
|
5 |
+
latent_dim: ${model.latent_dim}
|
6 |
+
hidden_dim: 256
|
7 |
+
force_pre_post_proj: true
|
8 |
+
ff_size: 1024
|
9 |
+
num_layers: 9
|
10 |
+
num_heads: 4
|
11 |
+
dropout: 0.1
|
12 |
+
arch: 'encoder_decoder'
|
13 |
+
normalize_before: false
|
14 |
+
norm_eps: 1e-5
|
15 |
+
activation: 'gelu'
|
16 |
+
norm_post: true
|
17 |
+
activation_post: null
|
18 |
+
position_embedding: 'learned'
|
configs/modules/noise_optimizer.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
noise_optimizer:
|
2 |
+
target: mld.models.architectures.dno.DNO
|
3 |
+
params:
|
4 |
+
optimize: false
|
5 |
+
max_train_steps: 400
|
6 |
+
learning_rate: 0.1
|
7 |
+
lr_scheduler: 'cosine'
|
8 |
+
lr_warmup_steps: 50
|
9 |
+
clip_grad: true
|
10 |
+
loss_hint_type: 'l2'
|
11 |
+
loss_diff_penalty: 0.000
|
12 |
+
loss_correlate_penalty: 100
|
13 |
+
visualize_samples: 0
|
14 |
+
visualize_ske_steps: []
|
15 |
+
output_dir: ${output_dir}
|
configs/modules/scheduler_ddim.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scheduler:
|
2 |
+
target: diffusers.DDIMScheduler
|
3 |
+
num_inference_steps: 50
|
4 |
+
eta: 0.0
|
5 |
+
params:
|
6 |
+
num_train_timesteps: 1000
|
7 |
+
beta_start: 0.00085
|
8 |
+
beta_end: 0.012
|
9 |
+
beta_schedule: 'scaled_linear'
|
10 |
+
prediction_type: 'epsilon'
|
11 |
+
clip_sample: false
|
12 |
+
# below are for ddim
|
13 |
+
set_alpha_to_one: false
|
14 |
+
steps_offset: 1
|
configs/modules/scheduler_lcm.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scheduler:
|
2 |
+
target: mld.models.schedulers.scheduling_lcm.LCMScheduler
|
3 |
+
num_inference_steps: 1
|
4 |
+
cfg_step_map:
|
5 |
+
1: 8.0
|
6 |
+
2: 12.5
|
7 |
+
4: 13.5
|
8 |
+
params:
|
9 |
+
num_train_timesteps: 1000
|
10 |
+
beta_start: 0.00085
|
11 |
+
beta_end: 0.012
|
12 |
+
beta_schedule: 'scaled_linear'
|
13 |
+
clip_sample: false
|
14 |
+
set_alpha_to_one: false
|
15 |
+
original_inference_steps: 10
|
16 |
+
timesteps_step_map:
|
17 |
+
1: [799]
|
18 |
+
2: [699, 299]
|
19 |
+
4: [699, 399, 299, 299]
|
configs/modules/text_encoder.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
text_encoder:
|
2 |
+
target: mld.models.architectures.mld_clip.MldTextEncoder
|
3 |
+
params:
|
4 |
+
last_hidden_state: false
|
5 |
+
modelpath: ${model.t5_path}
|
configs/modules/traj_encoder.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
traj_encoder:
|
2 |
+
target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder
|
3 |
+
params:
|
4 |
+
nfeats: ${DATASET.NJOINTS}
|
5 |
+
latent_dim: ${model.latent_dim}
|
6 |
+
hidden_dim: 256
|
7 |
+
force_post_proj: true
|
8 |
+
ff_size: 1024
|
9 |
+
num_layers: 9
|
10 |
+
num_heads: 4
|
11 |
+
dropout: 0.1
|
12 |
+
normalize_before: false
|
13 |
+
norm_eps: 1e-5
|
14 |
+
activation: 'gelu'
|
15 |
+
norm_post: true
|
16 |
+
activation_post: null
|
17 |
+
position_embedding: 'learned'
|
configs/motionlcm_control_s.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_control/spatial'
|
2 |
+
TEST_FOLDER: './experiments_control_test/spatial'
|
3 |
+
|
4 |
+
NAME: 'motionlcm_humanml'
|
5 |
+
|
6 |
+
SEED_VALUE: 1234
|
7 |
+
|
8 |
+
TRAIN:
|
9 |
+
DATASET: 'humanml3d'
|
10 |
+
BATCH_SIZE: 128
|
11 |
+
SPLIT: 'train'
|
12 |
+
NUM_WORKERS: 8
|
13 |
+
PERSISTENT_WORKERS: true
|
14 |
+
|
15 |
+
PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
|
16 |
+
|
17 |
+
validation_steps: -1
|
18 |
+
validation_epochs: 50
|
19 |
+
checkpointing_steps: -1
|
20 |
+
checkpointing_epochs: 50
|
21 |
+
max_train_steps: -1
|
22 |
+
max_train_epochs: 1000
|
23 |
+
learning_rate: 1e-4
|
24 |
+
learning_rate_spatial: 1e-4
|
25 |
+
lr_scheduler: "cosine"
|
26 |
+
lr_warmup_steps: 1000
|
27 |
+
adam_beta1: 0.9
|
28 |
+
adam_beta2: 0.999
|
29 |
+
adam_weight_decay: 0.0
|
30 |
+
adam_epsilon: 1e-08
|
31 |
+
max_grad_norm: 1.0
|
32 |
+
|
33 |
+
VAL:
|
34 |
+
DATASET: 'humanml3d'
|
35 |
+
BATCH_SIZE: 32
|
36 |
+
SPLIT: 'test'
|
37 |
+
NUM_WORKERS: 12
|
38 |
+
PERSISTENT_WORKERS: true
|
39 |
+
|
40 |
+
TEST:
|
41 |
+
DATASET: 'humanml3d'
|
42 |
+
BATCH_SIZE: 32
|
43 |
+
SPLIT: 'test'
|
44 |
+
NUM_WORKERS: 12
|
45 |
+
PERSISTENT_WORKERS: true
|
46 |
+
|
47 |
+
CHECKPOINTS: 'experiments_control/spatial/motionlcm_humanml/motionlcm_humanml_s_pelvis.ckpt'
|
48 |
+
# CHECKPOINTS: 'experiments_control/spatial/motionlcm_humanml/motionlcm_humanml_s_all.ckpt'
|
49 |
+
|
50 |
+
# Testing Args
|
51 |
+
REPLICATION_TIMES: 1
|
52 |
+
DIVERSITY_TIMES: 300
|
53 |
+
DO_MM_TEST: false
|
54 |
+
MAX_NUM_SAMPLES: 1024
|
55 |
+
|
56 |
+
DATASET:
|
57 |
+
NAME: 'humanml3d'
|
58 |
+
SMPL_PATH: './deps/smpl'
|
59 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
60 |
+
HUMANML3D:
|
61 |
+
FRAME_RATE: 20.0
|
62 |
+
UNIT_LEN: 4
|
63 |
+
ROOT: './datasets/humanml3d'
|
64 |
+
CONTROL_ARGS:
|
65 |
+
CONTROL: true
|
66 |
+
TEMPORAL: false
|
67 |
+
TRAIN_JOINTS: [0]
|
68 |
+
TEST_JOINTS: [0]
|
69 |
+
TRAIN_DENSITY: 'random'
|
70 |
+
TEST_DENSITY: 100
|
71 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
72 |
+
SAMPLER:
|
73 |
+
MAX_LEN: 200
|
74 |
+
MIN_LEN: 40
|
75 |
+
MAX_TEXT_LEN: 20
|
76 |
+
PADDING_TO_MAX: false
|
77 |
+
WINDOW_SIZE: null
|
78 |
+
|
79 |
+
METRIC:
|
80 |
+
DIST_SYNC_ON_STEP: true
|
81 |
+
TYPE: ['TM2TMetrics', 'ControlMetrics']
|
82 |
+
|
83 |
+
model:
|
84 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder', 'noise_optimizer']
|
85 |
+
latent_dim: [16, 32]
|
86 |
+
guidance_scale: 'dynamic'
|
87 |
+
|
88 |
+
# ControlNet Args
|
89 |
+
is_controlnet: true
|
90 |
+
vaeloss: true
|
91 |
+
vaeloss_type: 'mask'
|
92 |
+
cond_ratio: 1.0
|
93 |
+
control_loss_func: 'l1_smooth'
|
94 |
+
use_3d: true
|
95 |
+
lcm_w_min_nax: [5, 15]
|
96 |
+
lcm_num_ddim_timesteps: 10
|
97 |
+
|
98 |
+
t2m_textencoder:
|
99 |
+
dim_word: 300
|
100 |
+
dim_pos_ohot: 15
|
101 |
+
dim_text_hidden: 512
|
102 |
+
dim_coemb_hidden: 512
|
103 |
+
|
104 |
+
t2m_motionencoder:
|
105 |
+
dim_move_hidden: 512
|
106 |
+
dim_move_latent: 512
|
107 |
+
dim_motion_hidden: 1024
|
108 |
+
dim_motion_latent: 512
|
109 |
+
|
110 |
+
bert_path: './deps/distilbert-base-uncased'
|
111 |
+
clip_path: './deps/clip-vit-large-patch14'
|
112 |
+
t5_path: './deps/sentence-t5-large'
|
113 |
+
t2m_path: './deps/t2m/'
|
configs/motionlcm_control_t.yaml
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_control/temporal'
|
2 |
+
TEST_FOLDER: './experiments_control_test/temporal'
|
3 |
+
|
4 |
+
NAME: 'motionlcm_humanml'
|
5 |
+
|
6 |
+
SEED_VALUE: 1234
|
7 |
+
|
8 |
+
TRAIN:
|
9 |
+
DATASET: 'humanml3d'
|
10 |
+
BATCH_SIZE: 128
|
11 |
+
SPLIT: 'train'
|
12 |
+
NUM_WORKERS: 8
|
13 |
+
PERSISTENT_WORKERS: true
|
14 |
+
|
15 |
+
PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
|
16 |
+
|
17 |
+
validation_steps: -1
|
18 |
+
validation_epochs: 50
|
19 |
+
checkpointing_steps: -1
|
20 |
+
checkpointing_epochs: 50
|
21 |
+
max_train_steps: -1
|
22 |
+
max_train_epochs: 1000
|
23 |
+
learning_rate: 1e-4
|
24 |
+
learning_rate_spatial: 1e-4
|
25 |
+
lr_scheduler: "cosine"
|
26 |
+
lr_warmup_steps: 1000
|
27 |
+
adam_beta1: 0.9
|
28 |
+
adam_beta2: 0.999
|
29 |
+
adam_weight_decay: 0.0
|
30 |
+
adam_epsilon: 1e-08
|
31 |
+
max_grad_norm: 1.0
|
32 |
+
|
33 |
+
VAL:
|
34 |
+
DATASET: 'humanml3d'
|
35 |
+
BATCH_SIZE: 32
|
36 |
+
SPLIT: 'test'
|
37 |
+
NUM_WORKERS: 12
|
38 |
+
PERSISTENT_WORKERS: true
|
39 |
+
|
40 |
+
TEST:
|
41 |
+
DATASET: 'humanml3d'
|
42 |
+
BATCH_SIZE: 32
|
43 |
+
SPLIT: 'test'
|
44 |
+
NUM_WORKERS: 12
|
45 |
+
PERSISTENT_WORKERS: true
|
46 |
+
|
47 |
+
CHECKPOINTS: 'experiments_control/temporal/motionlcm_humanml/motionlcm_humanml_t.ckpt'
|
48 |
+
|
49 |
+
# Testing Args
|
50 |
+
REPLICATION_TIMES: 20
|
51 |
+
DIVERSITY_TIMES: 300
|
52 |
+
DO_MM_TEST: false
|
53 |
+
|
54 |
+
DATASET:
|
55 |
+
NAME: 'humanml3d'
|
56 |
+
SMPL_PATH: './deps/smpl'
|
57 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
58 |
+
HUMANML3D:
|
59 |
+
FRAME_RATE: 20.0
|
60 |
+
UNIT_LEN: 4
|
61 |
+
ROOT: './datasets/humanml3d'
|
62 |
+
CONTROL_ARGS:
|
63 |
+
CONTROL: true
|
64 |
+
TEMPORAL: true
|
65 |
+
TRAIN_JOINTS: [0, 10, 11, 15, 20, 21]
|
66 |
+
TEST_JOINTS: [0, 10, 11, 15, 20, 21]
|
67 |
+
TRAIN_DENSITY: [25, 25]
|
68 |
+
TEST_DENSITY: 25
|
69 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
70 |
+
SAMPLER:
|
71 |
+
MAX_LEN: 200
|
72 |
+
MIN_LEN: 40
|
73 |
+
MAX_TEXT_LEN: 20
|
74 |
+
PADDING_TO_MAX: false
|
75 |
+
WINDOW_SIZE: null
|
76 |
+
|
77 |
+
METRIC:
|
78 |
+
DIST_SYNC_ON_STEP: true
|
79 |
+
TYPE: ['TM2TMetrics', 'ControlMetrics']
|
80 |
+
|
81 |
+
model:
|
82 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder', 'noise_optimizer']
|
83 |
+
latent_dim: [16, 32]
|
84 |
+
guidance_scale: 'dynamic'
|
85 |
+
|
86 |
+
# ControlNet Args
|
87 |
+
is_controlnet: true
|
88 |
+
vaeloss: true
|
89 |
+
vaeloss_type: 'sum'
|
90 |
+
cond_ratio: 1.0
|
91 |
+
control_loss_func: 'l2'
|
92 |
+
use_3d: false
|
93 |
+
lcm_w_min_nax: [5, 15]
|
94 |
+
lcm_num_ddim_timesteps: 10
|
95 |
+
|
96 |
+
t2m_textencoder:
|
97 |
+
dim_word: 300
|
98 |
+
dim_pos_ohot: 15
|
99 |
+
dim_text_hidden: 512
|
100 |
+
dim_coemb_hidden: 512
|
101 |
+
|
102 |
+
t2m_motionencoder:
|
103 |
+
dim_move_hidden: 512
|
104 |
+
dim_move_latent: 512
|
105 |
+
dim_motion_hidden: 1024
|
106 |
+
dim_motion_latent: 512
|
107 |
+
|
108 |
+
bert_path: './deps/distilbert-base-uncased'
|
109 |
+
clip_path: './deps/clip-vit-large-patch14'
|
110 |
+
t5_path: './deps/sentence-t5-large'
|
111 |
+
t2m_path: './deps/t2m/'
|
configs/motionlcm_t2m.yaml
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_t2m'
|
2 |
+
TEST_FOLDER: './experiments_t2m_test'
|
3 |
+
|
4 |
+
NAME: 'motionlcm_humanml'
|
5 |
+
|
6 |
+
SEED_VALUE: 1234
|
7 |
+
|
8 |
+
TRAIN:
|
9 |
+
BATCH_SIZE: 128
|
10 |
+
SPLIT: 'train'
|
11 |
+
NUM_WORKERS: 8
|
12 |
+
PERSISTENT_WORKERS: true
|
13 |
+
|
14 |
+
PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
|
15 |
+
|
16 |
+
validation_steps: -1
|
17 |
+
validation_epochs: 50
|
18 |
+
checkpointing_steps: -1
|
19 |
+
checkpointing_epochs: 50
|
20 |
+
max_train_steps: -1
|
21 |
+
max_train_epochs: 1000
|
22 |
+
learning_rate: 2e-4
|
23 |
+
lr_scheduler: "cosine"
|
24 |
+
lr_warmup_steps: 1000
|
25 |
+
adam_beta1: 0.9
|
26 |
+
adam_beta2: 0.999
|
27 |
+
adam_weight_decay: 0.0
|
28 |
+
adam_epsilon: 1e-08
|
29 |
+
max_grad_norm: 1.0
|
30 |
+
|
31 |
+
# Latent Consistency Distillation Specific Arguments
|
32 |
+
w_min: 5.0
|
33 |
+
w_max: 15.0
|
34 |
+
num_ddim_timesteps: 10
|
35 |
+
loss_type: 'huber'
|
36 |
+
huber_c: 0.5
|
37 |
+
unet_time_cond_proj_dim: 256
|
38 |
+
ema_decay: 0.95
|
39 |
+
|
40 |
+
VAL:
|
41 |
+
BATCH_SIZE: 32
|
42 |
+
SPLIT: 'test'
|
43 |
+
NUM_WORKERS: 12
|
44 |
+
PERSISTENT_WORKERS: true
|
45 |
+
|
46 |
+
TEST:
|
47 |
+
BATCH_SIZE: 32
|
48 |
+
SPLIT: 'test'
|
49 |
+
NUM_WORKERS: 12
|
50 |
+
PERSISTENT_WORKERS: true
|
51 |
+
|
52 |
+
CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
|
53 |
+
|
54 |
+
# Testing Args
|
55 |
+
REPLICATION_TIMES: 20
|
56 |
+
MM_NUM_SAMPLES: 100
|
57 |
+
MM_NUM_REPEATS: 30
|
58 |
+
MM_NUM_TIMES: 10
|
59 |
+
DIVERSITY_TIMES: 300
|
60 |
+
DO_MM_TEST: true
|
61 |
+
|
62 |
+
DATASET:
|
63 |
+
NAME: 'humanml3d'
|
64 |
+
SMPL_PATH: './deps/smpl'
|
65 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
66 |
+
HUMANML3D:
|
67 |
+
FRAME_RATE: 20.0
|
68 |
+
UNIT_LEN: 4
|
69 |
+
ROOT: './datasets/humanml3d'
|
70 |
+
CONTROL_ARGS:
|
71 |
+
CONTROL: false
|
72 |
+
TEMPORAL: false
|
73 |
+
TRAIN_JOINTS: [0]
|
74 |
+
TEST_JOINTS: [0]
|
75 |
+
TRAIN_DENSITY: 'random'
|
76 |
+
TEST_DENSITY: 100
|
77 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
78 |
+
SAMPLER:
|
79 |
+
MAX_LEN: 200
|
80 |
+
MIN_LEN: 40
|
81 |
+
MAX_TEXT_LEN: 20
|
82 |
+
PADDING_TO_MAX: false
|
83 |
+
WINDOW_SIZE: null
|
84 |
+
|
85 |
+
METRIC:
|
86 |
+
DIST_SYNC_ON_STEP: true
|
87 |
+
TYPE: ['TM2TMetrics']
|
88 |
+
|
89 |
+
model:
|
90 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'noise_optimizer']
|
91 |
+
latent_dim: [16, 32]
|
92 |
+
guidance_scale: 'dynamic'
|
93 |
+
|
94 |
+
t2m_textencoder:
|
95 |
+
dim_word: 300
|
96 |
+
dim_pos_ohot: 15
|
97 |
+
dim_text_hidden: 512
|
98 |
+
dim_coemb_hidden: 512
|
99 |
+
|
100 |
+
t2m_motionencoder:
|
101 |
+
dim_move_hidden: 512
|
102 |
+
dim_move_latent: 512
|
103 |
+
dim_motion_hidden: 1024
|
104 |
+
dim_motion_latent: 512
|
105 |
+
|
106 |
+
bert_path: './deps/distilbert-base-uncased'
|
107 |
+
clip_path: './deps/clip-vit-large-patch14'
|
108 |
+
t5_path: './deps/sentence-t5-large'
|
109 |
+
t2m_path: './deps/t2m/'
|
configs/motionlcm_t2m_clt.yaml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_t2m'
|
2 |
+
TEST_FOLDER: './experiments_t2m_test'
|
3 |
+
|
4 |
+
NAME: 'motionlcm_humanml'
|
5 |
+
|
6 |
+
SEED_VALUE: 1234
|
7 |
+
|
8 |
+
TEST:
|
9 |
+
BATCH_SIZE: 1
|
10 |
+
SPLIT: 'test'
|
11 |
+
NUM_WORKERS: 12
|
12 |
+
PERSISTENT_WORKERS: true
|
13 |
+
|
14 |
+
CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
|
15 |
+
|
16 |
+
# Testing Args
|
17 |
+
REPLICATION_TIMES: 1
|
18 |
+
DIVERSITY_TIMES: 300
|
19 |
+
DO_MM_TEST: false
|
20 |
+
MAX_NUM_SAMPLES: 1024
|
21 |
+
|
22 |
+
DATASET:
|
23 |
+
NAME: 'humanml3d'
|
24 |
+
SMPL_PATH: './deps/smpl'
|
25 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
26 |
+
HUMANML3D:
|
27 |
+
FRAME_RATE: 20.0
|
28 |
+
UNIT_LEN: 4
|
29 |
+
ROOT: './datasets/humanml3d'
|
30 |
+
CONTROL_ARGS:
|
31 |
+
CONTROL: true
|
32 |
+
TEMPORAL: false
|
33 |
+
TRAIN_JOINTS: [0]
|
34 |
+
TEST_JOINTS: [0]
|
35 |
+
TRAIN_DENSITY: 'random'
|
36 |
+
TEST_DENSITY: 100
|
37 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
38 |
+
SAMPLER:
|
39 |
+
MAX_LEN: 200
|
40 |
+
MIN_LEN: 40
|
41 |
+
MAX_TEXT_LEN: 20
|
42 |
+
PADDING_TO_MAX: false
|
43 |
+
WINDOW_SIZE: null
|
44 |
+
|
45 |
+
METRIC:
|
46 |
+
DIST_SYNC_ON_STEP: true
|
47 |
+
TYPE: ['TM2TMetrics', 'ControlMetrics']
|
48 |
+
|
49 |
+
model:
|
50 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'noise_optimizer']
|
51 |
+
latent_dim: [16, 32]
|
52 |
+
guidance_scale: 'dynamic'
|
53 |
+
|
54 |
+
t2m_textencoder:
|
55 |
+
dim_word: 300
|
56 |
+
dim_pos_ohot: 15
|
57 |
+
dim_text_hidden: 512
|
58 |
+
dim_coemb_hidden: 512
|
59 |
+
|
60 |
+
t2m_motionencoder:
|
61 |
+
dim_move_hidden: 512
|
62 |
+
dim_move_latent: 512
|
63 |
+
dim_motion_hidden: 1024
|
64 |
+
dim_motion_latent: 512
|
65 |
+
|
66 |
+
bert_path: './deps/distilbert-base-uncased'
|
67 |
+
clip_path: './deps/clip-vit-large-patch14'
|
68 |
+
t5_path: './deps/sentence-t5-large'
|
69 |
+
t2m_path: './deps/t2m/'
|
configs/vae.yaml
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_recons'
|
2 |
+
TEST_FOLDER: './experiments_recons_test'
|
3 |
+
|
4 |
+
NAME: 'vae_humanml'
|
5 |
+
|
6 |
+
SEED_VALUE: 1234
|
7 |
+
|
8 |
+
TRAIN:
|
9 |
+
BATCH_SIZE: 128
|
10 |
+
SPLIT: 'train'
|
11 |
+
NUM_WORKERS: 8
|
12 |
+
PERSISTENT_WORKERS: true
|
13 |
+
PRETRAINED: ''
|
14 |
+
|
15 |
+
validation_steps: -1
|
16 |
+
validation_epochs: 100
|
17 |
+
checkpointing_steps: -1
|
18 |
+
checkpointing_epochs: 100
|
19 |
+
max_train_steps: -1
|
20 |
+
max_train_epochs: 6000
|
21 |
+
learning_rate: 2e-4
|
22 |
+
lr_scheduler: "cosine"
|
23 |
+
lr_warmup_steps: 1000
|
24 |
+
adam_beta1: 0.9
|
25 |
+
adam_beta2: 0.999
|
26 |
+
adam_weight_decay: 0.0
|
27 |
+
adam_epsilon: 1e-08
|
28 |
+
max_grad_norm: 1.0
|
29 |
+
|
30 |
+
VAL:
|
31 |
+
BATCH_SIZE: 32
|
32 |
+
SPLIT: 'test'
|
33 |
+
NUM_WORKERS: 12
|
34 |
+
PERSISTENT_WORKERS: true
|
35 |
+
|
36 |
+
TEST:
|
37 |
+
BATCH_SIZE: 32
|
38 |
+
SPLIT: 'test'
|
39 |
+
NUM_WORKERS: 12
|
40 |
+
PERSISTENT_WORKERS: true
|
41 |
+
|
42 |
+
CHECKPOINTS: 'experiments_recons/vae_humanml/vae_humanml.ckpt'
|
43 |
+
|
44 |
+
# Testing Args
|
45 |
+
REPLICATION_TIMES: 20
|
46 |
+
DIVERSITY_TIMES: 300
|
47 |
+
DO_MM_TEST: false
|
48 |
+
|
49 |
+
DATASET:
|
50 |
+
NAME: 'humanml3d'
|
51 |
+
SMPL_PATH: './deps/smpl'
|
52 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
53 |
+
HUMANML3D:
|
54 |
+
FRAME_RATE: 20.0
|
55 |
+
UNIT_LEN: 4
|
56 |
+
ROOT: './datasets/humanml3d'
|
57 |
+
CONTROL_ARGS:
|
58 |
+
CONTROL: false
|
59 |
+
TEMPORAL: false
|
60 |
+
TRAIN_JOINTS: [0]
|
61 |
+
TEST_JOINTS: [0]
|
62 |
+
TRAIN_DENSITY: 'random'
|
63 |
+
TEST_DESITY: 100
|
64 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
65 |
+
SAMPLER:
|
66 |
+
MAX_LEN: 200
|
67 |
+
MIN_LEN: 40
|
68 |
+
MAX_TEXT_LEN: 20
|
69 |
+
PADDING_TO_MAX: true
|
70 |
+
WINDOW_SIZE: 64
|
71 |
+
|
72 |
+
METRIC:
|
73 |
+
DIST_SYNC_ON_STEP: true
|
74 |
+
TYPE: ['TM2TMetrics', "PosMetrics"]
|
75 |
+
|
76 |
+
model:
|
77 |
+
target: ['motion_vae']
|
78 |
+
latent_dim: [16, 32]
|
79 |
+
|
80 |
+
# VAE Args
|
81 |
+
rec_feats_ratio: 1.0
|
82 |
+
rec_joints_ratio: 1.0
|
83 |
+
rec_velocity_ratio: 0.0
|
84 |
+
kl_ratio: 1e-4
|
85 |
+
|
86 |
+
rec_feats_loss: 'l1_smooth'
|
87 |
+
rec_joints_loss: 'l1_smooth'
|
88 |
+
rec_velocity_loss: 'l1_smooth'
|
89 |
+
mask_loss: true
|
90 |
+
|
91 |
+
t2m_textencoder:
|
92 |
+
dim_word: 300
|
93 |
+
dim_pos_ohot: 15
|
94 |
+
dim_text_hidden: 512
|
95 |
+
dim_coemb_hidden: 512
|
96 |
+
|
97 |
+
t2m_motionencoder:
|
98 |
+
dim_move_hidden: 512
|
99 |
+
dim_move_latent: 512
|
100 |
+
dim_motion_hidden: 1024
|
101 |
+
dim_motion_latent: 512
|
102 |
+
|
103 |
+
t2m_path: './deps/t2m/'
|
configs_v1/modules/denoiser.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
denoiser:
|
2 |
+
target: mld.models.architectures.mld_denoiser.MldDenoiser
|
3 |
+
params:
|
4 |
+
latent_dim: ${model.latent_dim}
|
5 |
+
hidden_dim: null
|
6 |
+
text_dim: 768
|
7 |
+
time_dim: 768
|
8 |
+
ff_size: 1024
|
9 |
+
num_layers: 9
|
10 |
+
num_heads: 4
|
11 |
+
dropout: 0.1
|
12 |
+
normalize_before: false
|
13 |
+
norm_eps: 1e-5
|
14 |
+
activation: 'gelu'
|
15 |
+
norm_post: true
|
16 |
+
activation_post: null
|
17 |
+
flip_sin_to_cos: true
|
18 |
+
freq_shift: 0
|
19 |
+
time_act_fn: 'silu'
|
20 |
+
time_post_act_fn: null
|
21 |
+
position_embedding: 'learned'
|
22 |
+
arch: 'trans_enc'
|
23 |
+
add_mem_pos: true
|
24 |
+
force_pre_post_proj: false
|
25 |
+
text_act_fn: 'relu'
|
26 |
+
zero_init_cond: true
|
27 |
+
controlnet_embed_dim: 256
|
28 |
+
controlnet_act_fn: null
|
configs_v1/modules/motion_vae.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
motion_vae:
|
2 |
+
target: mld.models.architectures.mld_vae.MldVae
|
3 |
+
params:
|
4 |
+
nfeats: ${DATASET.NFEATS}
|
5 |
+
latent_dim: ${model.latent_dim}
|
6 |
+
hidden_dim: null
|
7 |
+
force_pre_post_proj: false
|
8 |
+
ff_size: 1024
|
9 |
+
num_layers: 9
|
10 |
+
num_heads: 4
|
11 |
+
dropout: 0.1
|
12 |
+
arch: 'encoder_decoder'
|
13 |
+
normalize_before: false
|
14 |
+
norm_eps: 1e-5
|
15 |
+
activation: 'gelu'
|
16 |
+
norm_post: true
|
17 |
+
activation_post: null
|
18 |
+
position_embedding: 'learned'
|
configs_v1/modules/scheduler_lcm.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scheduler:
|
2 |
+
target: diffusers.LCMScheduler
|
3 |
+
num_inference_steps: 1
|
4 |
+
params:
|
5 |
+
num_train_timesteps: 1000
|
6 |
+
beta_start: 0.00085
|
7 |
+
beta_end: 0.012
|
8 |
+
beta_schedule: 'scaled_linear'
|
9 |
+
clip_sample: false
|
10 |
+
set_alpha_to_one: false
|
11 |
+
original_inference_steps: 50
|
configs_v1/modules/text_encoder.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
text_encoder:
|
2 |
+
target: mld.models.architectures.mld_clip.MldTextEncoder
|
3 |
+
params:
|
4 |
+
last_hidden_state: false
|
5 |
+
modelpath: ${model.t5_path}
|
configs_v1/modules/traj_encoder.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
traj_encoder:
|
2 |
+
target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder
|
3 |
+
params:
|
4 |
+
nfeats: ${DATASET.NJOINTS}
|
5 |
+
latent_dim: ${model.latent_dim}
|
6 |
+
hidden_dim: null
|
7 |
+
force_post_proj: false
|
8 |
+
ff_size: 1024
|
9 |
+
num_layers: 9
|
10 |
+
num_heads: 4
|
11 |
+
dropout: 0.1
|
12 |
+
normalize_before: false
|
13 |
+
norm_eps: 1e-5
|
14 |
+
activation: 'gelu'
|
15 |
+
norm_post: true
|
16 |
+
activation_post: null
|
17 |
+
position_embedding: 'learned'
|
configs_v1/motionlcm_control_t.yaml
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_control/temporal'
|
2 |
+
TEST_FOLDER: './experiments_control_test/temporal'
|
3 |
+
|
4 |
+
NAME: 'motionlcm_humanml'
|
5 |
+
|
6 |
+
SEED_VALUE: 1234
|
7 |
+
|
8 |
+
TRAIN:
|
9 |
+
DATASET: 'humanml3d'
|
10 |
+
BATCH_SIZE: 128
|
11 |
+
SPLIT: 'train'
|
12 |
+
NUM_WORKERS: 8
|
13 |
+
PERSISTENT_WORKERS: true
|
14 |
+
|
15 |
+
PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml_v1.ckpt'
|
16 |
+
|
17 |
+
validation_steps: -1
|
18 |
+
validation_epochs: 50
|
19 |
+
checkpointing_steps: -1
|
20 |
+
checkpointing_epochs: 50
|
21 |
+
max_train_steps: -1
|
22 |
+
max_train_epochs: 1000
|
23 |
+
learning_rate: 1e-4
|
24 |
+
learning_rate_spatial: 1e-4
|
25 |
+
lr_scheduler: "cosine"
|
26 |
+
lr_warmup_steps: 1000
|
27 |
+
adam_beta1: 0.9
|
28 |
+
adam_beta2: 0.999
|
29 |
+
adam_weight_decay: 0.0
|
30 |
+
adam_epsilon: 1e-08
|
31 |
+
max_grad_norm: 1.0
|
32 |
+
|
33 |
+
VAL:
|
34 |
+
DATASET: 'humanml3d'
|
35 |
+
BATCH_SIZE: 32
|
36 |
+
SPLIT: 'test'
|
37 |
+
NUM_WORKERS: 12
|
38 |
+
PERSISTENT_WORKERS: true
|
39 |
+
|
40 |
+
TEST:
|
41 |
+
DATASET: 'humanml3d'
|
42 |
+
BATCH_SIZE: 32
|
43 |
+
SPLIT: 'test'
|
44 |
+
NUM_WORKERS: 12
|
45 |
+
PERSISTENT_WORKERS: true
|
46 |
+
|
47 |
+
CHECKPOINTS: 'experiments_control/temporal/motionlcm_humanml/motionlcm_humanml_t_v1.ckpt'
|
48 |
+
|
49 |
+
# Testing Args
|
50 |
+
REPLICATION_TIMES: 20
|
51 |
+
MM_NUM_SAMPLES: 100
|
52 |
+
MM_NUM_REPEATS: 30
|
53 |
+
MM_NUM_TIMES: 10
|
54 |
+
DIVERSITY_TIMES: 300
|
55 |
+
DO_MM_TEST: false
|
56 |
+
|
57 |
+
DATASET:
|
58 |
+
NAME: 'humanml3d'
|
59 |
+
SMPL_PATH: './deps/smpl'
|
60 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
61 |
+
HUMANML3D:
|
62 |
+
FRAME_RATE: 20.0
|
63 |
+
UNIT_LEN: 4
|
64 |
+
ROOT: './datasets/humanml3d'
|
65 |
+
CONTROL_ARGS:
|
66 |
+
CONTROL: true
|
67 |
+
TEMPORAL: true
|
68 |
+
TRAIN_JOINTS: [0, 10, 11, 15, 20, 21]
|
69 |
+
TEST_JOINTS: [0, 10, 11, 15, 20, 21]
|
70 |
+
TRAIN_DENSITY: [25, 25]
|
71 |
+
TEST_DENSITY: 25
|
72 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
73 |
+
SAMPLER:
|
74 |
+
MAX_LEN: 200
|
75 |
+
MIN_LEN: 40
|
76 |
+
MAX_TEXT_LEN: 20
|
77 |
+
PADDING_TO_MAX: false
|
78 |
+
WINDOW_SIZE: null
|
79 |
+
|
80 |
+
METRIC:
|
81 |
+
DIST_SYNC_ON_STEP: true
|
82 |
+
TYPE: ['TM2TMetrics', 'ControlMetrics']
|
83 |
+
|
84 |
+
model:
|
85 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder']
|
86 |
+
latent_dim: [1, 256]
|
87 |
+
guidance_scale: 7.5
|
88 |
+
|
89 |
+
# ControlNet Args
|
90 |
+
is_controlnet: true
|
91 |
+
vaeloss: true
|
92 |
+
vaeloss_type: 'sum'
|
93 |
+
cond_ratio: 1.0
|
94 |
+
control_loss_func: 'l2'
|
95 |
+
use_3d: false
|
96 |
+
lcm_w_min_nax: null
|
97 |
+
lcm_num_ddim_timesteps: null
|
98 |
+
|
99 |
+
t2m_textencoder:
|
100 |
+
dim_word: 300
|
101 |
+
dim_pos_ohot: 15
|
102 |
+
dim_text_hidden: 512
|
103 |
+
dim_coemb_hidden: 512
|
104 |
+
|
105 |
+
t2m_motionencoder:
|
106 |
+
dim_move_hidden: 512
|
107 |
+
dim_move_latent: 512
|
108 |
+
dim_motion_hidden: 1024
|
109 |
+
dim_motion_latent: 512
|
110 |
+
|
111 |
+
bert_path: './deps/distilbert-base-uncased'
|
112 |
+
clip_path: './deps/clip-vit-large-patch14'
|
113 |
+
t5_path: './deps/sentence-t5-large'
|
114 |
+
t2m_path: './deps/t2m/'
|
configs_v1/motionlcm_t2m.yaml
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_t2m'
|
2 |
+
TEST_FOLDER: './experiments_t2m_test'
|
3 |
+
|
4 |
+
NAME: 'motionlcm_humanml'
|
5 |
+
|
6 |
+
SEED_VALUE: 1234
|
7 |
+
|
8 |
+
TRAIN:
|
9 |
+
BATCH_SIZE: 256
|
10 |
+
SPLIT: 'train'
|
11 |
+
NUM_WORKERS: 8
|
12 |
+
PERSISTENT_WORKERS: true
|
13 |
+
|
14 |
+
PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml_v1.ckpt'
|
15 |
+
|
16 |
+
validation_steps: -1
|
17 |
+
validation_epochs: 50
|
18 |
+
checkpointing_steps: -1
|
19 |
+
checkpointing_epochs: 50
|
20 |
+
max_train_steps: -1
|
21 |
+
max_train_epochs: 1000
|
22 |
+
learning_rate: 2e-4
|
23 |
+
lr_scheduler: "cosine"
|
24 |
+
lr_warmup_steps: 1000
|
25 |
+
adam_beta1: 0.9
|
26 |
+
adam_beta2: 0.999
|
27 |
+
adam_weight_decay: 0.0
|
28 |
+
adam_epsilon: 1e-08
|
29 |
+
max_grad_norm: 1.0
|
30 |
+
|
31 |
+
# Latent Consistency Distillation Specific Arguments
|
32 |
+
w_min: 5.0
|
33 |
+
w_max: 15.0
|
34 |
+
num_ddim_timesteps: 50
|
35 |
+
loss_type: 'huber'
|
36 |
+
huber_c: 0.001
|
37 |
+
unet_time_cond_proj_dim: 256
|
38 |
+
ema_decay: 0.95
|
39 |
+
|
40 |
+
VAL:
|
41 |
+
BATCH_SIZE: 32
|
42 |
+
SPLIT: 'test'
|
43 |
+
NUM_WORKERS: 12
|
44 |
+
PERSISTENT_WORKERS: true
|
45 |
+
|
46 |
+
TEST:
|
47 |
+
BATCH_SIZE: 32
|
48 |
+
SPLIT: 'test'
|
49 |
+
NUM_WORKERS: 12
|
50 |
+
PERSISTENT_WORKERS: true
|
51 |
+
|
52 |
+
CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml_v1.ckpt'
|
53 |
+
|
54 |
+
# Testing Args
|
55 |
+
REPLICATION_TIMES: 20
|
56 |
+
MM_NUM_SAMPLES: 100
|
57 |
+
MM_NUM_REPEATS: 30
|
58 |
+
MM_NUM_TIMES: 10
|
59 |
+
DIVERSITY_TIMES: 300
|
60 |
+
DO_MM_TEST: true
|
61 |
+
|
62 |
+
DATASET:
|
63 |
+
NAME: 'humanml3d'
|
64 |
+
SMPL_PATH: './deps/smpl'
|
65 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
66 |
+
HUMANML3D:
|
67 |
+
FRAME_RATE: 20.0
|
68 |
+
UNIT_LEN: 4
|
69 |
+
ROOT: './datasets/humanml3d'
|
70 |
+
CONTROL_ARGS:
|
71 |
+
CONTROL: false
|
72 |
+
TEMPORAL: false
|
73 |
+
TRAIN_JOINTS: [0]
|
74 |
+
TEST_JOINTS: [0]
|
75 |
+
TRAIN_DENSITY: 'random'
|
76 |
+
TEST_DENSITY: 100
|
77 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
78 |
+
SAMPLER:
|
79 |
+
MAX_LEN: 200
|
80 |
+
MIN_LEN: 40
|
81 |
+
MAX_TEXT_LEN: 20
|
82 |
+
PADDING_TO_MAX: false
|
83 |
+
WINDOW_SIZE: null
|
84 |
+
|
85 |
+
METRIC:
|
86 |
+
DIST_SYNC_ON_STEP: true
|
87 |
+
TYPE: ['TM2TMetrics']
|
88 |
+
|
89 |
+
model:
|
90 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm']
|
91 |
+
latent_dim: [1, 256]
|
92 |
+
guidance_scale: 7.5
|
93 |
+
|
94 |
+
t2m_textencoder:
|
95 |
+
dim_word: 300
|
96 |
+
dim_pos_ohot: 15
|
97 |
+
dim_text_hidden: 512
|
98 |
+
dim_coemb_hidden: 512
|
99 |
+
|
100 |
+
t2m_motionencoder:
|
101 |
+
dim_move_hidden: 512
|
102 |
+
dim_move_latent: 512
|
103 |
+
dim_motion_hidden: 1024
|
104 |
+
dim_motion_latent: 512
|
105 |
+
|
106 |
+
bert_path: './deps/distilbert-base-uncased'
|
107 |
+
clip_path: './deps/clip-vit-large-patch14'
|
108 |
+
t5_path: './deps/sentence-t5-large'
|
109 |
+
t2m_path: './deps/t2m/'
|
demo.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import sys
|
4 |
+
import datetime
|
5 |
+
import logging
|
6 |
+
import os.path as osp
|
7 |
+
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from mld.config import parse_args
|
13 |
+
from mld.data.get_data import get_dataset
|
14 |
+
from mld.models.modeltype.mld import MLD
|
15 |
+
from mld.models.modeltype.vae import VAE
|
16 |
+
from mld.utils.utils import set_seed, move_batch_to_device
|
17 |
+
from mld.data.humanml.utils.plot_script import plot_3d_motion
|
18 |
+
from mld.utils.temos_utils import remove_padding
|
19 |
+
|
20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
21 |
+
|
22 |
+
|
23 |
+
def load_example_hint_input(text_path: str) -> tuple:
|
24 |
+
with open(text_path, "r") as f:
|
25 |
+
lines = f.readlines()
|
26 |
+
|
27 |
+
n_frames, control_type_ids, control_hint_ids = [], [], []
|
28 |
+
for line in lines:
|
29 |
+
s = line.strip()
|
30 |
+
n_frame, control_type_id, control_hint_id = s.split(' ')
|
31 |
+
n_frames.append(int(n_frame))
|
32 |
+
control_type_ids.append(int(control_type_id))
|
33 |
+
control_hint_ids.append(int(control_hint_id))
|
34 |
+
|
35 |
+
return n_frames, control_type_ids, control_hint_ids
|
36 |
+
|
37 |
+
|
38 |
+
def load_example_input(text_path: str) -> tuple:
|
39 |
+
with open(text_path, "r") as f:
|
40 |
+
lines = f.readlines()
|
41 |
+
|
42 |
+
texts, lens = [], []
|
43 |
+
for line in lines:
|
44 |
+
s = line.strip()
|
45 |
+
s_l = s.split(" ")[0]
|
46 |
+
s_t = s[(len(s_l) + 1):]
|
47 |
+
lens.append(int(s_l))
|
48 |
+
texts.append(s_t)
|
49 |
+
return texts, lens
|
50 |
+
|
51 |
+
|
52 |
+
def main():
|
53 |
+
cfg = parse_args()
|
54 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
55 |
+
set_seed(cfg.SEED_VALUE)
|
56 |
+
|
57 |
+
name_time_str = osp.join(cfg.NAME, "demo_" + datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
58 |
+
cfg.output_dir = osp.join(cfg.TEST_FOLDER, name_time_str)
|
59 |
+
vis_dir = osp.join(cfg.output_dir, 'samples')
|
60 |
+
os.makedirs(cfg.output_dir, exist_ok=False)
|
61 |
+
os.makedirs(vis_dir, exist_ok=False)
|
62 |
+
|
63 |
+
steam_handler = logging.StreamHandler(sys.stdout)
|
64 |
+
file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log'))
|
65 |
+
logging.basicConfig(level=logging.INFO,
|
66 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
67 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
68 |
+
handlers=[steam_handler, file_handler])
|
69 |
+
logger = logging.getLogger(__name__)
|
70 |
+
|
71 |
+
OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml'))
|
72 |
+
|
73 |
+
state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
|
74 |
+
logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
|
75 |
+
|
76 |
+
# Step 1: Check if the checkpoint is VAE-based.
|
77 |
+
is_vae = False
|
78 |
+
vae_key = 'vae.skel_embedding.weight'
|
79 |
+
if vae_key in state_dict:
|
80 |
+
is_vae = True
|
81 |
+
logger.info(f'Is VAE: {is_vae}')
|
82 |
+
|
83 |
+
# Step 2: Check if the checkpoint is MLD-based.
|
84 |
+
is_mld = False
|
85 |
+
mld_key = 'denoiser.time_embedding.linear_1.weight'
|
86 |
+
if mld_key in state_dict:
|
87 |
+
is_mld = True
|
88 |
+
logger.info(f'Is MLD: {is_mld}')
|
89 |
+
|
90 |
+
# Step 3: Check if the checkpoint is LCM-based.
|
91 |
+
is_lcm = False
|
92 |
+
lcm_key = 'denoiser.time_embedding.cond_proj.weight' # unique key for CFG
|
93 |
+
if lcm_key in state_dict:
|
94 |
+
is_lcm = True
|
95 |
+
time_cond_proj_dim = state_dict[lcm_key].shape[1]
|
96 |
+
cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim
|
97 |
+
logger.info(f'Is LCM: {is_lcm}')
|
98 |
+
|
99 |
+
# Step 4: Check if the checkpoint is Controlnet-based.
|
100 |
+
cn_key = "controlnet.controlnet_cond_embedding.0.weight"
|
101 |
+
is_controlnet = True if cn_key in state_dict else False
|
102 |
+
cfg.model.is_controlnet = is_controlnet
|
103 |
+
logger.info(f'Is Controlnet: {is_controlnet}')
|
104 |
+
|
105 |
+
if is_mld or is_lcm or is_controlnet:
|
106 |
+
target_model_class = MLD
|
107 |
+
else:
|
108 |
+
target_model_class = VAE
|
109 |
+
|
110 |
+
if cfg.optimize:
|
111 |
+
assert cfg.model.get('noise_optimizer') is not None
|
112 |
+
cfg.model.noise_optimizer.params.optimize = True
|
113 |
+
logger.info('Optimization enabled. Set the batch size to 1.')
|
114 |
+
logger.info(f'Original batch size: {cfg.TEST.BATCH_SIZE}')
|
115 |
+
cfg.TEST.BATCH_SIZE = 1
|
116 |
+
|
117 |
+
dataset = get_dataset(cfg)
|
118 |
+
model = target_model_class(cfg, dataset)
|
119 |
+
model.to(device)
|
120 |
+
model.eval()
|
121 |
+
model.requires_grad_(False)
|
122 |
+
logger.info(model.load_state_dict(state_dict))
|
123 |
+
|
124 |
+
FPS = eval(f"cfg.DATASET.{cfg.DATASET.NAME.upper()}.FRAME_RATE")
|
125 |
+
|
126 |
+
if cfg.example is not None and not is_controlnet:
|
127 |
+
text, length = load_example_input(cfg.example)
|
128 |
+
for t, l in zip(text, length):
|
129 |
+
logger.info(f"{l}: {t}")
|
130 |
+
|
131 |
+
batch = {"length": length, "text": text}
|
132 |
+
|
133 |
+
for rep_i in range(cfg.replication):
|
134 |
+
with torch.no_grad():
|
135 |
+
joints = model(batch)[0]
|
136 |
+
|
137 |
+
num_samples = len(joints)
|
138 |
+
for i in range(num_samples):
|
139 |
+
res = dict()
|
140 |
+
pkl_path = osp.join(vis_dir, f"sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl")
|
141 |
+
res['joints'] = joints[i].detach().cpu().numpy()
|
142 |
+
res['text'] = text[i]
|
143 |
+
res['length'] = length[i]
|
144 |
+
res['hint'] = None
|
145 |
+
with open(pkl_path, 'wb') as f:
|
146 |
+
pickle.dump(res, f)
|
147 |
+
logger.info(f"Motions are generated here:\n{pkl_path}")
|
148 |
+
|
149 |
+
if not cfg.no_plot:
|
150 |
+
plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(), text[i], fps=FPS)
|
151 |
+
|
152 |
+
else:
|
153 |
+
test_dataloader = dataset.test_dataloader()
|
154 |
+
for rep_i in range(cfg.replication):
|
155 |
+
for batch_id, batch in enumerate(test_dataloader):
|
156 |
+
batch = move_batch_to_device(batch, device)
|
157 |
+
with torch.no_grad():
|
158 |
+
joints, joints_ref = model(batch)
|
159 |
+
|
160 |
+
num_samples = len(joints)
|
161 |
+
text = batch['text']
|
162 |
+
length = batch['length']
|
163 |
+
if 'hint' in batch:
|
164 |
+
hint, hint_mask = batch['hint'], batch['hint_mask']
|
165 |
+
hint = dataset.denorm_spatial(hint) * hint_mask
|
166 |
+
hint = remove_padding(hint, lengths=length)
|
167 |
+
else:
|
168 |
+
hint = None
|
169 |
+
|
170 |
+
for i in range(num_samples):
|
171 |
+
res = dict()
|
172 |
+
pkl_path = osp.join(vis_dir, f"batch_id_{batch_id}_sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl")
|
173 |
+
res['joints'] = joints[i].detach().cpu().numpy()
|
174 |
+
res['text'] = text[i]
|
175 |
+
res['length'] = length[i]
|
176 |
+
res['hint'] = hint[i].detach().cpu().numpy() if hint is not None else None
|
177 |
+
with open(pkl_path, 'wb') as f:
|
178 |
+
pickle.dump(res, f)
|
179 |
+
logger.info(f"Motions are generated here:\n{pkl_path}")
|
180 |
+
|
181 |
+
if not cfg.no_plot:
|
182 |
+
plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(),
|
183 |
+
text[i], fps=FPS, hint=hint[i].detach().cpu().numpy() if hint is not None else None)
|
184 |
+
|
185 |
+
if rep_i == 0:
|
186 |
+
res['joints'] = joints_ref[i].detach().cpu().numpy()
|
187 |
+
with open(pkl_path.replace('.pkl', '_ref.pkl'), 'wb') as f:
|
188 |
+
pickle.dump(res, f)
|
189 |
+
logger.info(f"Motions are generated here:\n{pkl_path.replace('.pkl', '_ref.pkl')}")
|
190 |
+
if not cfg.no_plot:
|
191 |
+
plot_3d_motion(pkl_path.replace('.pkl', '_ref.mp4'), joints_ref[i].detach().cpu().numpy(),
|
192 |
+
text[i], fps=FPS, hint=hint[i].detach().cpu().numpy() if hint is not None else None)
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == "__main__":
|
196 |
+
main()
|
fit.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# borrow from optimization https://github.com/wangsen1312/joints2smpl
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
import natsort
|
8 |
+
import smplx
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from mld.transforms.joints2rots import config
|
13 |
+
from mld.transforms.joints2rots.smplify import SMPLify3D
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument("--pkl", type=str, default=None, help="pkl motion file")
|
17 |
+
parser.add_argument("--dir", type=str, default=None, help="pkl motion folder")
|
18 |
+
parser.add_argument("--num_smplify_iters", type=int, default=150, help="num of smplify iters")
|
19 |
+
parser.add_argument("--cuda", type=bool, default=True, help="enables cuda")
|
20 |
+
parser.add_argument("--gpu_ids", type=int, default=0, help="choose gpu ids")
|
21 |
+
parser.add_argument("--num_joints", type=int, default=22, help="joint number")
|
22 |
+
parser.add_argument("--joint_category", type=str, default="AMASS", help="use correspondence")
|
23 |
+
parser.add_argument("--fix_foot", type=str, default="False", help="fix foot or not")
|
24 |
+
opt = parser.parse_args()
|
25 |
+
print(opt)
|
26 |
+
|
27 |
+
if opt.pkl:
|
28 |
+
paths = [opt.pkl]
|
29 |
+
elif opt.dir:
|
30 |
+
paths = []
|
31 |
+
file_list = natsort.natsorted(os.listdir(opt.dir))
|
32 |
+
for item in file_list:
|
33 |
+
if item.endswith('.pkl') and not item.endswith("_mesh.pkl"):
|
34 |
+
paths.append(os.path.join(opt.dir, item))
|
35 |
+
else:
|
36 |
+
raise ValueError(f'{opt.pkl} and {opt.dir} are both None!')
|
37 |
+
|
38 |
+
for path in paths:
|
39 |
+
# load joints
|
40 |
+
if os.path.exists(path.replace('.pkl', '_mesh.pkl')):
|
41 |
+
print(f"{path} is rendered! skip!")
|
42 |
+
continue
|
43 |
+
|
44 |
+
with open(path, 'rb') as f:
|
45 |
+
data = pickle.load(f)
|
46 |
+
|
47 |
+
joints = data['joints']
|
48 |
+
# load predefined something
|
49 |
+
device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu")
|
50 |
+
print(config.SMPL_MODEL_DIR)
|
51 |
+
smplxmodel = smplx.create(
|
52 |
+
config.SMPL_MODEL_DIR,
|
53 |
+
model_type="smpl",
|
54 |
+
gender="neutral",
|
55 |
+
ext="pkl",
|
56 |
+
batch_size=joints.shape[0],
|
57 |
+
).to(device)
|
58 |
+
|
59 |
+
# load the mean pose as original
|
60 |
+
smpl_mean_file = config.SMPL_MEAN_FILE
|
61 |
+
|
62 |
+
file = h5py.File(smpl_mean_file, "r")
|
63 |
+
init_mean_pose = (
|
64 |
+
torch.from_numpy(file["pose"][:])
|
65 |
+
.unsqueeze(0).repeat(joints.shape[0], 1)
|
66 |
+
.float()
|
67 |
+
.to(device)
|
68 |
+
)
|
69 |
+
init_mean_shape = (
|
70 |
+
torch.from_numpy(file["shape"][:])
|
71 |
+
.unsqueeze(0).repeat(joints.shape[0], 1)
|
72 |
+
.float()
|
73 |
+
.to(device)
|
74 |
+
)
|
75 |
+
cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device)
|
76 |
+
|
77 |
+
# initialize SMPLify
|
78 |
+
smplify = SMPLify3D(
|
79 |
+
smplxmodel=smplxmodel,
|
80 |
+
batch_size=joints.shape[0],
|
81 |
+
joints_category=opt.joint_category,
|
82 |
+
num_iters=opt.num_smplify_iters,
|
83 |
+
device=device,
|
84 |
+
)
|
85 |
+
print("initialize SMPLify3D done!")
|
86 |
+
|
87 |
+
print("Start SMPLify!")
|
88 |
+
keypoints_3d = torch.Tensor(joints).to(device).float()
|
89 |
+
|
90 |
+
if opt.joint_category == "AMASS":
|
91 |
+
confidence_input = torch.ones(opt.num_joints)
|
92 |
+
# make sure the foot and ankle
|
93 |
+
if opt.fix_foot:
|
94 |
+
confidence_input[7] = 1.5
|
95 |
+
confidence_input[8] = 1.5
|
96 |
+
confidence_input[10] = 1.5
|
97 |
+
confidence_input[11] = 1.5
|
98 |
+
else:
|
99 |
+
print("Such category not settle down!")
|
100 |
+
|
101 |
+
# ----- from initial to fitting -------
|
102 |
+
(
|
103 |
+
new_opt_vertices,
|
104 |
+
new_opt_joints,
|
105 |
+
new_opt_pose,
|
106 |
+
new_opt_betas,
|
107 |
+
new_opt_cam_t,
|
108 |
+
new_opt_joint_loss,
|
109 |
+
) = smplify(
|
110 |
+
init_mean_pose.detach(),
|
111 |
+
init_mean_shape.detach(),
|
112 |
+
cam_trans_zero.detach(),
|
113 |
+
keypoints_3d,
|
114 |
+
conf_3d=confidence_input.to(device)
|
115 |
+
)
|
116 |
+
|
117 |
+
# fix shape
|
118 |
+
betas = torch.zeros_like(new_opt_betas)
|
119 |
+
root = keypoints_3d[:, 0, :]
|
120 |
+
|
121 |
+
output = smplxmodel(
|
122 |
+
betas=betas,
|
123 |
+
global_orient=new_opt_pose[:, :3],
|
124 |
+
body_pose=new_opt_pose[:, 3:],
|
125 |
+
transl=root,
|
126 |
+
return_verts=True
|
127 |
+
)
|
128 |
+
vertices = output.vertices.detach().cpu().numpy()
|
129 |
+
floor_height = vertices[..., 1].min()
|
130 |
+
vertices[..., 1] -= floor_height
|
131 |
+
data['vertices'] = vertices
|
132 |
+
|
133 |
+
save_file = path.replace('.pkl', '_mesh.pkl')
|
134 |
+
with open(save_file, 'wb') as f:
|
135 |
+
pickle.dump(data, f)
|
136 |
+
print(f'vertices saved in {save_file}')
|
mld/__init__.py
ADDED
File without changes
|
mld/config.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib
|
3 |
+
from typing import Type, TypeVar
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
|
6 |
+
from omegaconf import OmegaConf, DictConfig
|
7 |
+
|
8 |
+
|
9 |
+
def get_module_config(cfg_model: DictConfig, paths: list[str], cfg_root: str) -> DictConfig:
|
10 |
+
files = [os.path.join(cfg_root, 'modules', p+'.yaml') for p in paths]
|
11 |
+
for file in files:
|
12 |
+
assert os.path.exists(file), f'{file} is not exists.'
|
13 |
+
with open(file, 'r') as f:
|
14 |
+
cfg_model.merge_with(OmegaConf.load(f))
|
15 |
+
return cfg_model
|
16 |
+
|
17 |
+
|
18 |
+
def get_obj_from_str(string: str, reload: bool = False) -> Type:
|
19 |
+
module, cls = string.rsplit(".", 1)
|
20 |
+
if reload:
|
21 |
+
module_imp = importlib.import_module(module)
|
22 |
+
importlib.reload(module_imp)
|
23 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
24 |
+
|
25 |
+
|
26 |
+
def instantiate_from_config(config: DictConfig) -> TypeVar:
|
27 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
28 |
+
|
29 |
+
|
30 |
+
def parse_args() -> DictConfig:
|
31 |
+
parser = ArgumentParser()
|
32 |
+
parser.add_argument("--cfg", type=str, required=True, help="The main config file")
|
33 |
+
parser.add_argument('--example', type=str, required=False, help="The input texts and lengths with txt format")
|
34 |
+
parser.add_argument('--example_hint', type=str, required=False, help="The input hint ids and lengths with txt format")
|
35 |
+
parser.add_argument('--no-plot', action="store_true", required=False, help="Whether to plot the skeleton-based motion")
|
36 |
+
parser.add_argument('--replication', type=int, default=1, help="The number of replications of sampling")
|
37 |
+
parser.add_argument('--vis', type=str, default="tb", choices=['tb', 'swanlab'], help="The visualization backends: tensorboard or swanlab")
|
38 |
+
parser.add_argument('--optimize', action='store_true', help="Enable optimization for motion control")
|
39 |
+
args = parser.parse_args()
|
40 |
+
|
41 |
+
cfg = OmegaConf.load(args.cfg)
|
42 |
+
cfg_root = os.path.dirname(args.cfg)
|
43 |
+
cfg_model = get_module_config(cfg.model, cfg.model.target, cfg_root)
|
44 |
+
cfg = OmegaConf.merge(cfg, cfg_model)
|
45 |
+
|
46 |
+
cfg.example = args.example
|
47 |
+
cfg.example_hint = args.example_hint
|
48 |
+
cfg.no_plot = args.no_plot
|
49 |
+
cfg.replication = args.replication
|
50 |
+
cfg.vis = args.vis
|
51 |
+
cfg.optimize = args.optimize
|
52 |
+
return cfg
|
mld/data/__init__.py
ADDED
File without changes
|
mld/data/base.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from os.path import join as pjoin
|
3 |
+
from typing import Any, Callable
|
4 |
+
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
|
8 |
+
class BaseDataModule:
|
9 |
+
def __init__(self, collate_fn: Callable) -> None:
|
10 |
+
super(BaseDataModule, self).__init__()
|
11 |
+
self.collate_fn = collate_fn
|
12 |
+
self.is_mm = False
|
13 |
+
|
14 |
+
def get_sample_set(self, overrides: dict) -> Any:
|
15 |
+
sample_params = copy.deepcopy(self.hparams)
|
16 |
+
sample_params.update(overrides)
|
17 |
+
split_file = pjoin(
|
18 |
+
eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"),
|
19 |
+
self.cfg.TEST.SPLIT + ".txt"
|
20 |
+
)
|
21 |
+
return self.Dataset(split_file=split_file, **sample_params)
|
22 |
+
|
23 |
+
def __getattr__(self, item: str) -> Any:
|
24 |
+
if item.endswith("_dataset") and not item.startswith("_"):
|
25 |
+
subset = item[:-len("_dataset")].upper()
|
26 |
+
item_c = "_" + item
|
27 |
+
if item_c not in self.__dict__:
|
28 |
+
split_file = pjoin(
|
29 |
+
eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"),
|
30 |
+
eval(f"self.cfg.{subset}.SPLIT") + ".txt"
|
31 |
+
)
|
32 |
+
self.__dict__[item_c] = self.Dataset(split_file=split_file, **self.hparams)
|
33 |
+
return getattr(self, item_c)
|
34 |
+
classname = self.__class__.__name__
|
35 |
+
raise AttributeError(f"'{classname}' object has no attribute '{item}'")
|
36 |
+
|
37 |
+
def get_dataloader_options(self, stage: str) -> dict:
|
38 |
+
stage_args = eval(f"self.cfg.{stage.upper()}")
|
39 |
+
dataloader_options = {
|
40 |
+
"batch_size": stage_args.BATCH_SIZE,
|
41 |
+
"num_workers": stage_args.NUM_WORKERS,
|
42 |
+
"collate_fn": self.collate_fn,
|
43 |
+
"persistent_workers": stage_args.PERSISTENT_WORKERS,
|
44 |
+
}
|
45 |
+
return dataloader_options
|
46 |
+
|
47 |
+
def train_dataloader(self) -> DataLoader:
|
48 |
+
dataloader_options = self.get_dataloader_options('TRAIN')
|
49 |
+
return DataLoader(self.train_dataset, shuffle=True, **dataloader_options)
|
50 |
+
|
51 |
+
def val_dataloader(self) -> DataLoader:
|
52 |
+
dataloader_options = self.get_dataloader_options('VAL')
|
53 |
+
return DataLoader(self.val_dataset, shuffle=False, **dataloader_options)
|
54 |
+
|
55 |
+
def test_dataloader(self) -> DataLoader:
|
56 |
+
dataloader_options = self.get_dataloader_options('TEST')
|
57 |
+
dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
|
58 |
+
return DataLoader(self.test_dataset, shuffle=False, **dataloader_options)
|
mld/data/data.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Callable, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from .base import BaseDataModule
|
10 |
+
from .humanml.dataset import Text2MotionDataset, MotionDataset
|
11 |
+
from .humanml.scripts.motion_process import recover_from_ric
|
12 |
+
|
13 |
+
|
14 |
+
# (nfeats, njoints)
|
15 |
+
dataset_map = {'humanml3d': (263, 22), 'kit': (251, 21)}
|
16 |
+
|
17 |
+
|
18 |
+
class DataModule(BaseDataModule):
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
name: str,
|
22 |
+
cfg: DictConfig,
|
23 |
+
motion_only: bool,
|
24 |
+
collate_fn: Optional[Callable] = None,
|
25 |
+
**kwargs) -> None:
|
26 |
+
super().__init__(collate_fn=collate_fn)
|
27 |
+
self.cfg = cfg
|
28 |
+
self.name = name
|
29 |
+
self.nfeats, self.njoints = dataset_map[name]
|
30 |
+
self.hparams = copy.deepcopy({**kwargs, 'njoints': self.njoints})
|
31 |
+
self.Dataset = MotionDataset if motion_only else Text2MotionDataset
|
32 |
+
sample_overrides = {"tiny": True, "progress_bar": False}
|
33 |
+
self._sample_set = self.get_sample_set(overrides=sample_overrides)
|
34 |
+
|
35 |
+
def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
|
36 |
+
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
|
37 |
+
raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
|
38 |
+
hint = hint * raw_std + raw_mean
|
39 |
+
return hint
|
40 |
+
|
41 |
+
def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
|
42 |
+
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
|
43 |
+
raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
|
44 |
+
hint = (hint - raw_mean) / raw_std
|
45 |
+
return hint
|
46 |
+
|
47 |
+
def feats2joints(self, features: torch.Tensor) -> torch.Tensor:
|
48 |
+
mean = torch.tensor(self.hparams['mean']).to(features)
|
49 |
+
std = torch.tensor(self.hparams['std']).to(features)
|
50 |
+
features = features * std + mean
|
51 |
+
return recover_from_ric(features, self.njoints)
|
52 |
+
|
53 |
+
def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor:
|
54 |
+
# renorm to t2m norms for using t2m evaluators
|
55 |
+
ori_mean = torch.tensor(self.hparams['mean']).to(features)
|
56 |
+
ori_std = torch.tensor(self.hparams['std']).to(features)
|
57 |
+
eval_mean = torch.tensor(self.hparams['mean_eval']).to(features)
|
58 |
+
eval_std = torch.tensor(self.hparams['std_eval']).to(features)
|
59 |
+
features = features * ori_std + ori_mean
|
60 |
+
features = (features - eval_mean) / eval_std
|
61 |
+
return features
|
62 |
+
|
63 |
+
def mm_mode(self, mm_on: bool = True) -> None:
|
64 |
+
if mm_on:
|
65 |
+
self.is_mm = True
|
66 |
+
self.name_list = self.test_dataset.name_list
|
67 |
+
self.mm_list = np.random.choice(self.name_list,
|
68 |
+
self.cfg.TEST.MM_NUM_SAMPLES,
|
69 |
+
replace=False)
|
70 |
+
self.test_dataset.name_list = self.mm_list
|
71 |
+
else:
|
72 |
+
self.is_mm = False
|
73 |
+
self.test_dataset.name_list = self.name_list
|
mld/data/get_data.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
from os.path import join as pjoin
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from omegaconf import DictConfig
|
7 |
+
|
8 |
+
from .data import DataModule
|
9 |
+
from .base import BaseDataModule
|
10 |
+
from .utils import mld_collate, mld_collate_motion_only
|
11 |
+
from .humanml.utils.word_vectorizer import WordVectorizer
|
12 |
+
|
13 |
+
|
14 |
+
def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]:
|
15 |
+
name = "t2m" if dataset_name == "humanml3d" else dataset_name
|
16 |
+
assert name in ["t2m", "kit"]
|
17 |
+
if phase in ["val"]:
|
18 |
+
if name == 't2m':
|
19 |
+
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta")
|
20 |
+
elif name == 'kit':
|
21 |
+
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta")
|
22 |
+
else:
|
23 |
+
raise ValueError("Only support t2m and kit")
|
24 |
+
mean = np.load(pjoin(data_root, "mean.npy"))
|
25 |
+
std = np.load(pjoin(data_root, "std.npy"))
|
26 |
+
else:
|
27 |
+
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
|
28 |
+
mean = np.load(pjoin(data_root, "Mean.npy"))
|
29 |
+
std = np.load(pjoin(data_root, "Std.npy"))
|
30 |
+
|
31 |
+
return mean, std
|
32 |
+
|
33 |
+
|
34 |
+
def get_WordVectorizer(cfg: DictConfig, dataset_name: str) -> Optional[WordVectorizer]:
|
35 |
+
if dataset_name.lower() in ["humanml3d", "kit"]:
|
36 |
+
return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab")
|
37 |
+
else:
|
38 |
+
raise ValueError("Only support WordVectorizer for HumanML3D and KIT")
|
39 |
+
|
40 |
+
|
41 |
+
dataset_module_map = {"humanml3d": DataModule, "kit": DataModule}
|
42 |
+
motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"}
|
43 |
+
|
44 |
+
|
45 |
+
def get_dataset(cfg: DictConfig, motion_only: bool = False) -> BaseDataModule:
|
46 |
+
dataset_name = cfg.DATASET.NAME
|
47 |
+
if dataset_name.lower() in ["humanml3d", "kit"]:
|
48 |
+
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
|
49 |
+
mean, std = get_mean_std('train', cfg, dataset_name)
|
50 |
+
mean_eval, std_eval = get_mean_std("val", cfg, dataset_name)
|
51 |
+
wordVectorizer = None if motion_only else get_WordVectorizer(cfg, dataset_name)
|
52 |
+
collate_fn = mld_collate_motion_only if motion_only else mld_collate
|
53 |
+
dataset = dataset_module_map[dataset_name.lower()](
|
54 |
+
name=dataset_name.lower(),
|
55 |
+
cfg=cfg,
|
56 |
+
motion_only=motion_only,
|
57 |
+
collate_fn=collate_fn,
|
58 |
+
mean=mean,
|
59 |
+
std=std,
|
60 |
+
mean_eval=mean_eval,
|
61 |
+
std_eval=std_eval,
|
62 |
+
w_vectorizer=wordVectorizer,
|
63 |
+
text_dir=pjoin(data_root, "texts"),
|
64 |
+
motion_dir=pjoin(data_root, motion_subdir[dataset_name]),
|
65 |
+
max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN,
|
66 |
+
min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN,
|
67 |
+
max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN,
|
68 |
+
unit_length=eval(f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"),
|
69 |
+
fps=eval(f"cfg.DATASET.{dataset_name.upper()}.FRAME_RATE"),
|
70 |
+
padding_to_max=cfg.DATASET.PADDING_TO_MAX,
|
71 |
+
window_size=cfg.DATASET.WINDOW_SIZE,
|
72 |
+
control_args=eval(f"cfg.DATASET.{dataset_name.upper()}.CONTROL_ARGS"))
|
73 |
+
|
74 |
+
cfg.DATASET.NFEATS = dataset.nfeats
|
75 |
+
cfg.DATASET.NJOINTS = dataset.njoints
|
76 |
+
return dataset
|
77 |
+
|
78 |
+
elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]:
|
79 |
+
raise NotImplementedError
|
mld/data/humanml/__init__.py
ADDED
File without changes
|
mld/data/humanml/common/quaternion.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def qinv(q: torch.Tensor) -> torch.Tensor:
|
5 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
6 |
+
mask = torch.ones_like(q)
|
7 |
+
mask[..., 1:] = -mask[..., 1:]
|
8 |
+
return q * mask
|
9 |
+
|
10 |
+
|
11 |
+
def qrot(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
12 |
+
"""
|
13 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
14 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
15 |
+
where * denotes any number of dimensions.
|
16 |
+
Returns a tensor of shape (*, 3).
|
17 |
+
"""
|
18 |
+
assert q.shape[-1] == 4
|
19 |
+
assert v.shape[-1] == 3
|
20 |
+
assert q.shape[:-1] == v.shape[:-1]
|
21 |
+
|
22 |
+
original_shape = list(v.shape)
|
23 |
+
q = q.contiguous().view(-1, 4)
|
24 |
+
v = v.contiguous().view(-1, 3)
|
25 |
+
|
26 |
+
qvec = q[:, 1:]
|
27 |
+
uv = torch.cross(qvec, v, dim=1)
|
28 |
+
uuv = torch.cross(qvec, uv, dim=1)
|
29 |
+
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
mld/data/humanml/dataset.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import logging
|
4 |
+
import codecs as cs
|
5 |
+
from os.path import join as pjoin
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from rich.progress import track
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
|
13 |
+
from .scripts.motion_process import recover_from_ric
|
14 |
+
from .utils.word_vectorizer import WordVectorizer
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class MotionDataset(Dataset):
|
20 |
+
def __init__(self, mean: np.ndarray, std: np.ndarray,
|
21 |
+
split_file: str, motion_dir: str, window_size: int,
|
22 |
+
tiny: bool = False, progress_bar: bool = True, **kwargs) -> None:
|
23 |
+
self.data = []
|
24 |
+
self.lengths = []
|
25 |
+
id_list = []
|
26 |
+
with cs.open(split_file, "r") as f:
|
27 |
+
for line in f.readlines():
|
28 |
+
id_list.append(line.strip())
|
29 |
+
|
30 |
+
maxdata = 10 if tiny else 1e10
|
31 |
+
if progress_bar:
|
32 |
+
enumerator = enumerate(
|
33 |
+
track(
|
34 |
+
id_list,
|
35 |
+
f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}",
|
36 |
+
))
|
37 |
+
else:
|
38 |
+
enumerator = enumerate(id_list)
|
39 |
+
|
40 |
+
count = 0
|
41 |
+
for i, name in enumerator:
|
42 |
+
if count > maxdata:
|
43 |
+
break
|
44 |
+
try:
|
45 |
+
motion = np.load(pjoin(motion_dir, name + '.npy'))
|
46 |
+
if motion.shape[0] < window_size:
|
47 |
+
continue
|
48 |
+
self.lengths.append(motion.shape[0] - window_size)
|
49 |
+
self.data.append(motion)
|
50 |
+
except Exception as e:
|
51 |
+
print(e)
|
52 |
+
pass
|
53 |
+
|
54 |
+
self.cumsum = np.cumsum([0] + self.lengths)
|
55 |
+
if not tiny:
|
56 |
+
logger.info("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
|
57 |
+
|
58 |
+
self.mean = mean
|
59 |
+
self.std = std
|
60 |
+
self.window_size = window_size
|
61 |
+
|
62 |
+
def __len__(self) -> int:
|
63 |
+
return self.cumsum[-1]
|
64 |
+
|
65 |
+
def __getitem__(self, item: int) -> tuple:
|
66 |
+
if item != 0:
|
67 |
+
motion_id = np.searchsorted(self.cumsum, item) - 1
|
68 |
+
idx = item - self.cumsum[motion_id] - 1
|
69 |
+
else:
|
70 |
+
motion_id = 0
|
71 |
+
idx = 0
|
72 |
+
motion = self.data[motion_id][idx:idx + self.window_size]
|
73 |
+
"Z Normalization"
|
74 |
+
motion = (motion - self.mean) / self.std
|
75 |
+
return motion, self.window_size
|
76 |
+
|
77 |
+
|
78 |
+
class Text2MotionDataset(Dataset):
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
mean: np.ndarray,
|
83 |
+
std: np.ndarray,
|
84 |
+
split_file: str,
|
85 |
+
w_vectorizer: WordVectorizer,
|
86 |
+
max_motion_length: int,
|
87 |
+
min_motion_length: int,
|
88 |
+
max_text_len: int,
|
89 |
+
unit_length: int,
|
90 |
+
motion_dir: str,
|
91 |
+
text_dir: str,
|
92 |
+
fps: int,
|
93 |
+
padding_to_max: bool,
|
94 |
+
njoints: int,
|
95 |
+
tiny: bool = False,
|
96 |
+
progress_bar: bool = True,
|
97 |
+
**kwargs,
|
98 |
+
) -> None:
|
99 |
+
self.w_vectorizer = w_vectorizer
|
100 |
+
self.max_motion_length = max_motion_length
|
101 |
+
self.min_motion_length = min_motion_length
|
102 |
+
self.max_text_len = max_text_len
|
103 |
+
self.unit_length = unit_length
|
104 |
+
self.padding_to_max = padding_to_max
|
105 |
+
self.njoints = njoints
|
106 |
+
|
107 |
+
data_dict = {}
|
108 |
+
id_list = []
|
109 |
+
with cs.open(split_file, "r") as f:
|
110 |
+
for line in f.readlines():
|
111 |
+
id_list.append(line.strip())
|
112 |
+
self.id_list = id_list
|
113 |
+
|
114 |
+
maxdata = 10 if tiny else 1e10
|
115 |
+
if progress_bar:
|
116 |
+
enumerator = enumerate(
|
117 |
+
track(
|
118 |
+
id_list,
|
119 |
+
f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}",
|
120 |
+
))
|
121 |
+
else:
|
122 |
+
enumerator = enumerate(id_list)
|
123 |
+
count = 0
|
124 |
+
bad_count = 0
|
125 |
+
new_name_list = []
|
126 |
+
length_list = []
|
127 |
+
for i, name in enumerator:
|
128 |
+
if count > maxdata:
|
129 |
+
break
|
130 |
+
try:
|
131 |
+
motion = np.load(pjoin(motion_dir, name + ".npy"))
|
132 |
+
if len(motion) < self.min_motion_length or len(motion) >= self.max_motion_length:
|
133 |
+
bad_count += 1
|
134 |
+
continue
|
135 |
+
text_data = []
|
136 |
+
flag = False
|
137 |
+
with cs.open(pjoin(text_dir, name + ".txt")) as f:
|
138 |
+
for line in f.readlines():
|
139 |
+
text_dict = {}
|
140 |
+
line_split = line.strip().split("#")
|
141 |
+
caption = line_split[0]
|
142 |
+
tokens = line_split[1].split(" ")
|
143 |
+
f_tag = float(line_split[2])
|
144 |
+
to_tag = float(line_split[3])
|
145 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
146 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
147 |
+
|
148 |
+
text_dict["caption"] = caption
|
149 |
+
text_dict["tokens"] = tokens
|
150 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
151 |
+
flag = True
|
152 |
+
text_data.append(text_dict)
|
153 |
+
else:
|
154 |
+
try:
|
155 |
+
n_motion = motion[int(f_tag * fps): int(to_tag * fps)]
|
156 |
+
if (len(n_motion)) < self.min_motion_length or \
|
157 |
+
len(n_motion) >= self.max_motion_length:
|
158 |
+
continue
|
159 |
+
new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name
|
160 |
+
while new_name in data_dict:
|
161 |
+
new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name
|
162 |
+
data_dict[new_name] = {
|
163 |
+
"motion": n_motion,
|
164 |
+
"length": len(n_motion),
|
165 |
+
"text": [text_dict],
|
166 |
+
}
|
167 |
+
new_name_list.append(new_name)
|
168 |
+
length_list.append(len(n_motion))
|
169 |
+
except ValueError:
|
170 |
+
print(line_split)
|
171 |
+
print(line_split[2], line_split[3], f_tag, to_tag, name)
|
172 |
+
|
173 |
+
if flag:
|
174 |
+
data_dict[name] = {
|
175 |
+
"motion": motion,
|
176 |
+
"length": len(motion),
|
177 |
+
"text": text_data,
|
178 |
+
}
|
179 |
+
new_name_list.append(name)
|
180 |
+
length_list.append(len(motion))
|
181 |
+
count += 1
|
182 |
+
except Exception as e:
|
183 |
+
print(e)
|
184 |
+
pass
|
185 |
+
|
186 |
+
name_list, length_list = zip(
|
187 |
+
*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
|
188 |
+
|
189 |
+
if not tiny:
|
190 |
+
logger.info(f"Reading {len(self.id_list)} motions from {split_file}.")
|
191 |
+
logger.info(f"Total {len(name_list)} motions are used.")
|
192 |
+
logger.info(f"{bad_count} motion sequences not within the length range of "
|
193 |
+
f"[{self.min_motion_length}, {self.max_motion_length}) are filtered out.")
|
194 |
+
|
195 |
+
self.mean = mean
|
196 |
+
self.std = std
|
197 |
+
|
198 |
+
control_args = kwargs['control_args']
|
199 |
+
self.control_mode = None
|
200 |
+
if os.path.exists(control_args.MEAN_STD_PATH):
|
201 |
+
self.raw_mean = np.load(pjoin(control_args.MEAN_STD_PATH, 'Mean_raw.npy'))
|
202 |
+
self.raw_std = np.load(pjoin(control_args.MEAN_STD_PATH, 'Std_raw.npy'))
|
203 |
+
else:
|
204 |
+
self.raw_mean = self.raw_std = None
|
205 |
+
if not tiny and control_args.CONTROL:
|
206 |
+
self.t_ctrl = control_args.TEMPORAL
|
207 |
+
self.training_control_joints = np.array(control_args.TRAIN_JOINTS)
|
208 |
+
self.testing_control_joints = np.array(control_args.TEST_JOINTS)
|
209 |
+
self.training_density = control_args.TRAIN_DENSITY
|
210 |
+
self.testing_density = control_args.TEST_DENSITY
|
211 |
+
|
212 |
+
self.control_mode = 'val' if ('test' in split_file or 'val' in split_file) else 'train'
|
213 |
+
if self.control_mode == 'train':
|
214 |
+
logger.info(f'Training Control Joints: {self.training_control_joints}')
|
215 |
+
logger.info(f'Training Control Density: {self.training_density}')
|
216 |
+
else:
|
217 |
+
logger.info(f'Testing Control Joints: {self.testing_control_joints}')
|
218 |
+
logger.info(f'Testing Control Density: {self.testing_density}')
|
219 |
+
logger.info(f"Temporal Control: {self.t_ctrl}")
|
220 |
+
|
221 |
+
self.data_dict = data_dict
|
222 |
+
self.name_list = name_list
|
223 |
+
|
224 |
+
def __len__(self) -> int:
|
225 |
+
return len(self.name_list)
|
226 |
+
|
227 |
+
def random_mask(self, joints: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
228 |
+
choose_joint = self.testing_control_joints
|
229 |
+
|
230 |
+
length = joints.shape[0]
|
231 |
+
density = self.testing_density
|
232 |
+
if density in [1, 2, 5]:
|
233 |
+
choose_seq_num = density
|
234 |
+
else:
|
235 |
+
choose_seq_num = int(length * density / 100)
|
236 |
+
|
237 |
+
if self.t_ctrl:
|
238 |
+
choose_seq = np.arange(0, choose_seq_num)
|
239 |
+
else:
|
240 |
+
choose_seq = np.random.choice(length, choose_seq_num, replace=False)
|
241 |
+
choose_seq.sort()
|
242 |
+
|
243 |
+
mask_seq = np.zeros((length, self.njoints, 3))
|
244 |
+
for cj in choose_joint:
|
245 |
+
mask_seq[choose_seq, cj] = 1.0
|
246 |
+
|
247 |
+
joints = (joints - self.raw_mean) / self.raw_std
|
248 |
+
joints = joints * mask_seq
|
249 |
+
return joints, mask_seq
|
250 |
+
|
251 |
+
def random_mask_train(self, joints: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
252 |
+
if self.t_ctrl:
|
253 |
+
choose_joint = self.training_control_joints
|
254 |
+
else:
|
255 |
+
num_joints = len(self.training_control_joints)
|
256 |
+
num_joints_control = 1
|
257 |
+
choose_joint = np.random.choice(num_joints, num_joints_control, replace=False)
|
258 |
+
choose_joint = self.training_control_joints[choose_joint]
|
259 |
+
|
260 |
+
length = joints.shape[0]
|
261 |
+
|
262 |
+
if self.training_density == 'random':
|
263 |
+
choose_seq_num = np.random.choice(length - 1, 1) + 1
|
264 |
+
else:
|
265 |
+
choose_seq_num = int(length * random.uniform(self.training_density[0], self.training_density[1]) / 100)
|
266 |
+
|
267 |
+
if self.t_ctrl:
|
268 |
+
choose_seq = np.arange(0, choose_seq_num)
|
269 |
+
else:
|
270 |
+
choose_seq = np.random.choice(length, choose_seq_num, replace=False)
|
271 |
+
choose_seq.sort()
|
272 |
+
|
273 |
+
mask_seq = np.zeros((length, self.njoints, 3))
|
274 |
+
for cj in choose_joint:
|
275 |
+
mask_seq[choose_seq, cj] = 1
|
276 |
+
|
277 |
+
joints = (joints - self.raw_mean) / self.raw_std
|
278 |
+
joints = joints * mask_seq
|
279 |
+
return joints, mask_seq
|
280 |
+
|
281 |
+
def __getitem__(self, idx: int) -> tuple:
|
282 |
+
data = self.data_dict[self.name_list[idx]]
|
283 |
+
motion, m_length, text_list = data["motion"], data["length"], data["text"]
|
284 |
+
# Randomly select a caption
|
285 |
+
text_data = random.choice(text_list)
|
286 |
+
caption, tokens = text_data["caption"], text_data["tokens"]
|
287 |
+
|
288 |
+
if len(tokens) < self.max_text_len:
|
289 |
+
# pad with "unk"
|
290 |
+
tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
|
291 |
+
sent_len = len(tokens)
|
292 |
+
tokens = tokens + ["unk/OTHER"] * (self.max_text_len + 2 - sent_len)
|
293 |
+
else:
|
294 |
+
# crop
|
295 |
+
tokens = tokens[:self.max_text_len]
|
296 |
+
tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
|
297 |
+
sent_len = len(tokens)
|
298 |
+
pos_one_hots = []
|
299 |
+
word_embeddings = []
|
300 |
+
for token in tokens:
|
301 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
302 |
+
pos_one_hots.append(pos_oh[None, :])
|
303 |
+
word_embeddings.append(word_emb[None, :])
|
304 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
305 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
306 |
+
|
307 |
+
# Crop the motions in to times of 4, and introduce small variations
|
308 |
+
if self.unit_length < 10:
|
309 |
+
coin2 = np.random.choice(["single", "single", "double"])
|
310 |
+
else:
|
311 |
+
coin2 = "single"
|
312 |
+
|
313 |
+
if coin2 == "double":
|
314 |
+
m_length = (m_length // self.unit_length - 1) * self.unit_length
|
315 |
+
elif coin2 == "single":
|
316 |
+
m_length = (m_length // self.unit_length) * self.unit_length
|
317 |
+
idx = random.randint(0, len(motion) - m_length)
|
318 |
+
motion = motion[idx:idx + m_length]
|
319 |
+
|
320 |
+
hint, hint_mask = None, None
|
321 |
+
if self.control_mode is not None:
|
322 |
+
joints = recover_from_ric(torch.from_numpy(motion).float(), self.njoints)
|
323 |
+
joints = joints.numpy()
|
324 |
+
if self.control_mode == 'train':
|
325 |
+
hint, hint_mask = self.random_mask_train(joints)
|
326 |
+
else:
|
327 |
+
hint, hint_mask = self.random_mask(joints)
|
328 |
+
|
329 |
+
if self.padding_to_max:
|
330 |
+
padding = np.zeros((self.max_motion_length - m_length, *hint.shape[1:]))
|
331 |
+
hint = np.concatenate([hint, padding], axis=0)
|
332 |
+
hint_mask = np.concatenate([hint_mask, padding], axis=0)
|
333 |
+
|
334 |
+
"Z Normalization"
|
335 |
+
motion = (motion - self.mean) / self.std
|
336 |
+
|
337 |
+
if self.padding_to_max:
|
338 |
+
padding = np.zeros((self.max_motion_length - m_length, motion.shape[1]))
|
339 |
+
motion = np.concatenate([motion, padding], axis=0)
|
340 |
+
|
341 |
+
return (word_embeddings,
|
342 |
+
pos_one_hots,
|
343 |
+
caption,
|
344 |
+
sent_len,
|
345 |
+
motion,
|
346 |
+
m_length,
|
347 |
+
"_".join(tokens),
|
348 |
+
(hint, hint_mask))
|
mld/data/humanml/scripts/motion_process.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..common.quaternion import qinv, qrot
|
4 |
+
|
5 |
+
|
6 |
+
# Recover global angle and positions for rotation dataset
|
7 |
+
# root_rot_velocity (B, seq_len, 1)
|
8 |
+
# root_linear_velocity (B, seq_len, 2)
|
9 |
+
# root_y (B, seq_len, 1)
|
10 |
+
# ric_data (B, seq_len, (joint_num - 1)*3)
|
11 |
+
# rot_data (B, seq_len, (joint_num - 1)*6)
|
12 |
+
# local_velocity (B, seq_len, joint_num*3)
|
13 |
+
# foot contact (B, seq_len, 4)
|
14 |
+
def recover_root_rot_pos(data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
15 |
+
rot_vel = data[..., 0]
|
16 |
+
r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
|
17 |
+
'''Get Y-axis rotation from rotation velocity'''
|
18 |
+
r_rot_ang[..., 1:] = rot_vel[..., :-1]
|
19 |
+
r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
|
20 |
+
|
21 |
+
r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
|
22 |
+
r_rot_quat[..., 0] = torch.cos(r_rot_ang)
|
23 |
+
r_rot_quat[..., 2] = torch.sin(r_rot_ang)
|
24 |
+
|
25 |
+
r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
|
26 |
+
r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
|
27 |
+
'''Add Y-axis rotation to root position'''
|
28 |
+
r_pos = qrot(qinv(r_rot_quat), r_pos)
|
29 |
+
|
30 |
+
r_pos = torch.cumsum(r_pos, dim=-2)
|
31 |
+
|
32 |
+
r_pos[..., 1] = data[..., 3]
|
33 |
+
return r_rot_quat, r_pos
|
34 |
+
|
35 |
+
|
36 |
+
def recover_from_ric(data: torch.Tensor, joints_num: int) -> torch.Tensor:
|
37 |
+
r_rot_quat, r_pos = recover_root_rot_pos(data)
|
38 |
+
positions = data[..., 4:(joints_num - 1) * 3 + 4]
|
39 |
+
positions = positions.view(positions.shape[:-1] + (-1, 3))
|
40 |
+
|
41 |
+
'''Add Y-axis rotation to local joints'''
|
42 |
+
positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
|
43 |
+
|
44 |
+
'''Add root XZ to joints'''
|
45 |
+
positions[..., 0] += r_pos[..., 0:1]
|
46 |
+
positions[..., 2] += r_pos[..., 2:3]
|
47 |
+
|
48 |
+
'''Concat root and joints'''
|
49 |
+
positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
|
50 |
+
|
51 |
+
return positions
|
mld/data/humanml/utils/__init__.py
ADDED
File without changes
|
mld/data/humanml/utils/paramUtil.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
# Define a kinematic tree for the skeletal structure
|
4 |
+
kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
|
5 |
+
|
6 |
+
kit_raw_offsets = np.array(
|
7 |
+
[
|
8 |
+
[0, 0, 0],
|
9 |
+
[0, 1, 0],
|
10 |
+
[0, 1, 0],
|
11 |
+
[0, 1, 0],
|
12 |
+
[0, 1, 0],
|
13 |
+
[1, 0, 0],
|
14 |
+
[0, -1, 0],
|
15 |
+
[0, -1, 0],
|
16 |
+
[-1, 0, 0],
|
17 |
+
[0, -1, 0],
|
18 |
+
[0, -1, 0],
|
19 |
+
[1, 0, 0],
|
20 |
+
[0, -1, 0],
|
21 |
+
[0, -1, 0],
|
22 |
+
[0, 0, 1],
|
23 |
+
[0, 0, 1],
|
24 |
+
[-1, 0, 0],
|
25 |
+
[0, -1, 0],
|
26 |
+
[0, -1, 0],
|
27 |
+
[0, 0, 1],
|
28 |
+
[0, 0, 1]
|
29 |
+
]
|
30 |
+
)
|
31 |
+
|
32 |
+
t2m_raw_offsets = np.array([[0, 0, 0],
|
33 |
+
[1, 0, 0],
|
34 |
+
[-1, 0, 0],
|
35 |
+
[0, 1, 0],
|
36 |
+
[0, -1, 0],
|
37 |
+
[0, -1, 0],
|
38 |
+
[0, 1, 0],
|
39 |
+
[0, -1, 0],
|
40 |
+
[0, -1, 0],
|
41 |
+
[0, 1, 0],
|
42 |
+
[0, 0, 1],
|
43 |
+
[0, 0, 1],
|
44 |
+
[0, 1, 0],
|
45 |
+
[1, 0, 0],
|
46 |
+
[-1, 0, 0],
|
47 |
+
[0, 0, 1],
|
48 |
+
[0, -1, 0],
|
49 |
+
[0, -1, 0],
|
50 |
+
[0, -1, 0],
|
51 |
+
[0, -1, 0],
|
52 |
+
[0, -1, 0],
|
53 |
+
[0, -1, 0]])
|
54 |
+
|
55 |
+
t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21],
|
56 |
+
[9, 13, 16, 18, 20]]
|
57 |
+
t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
|
58 |
+
t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
|
59 |
+
|
60 |
+
kit_tgt_skel_id = '03950'
|
61 |
+
|
62 |
+
t2m_tgt_skel_id = '000021'
|
mld/data/humanml/utils/plot_script.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from textwrap import wrap
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import mpl_toolkits.mplot3d.axes3d as p3
|
8 |
+
from matplotlib.animation import FuncAnimation
|
9 |
+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
10 |
+
|
11 |
+
import mld.data.humanml.utils.paramUtil as paramUtil
|
12 |
+
|
13 |
+
skeleton = paramUtil.t2m_kinematic_chain
|
14 |
+
|
15 |
+
|
16 |
+
def plot_3d_motion(save_path: str, joints: np.ndarray, title: str,
|
17 |
+
figsize: tuple[int, int] = (3, 3),
|
18 |
+
fps: int = 120, radius: int = 3, kinematic_tree: list = skeleton,
|
19 |
+
hint: Optional[np.ndarray] = None) -> None:
|
20 |
+
|
21 |
+
title = '\n'.join(wrap(title, 20))
|
22 |
+
|
23 |
+
def init():
|
24 |
+
ax.set_xlim3d([-radius / 2, radius / 2])
|
25 |
+
ax.set_ylim3d([0, radius])
|
26 |
+
ax.set_zlim3d([-radius / 3., radius * 2 / 3.])
|
27 |
+
fig.suptitle(title, fontsize=10)
|
28 |
+
ax.grid(b=False)
|
29 |
+
|
30 |
+
def plot_xzPlane(minx, maxx, miny, minz, maxz):
|
31 |
+
# Plot a plane XZ
|
32 |
+
verts = [
|
33 |
+
[minx, miny, minz],
|
34 |
+
[minx, miny, maxz],
|
35 |
+
[maxx, miny, maxz],
|
36 |
+
[maxx, miny, minz]
|
37 |
+
]
|
38 |
+
xz_plane = Poly3DCollection([verts])
|
39 |
+
xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
|
40 |
+
ax.add_collection3d(xz_plane)
|
41 |
+
|
42 |
+
# (seq_len, joints_num, 3)
|
43 |
+
data = joints.copy().reshape(len(joints), -1, 3)
|
44 |
+
|
45 |
+
data *= 1.3 # scale for visualization
|
46 |
+
if hint is not None:
|
47 |
+
mask = hint.sum(-1) != 0
|
48 |
+
hint = hint[mask]
|
49 |
+
hint *= 1.3
|
50 |
+
|
51 |
+
fig = plt.figure(figsize=figsize)
|
52 |
+
plt.tight_layout()
|
53 |
+
ax = p3.Axes3D(fig)
|
54 |
+
init()
|
55 |
+
MINS = data.min(axis=0).min(axis=0)
|
56 |
+
MAXS = data.max(axis=0).max(axis=0)
|
57 |
+
colors = ["#DD5A37", "#D69E00", "#B75A39", "#DD5A37", "#D69E00",
|
58 |
+
"#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00",
|
59 |
+
"#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", ]
|
60 |
+
|
61 |
+
frame_number = data.shape[0]
|
62 |
+
|
63 |
+
height_offset = MINS[1]
|
64 |
+
data[:, :, 1] -= height_offset
|
65 |
+
if hint is not None:
|
66 |
+
hint[..., 1] -= height_offset
|
67 |
+
trajec = data[:, 0, [0, 2]]
|
68 |
+
|
69 |
+
data[..., 0] -= data[:, 0:1, 0]
|
70 |
+
data[..., 2] -= data[:, 0:1, 2]
|
71 |
+
|
72 |
+
def update(index):
|
73 |
+
ax.lines = []
|
74 |
+
ax.collections = []
|
75 |
+
ax.view_init(elev=120, azim=-90)
|
76 |
+
ax.dist = 7.5
|
77 |
+
plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
|
78 |
+
MAXS[2] - trajec[index, 1])
|
79 |
+
|
80 |
+
if hint is not None:
|
81 |
+
ax.scatter(hint[..., 0] - trajec[index, 0], hint[..., 1], hint[..., 2] - trajec[index, 1], color="#80B79A")
|
82 |
+
|
83 |
+
for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
|
84 |
+
if i < 5:
|
85 |
+
linewidth = 4.0
|
86 |
+
else:
|
87 |
+
linewidth = 2.0
|
88 |
+
ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
|
89 |
+
color=color)
|
90 |
+
|
91 |
+
plt.axis('off')
|
92 |
+
ax.set_xticklabels([])
|
93 |
+
ax.set_yticklabels([])
|
94 |
+
ax.set_zticklabels([])
|
95 |
+
|
96 |
+
ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
|
97 |
+
ani.save(save_path, fps=fps)
|
98 |
+
plt.close()
|
mld/data/humanml/utils/word_vectorizer.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from os.path import join as pjoin
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
POS_enumerator = {
|
8 |
+
'VERB': 0,
|
9 |
+
'NOUN': 1,
|
10 |
+
'DET': 2,
|
11 |
+
'ADP': 3,
|
12 |
+
'NUM': 4,
|
13 |
+
'AUX': 5,
|
14 |
+
'PRON': 6,
|
15 |
+
'ADJ': 7,
|
16 |
+
'ADV': 8,
|
17 |
+
'Loc_VIP': 9,
|
18 |
+
'Body_VIP': 10,
|
19 |
+
'Obj_VIP': 11,
|
20 |
+
'Act_VIP': 12,
|
21 |
+
'Desc_VIP': 13,
|
22 |
+
'OTHER': 14
|
23 |
+
}
|
24 |
+
|
25 |
+
Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
|
26 |
+
'up', 'down', 'straight', 'curve')
|
27 |
+
|
28 |
+
Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
|
29 |
+
|
30 |
+
Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
|
31 |
+
|
32 |
+
Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
|
33 |
+
'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
|
34 |
+
'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
|
35 |
+
|
36 |
+
Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
|
37 |
+
'angrily', 'sadly')
|
38 |
+
|
39 |
+
VIP_dict = {
|
40 |
+
'Loc_VIP': Loc_list,
|
41 |
+
'Body_VIP': Body_list,
|
42 |
+
'Obj_VIP': Obj_List,
|
43 |
+
'Act_VIP': Act_list,
|
44 |
+
'Desc_VIP': Desc_list,
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
class WordVectorizer(object):
|
49 |
+
def __init__(self, meta_root: str, prefix: str) -> None:
|
50 |
+
vectors = np.load(pjoin(meta_root, '%s_data.npy' % prefix))
|
51 |
+
words = pickle.load(open(pjoin(meta_root, '%s_words.pkl' % prefix), 'rb'))
|
52 |
+
word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl' % prefix), 'rb'))
|
53 |
+
self.word2vec = {w: vectors[word2idx[w]] for w in words}
|
54 |
+
|
55 |
+
def _get_pos_ohot(self, pos: str) -> np.ndarray:
|
56 |
+
pos_vec = np.zeros(len(POS_enumerator))
|
57 |
+
if pos in POS_enumerator:
|
58 |
+
pos_vec[POS_enumerator[pos]] = 1
|
59 |
+
else:
|
60 |
+
pos_vec[POS_enumerator['OTHER']] = 1
|
61 |
+
return pos_vec
|
62 |
+
|
63 |
+
def __len__(self) -> int:
|
64 |
+
return len(self.word2vec)
|
65 |
+
|
66 |
+
def __getitem__(self, item: str) -> tuple:
|
67 |
+
word, pos = item.split('/')
|
68 |
+
if word in self.word2vec:
|
69 |
+
word_vec = self.word2vec[word]
|
70 |
+
vip_pos = None
|
71 |
+
for key, values in VIP_dict.items():
|
72 |
+
if word in values:
|
73 |
+
vip_pos = key
|
74 |
+
break
|
75 |
+
if vip_pos is not None:
|
76 |
+
pos_vec = self._get_pos_ohot(vip_pos)
|
77 |
+
else:
|
78 |
+
pos_vec = self._get_pos_ohot(pos)
|
79 |
+
else:
|
80 |
+
word_vec = self.word2vec['unk']
|
81 |
+
pos_vec = self._get_pos_ohot('OTHER')
|
82 |
+
return word_vec, pos_vec
|
mld/data/utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from mld.utils.temos_utils import lengths_to_mask
|
4 |
+
|
5 |
+
|
6 |
+
def collate_tensors(batch: list) -> torch.Tensor:
|
7 |
+
dims = batch[0].dim()
|
8 |
+
max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
|
9 |
+
size = (len(batch), ) + tuple(max_size)
|
10 |
+
canvas = batch[0].new_zeros(size=size)
|
11 |
+
for i, b in enumerate(batch):
|
12 |
+
sub_tensor = canvas[i]
|
13 |
+
for d in range(dims):
|
14 |
+
sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
|
15 |
+
sub_tensor.add_(b)
|
16 |
+
return canvas
|
17 |
+
|
18 |
+
|
19 |
+
def mld_collate(batch: list) -> dict:
|
20 |
+
notnone_batches = [b for b in batch if b is not None]
|
21 |
+
notnone_batches.sort(key=lambda x: x[3], reverse=True)
|
22 |
+
adapted_batch = {
|
23 |
+
"motion":
|
24 |
+
collate_tensors([torch.tensor(b[4]).float() for b in notnone_batches]),
|
25 |
+
"text": [b[2] for b in notnone_batches],
|
26 |
+
"length": [b[5] for b in notnone_batches],
|
27 |
+
"word_embs":
|
28 |
+
collate_tensors([torch.tensor(b[0]).float() for b in notnone_batches]),
|
29 |
+
"pos_ohot":
|
30 |
+
collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]),
|
31 |
+
"text_len":
|
32 |
+
collate_tensors([torch.tensor(b[3]) for b in notnone_batches]),
|
33 |
+
"tokens": [b[6] for b in notnone_batches]
|
34 |
+
}
|
35 |
+
|
36 |
+
mask = lengths_to_mask(adapted_batch['length'], adapted_batch['motion'].device, adapted_batch['motion'].shape[1])
|
37 |
+
adapted_batch['mask'] = mask
|
38 |
+
|
39 |
+
# collate trajectory
|
40 |
+
if notnone_batches[0][-1][0] is not None:
|
41 |
+
adapted_batch['hint'] = collate_tensors([torch.tensor(b[-1][0]).float() for b in notnone_batches])
|
42 |
+
adapted_batch['hint_mask'] = collate_tensors([torch.tensor(b[-1][1]).float() for b in notnone_batches])
|
43 |
+
|
44 |
+
return adapted_batch
|
45 |
+
|
46 |
+
|
47 |
+
def mld_collate_motion_only(batch: list) -> dict:
|
48 |
+
batch = {
|
49 |
+
"motion": collate_tensors([torch.tensor(b[0]).float() for b in batch]),
|
50 |
+
"length": [b[1] for b in batch]
|
51 |
+
}
|
52 |
+
return batch
|
mld/launch/__init__.py
ADDED
File without changes
|
mld/launch/blender.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Fix blender path
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
|
6 |
+
sys.path.append(os.path.expanduser("~/.local/lib/python3.9/site-packages"))
|
7 |
+
|
8 |
+
|
9 |
+
# Monkey patch argparse such that
|
10 |
+
# blender / python parsing works
|
11 |
+
def parse_args(self, args=None, namespace=None):
|
12 |
+
if args is not None:
|
13 |
+
return self.parse_args_bak(args=args, namespace=namespace)
|
14 |
+
try:
|
15 |
+
idx = sys.argv.index("--")
|
16 |
+
args = sys.argv[idx + 1:] # the list after '--'
|
17 |
+
except ValueError as e: # '--' not in the list:
|
18 |
+
args = []
|
19 |
+
return self.parse_args_bak(args=args, namespace=namespace)
|
20 |
+
|
21 |
+
|
22 |
+
setattr(ArgumentParser, 'parse_args_bak', ArgumentParser.parse_args)
|
23 |
+
setattr(ArgumentParser, 'parse_args', parse_args)
|
mld/models/__init__.py
ADDED
File without changes
|
mld/models/architectures/__init__.py
ADDED
File without changes
|
mld/models/architectures/dno.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.tensorboard import SummaryWriter
|
6 |
+
|
7 |
+
|
8 |
+
class DNO(object):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
optimize: bool,
|
12 |
+
max_train_steps: int,
|
13 |
+
learning_rate: float,
|
14 |
+
lr_scheduler: str,
|
15 |
+
lr_warmup_steps: int,
|
16 |
+
clip_grad: bool,
|
17 |
+
loss_hint_type: str,
|
18 |
+
loss_diff_penalty: float,
|
19 |
+
loss_correlate_penalty: float,
|
20 |
+
visualize_samples: int,
|
21 |
+
visualize_ske_steps: list[int],
|
22 |
+
output_dir: str
|
23 |
+
) -> None:
|
24 |
+
|
25 |
+
self.optimize = optimize
|
26 |
+
self.max_train_steps = max_train_steps
|
27 |
+
self.learning_rate = learning_rate
|
28 |
+
self.lr_scheduler = lr_scheduler
|
29 |
+
self.lr_warmup_steps = lr_warmup_steps
|
30 |
+
self.clip_grad = clip_grad
|
31 |
+
self.loss_hint_type = loss_hint_type
|
32 |
+
self.loss_diff_penalty = loss_diff_penalty
|
33 |
+
self.loss_correlate_penalty = loss_correlate_penalty
|
34 |
+
|
35 |
+
if loss_hint_type == 'l1':
|
36 |
+
self.loss_hint_func = F.l1_loss
|
37 |
+
elif loss_hint_type == 'l1_smooth':
|
38 |
+
self.loss_hint_func = F.smooth_l1_loss
|
39 |
+
elif loss_hint_type == 'l2':
|
40 |
+
self.loss_hint_func = F.mse_loss
|
41 |
+
else:
|
42 |
+
raise ValueError(f'Invalid loss type: {loss_hint_type}')
|
43 |
+
|
44 |
+
self.visualize_samples = float('inf') if visualize_samples == 'inf' else visualize_samples
|
45 |
+
assert self.visualize_samples >= 0
|
46 |
+
self.visualize_samples_done = 0
|
47 |
+
self.visualize_ske_steps = visualize_ske_steps
|
48 |
+
if len(visualize_ske_steps) > 0:
|
49 |
+
self.vis_dir = os.path.join(output_dir, 'vis_optimize')
|
50 |
+
os.makedirs(self.vis_dir)
|
51 |
+
|
52 |
+
self.writer = None
|
53 |
+
self.output_dir = output_dir
|
54 |
+
if self.visualize_samples > 0:
|
55 |
+
self.writer = SummaryWriter(output_dir)
|
56 |
+
|
57 |
+
@property
|
58 |
+
def do_visualize(self):
|
59 |
+
return self.visualize_samples_done < self.visualize_samples
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def noise_regularize_1d(noise: torch.Tensor, stop_at: int = 2, dim: int = 1) -> torch.Tensor:
|
63 |
+
size = noise.shape[dim]
|
64 |
+
if size & (size - 1) != 0:
|
65 |
+
new_size = 2 ** (size - 1).bit_length()
|
66 |
+
pad = new_size - size
|
67 |
+
pad_shape = list(noise.shape)
|
68 |
+
pad_shape[dim] = pad
|
69 |
+
pad_noise = torch.randn(*pad_shape, device=noise.device)
|
70 |
+
noise = torch.cat([noise, pad_noise], dim=dim)
|
71 |
+
size = noise.shape[dim]
|
72 |
+
|
73 |
+
loss = torch.zeros(noise.shape[0], device=noise.device)
|
74 |
+
while size > stop_at:
|
75 |
+
rolled_noise = torch.roll(noise, shifts=1, dims=dim)
|
76 |
+
loss += (noise * rolled_noise).mean(dim=tuple(range(1, noise.ndim))).pow(2)
|
77 |
+
noise = noise.view(*noise.shape[:dim], size // 2, 2, *noise.shape[dim + 1:]).mean(dim=dim + 1)
|
78 |
+
size //= 2
|
79 |
+
return loss
|
mld/models/architectures/mld_clip.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import AutoModel, AutoTokenizer
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
|
7 |
+
|
8 |
+
class MldTextEncoder(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, modelpath: str, last_hidden_state: bool = False) -> None:
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
if 't5' in modelpath:
|
14 |
+
self.text_model = SentenceTransformer(modelpath)
|
15 |
+
self.tokenizer = self.text_model.tokenizer
|
16 |
+
else:
|
17 |
+
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
|
18 |
+
self.text_model = AutoModel.from_pretrained(modelpath)
|
19 |
+
|
20 |
+
self.max_length = self.tokenizer.model_max_length
|
21 |
+
if "clip" in modelpath:
|
22 |
+
self.text_encoded_dim = self.text_model.config.text_config.hidden_size
|
23 |
+
if last_hidden_state:
|
24 |
+
self.name = "clip_hidden"
|
25 |
+
else:
|
26 |
+
self.name = "clip"
|
27 |
+
elif "bert" in modelpath:
|
28 |
+
self.name = "bert"
|
29 |
+
self.text_encoded_dim = self.text_model.config.hidden_size
|
30 |
+
elif 't5' in modelpath:
|
31 |
+
self.name = 't5'
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Model {modelpath} not supported")
|
34 |
+
|
35 |
+
def forward(self, texts: list[str]) -> torch.Tensor:
|
36 |
+
# get prompt text embeddings
|
37 |
+
if self.name in ["clip", "clip_hidden"]:
|
38 |
+
text_inputs = self.tokenizer(
|
39 |
+
texts,
|
40 |
+
padding="max_length",
|
41 |
+
truncation=True,
|
42 |
+
max_length=self.max_length,
|
43 |
+
return_tensors="pt",
|
44 |
+
)
|
45 |
+
text_input_ids = text_inputs.input_ids
|
46 |
+
# split into max length Clip can handle
|
47 |
+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
48 |
+
text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
|
49 |
+
elif self.name == "bert":
|
50 |
+
text_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
51 |
+
|
52 |
+
if self.name == "clip":
|
53 |
+
# (batch_Size, text_encoded_dim)
|
54 |
+
text_embeddings = self.text_model.get_text_features(
|
55 |
+
text_input_ids.to(self.text_model.device))
|
56 |
+
# (batch_Size, 1, text_encoded_dim)
|
57 |
+
text_embeddings = text_embeddings.unsqueeze(1)
|
58 |
+
elif self.name == "clip_hidden":
|
59 |
+
# (batch_Size, seq_length , text_encoded_dim)
|
60 |
+
text_embeddings = self.text_model.text_model(
|
61 |
+
text_input_ids.to(self.text_model.device)).last_hidden_state
|
62 |
+
elif self.name == "bert":
|
63 |
+
# (batch_Size, seq_length , text_encoded_dim)
|
64 |
+
text_embeddings = self.text_model(
|
65 |
+
**text_inputs.to(self.text_model.device)).last_hidden_state
|
66 |
+
elif self.name == 't5':
|
67 |
+
text_embeddings = self.text_model.encode(texts, show_progress_bar=False, convert_to_tensor=True, batch_size=len(texts))
|
68 |
+
text_embeddings = text_embeddings.unsqueeze(1)
|
69 |
+
else:
|
70 |
+
raise NotImplementedError(f"Model {self.name} not implemented")
|
71 |
+
|
72 |
+
return text_embeddings
|
mld/models/architectures/mld_denoiser.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from mld.models.operator.embeddings import TimestepEmbedding, Timesteps
|
7 |
+
from mld.models.operator.attention import (SkipTransformerEncoder,
|
8 |
+
SkipTransformerDecoder,
|
9 |
+
TransformerDecoder,
|
10 |
+
TransformerDecoderLayer,
|
11 |
+
TransformerEncoder,
|
12 |
+
TransformerEncoderLayer)
|
13 |
+
from mld.models.operator.moe import MoeTransformerEncoderLayer, MoeTransformerDecoderLayer
|
14 |
+
from mld.models.operator.utils import get_clones, get_activation_fn, zero_module
|
15 |
+
from mld.models.operator.position_encoding import build_position_encoding
|
16 |
+
|
17 |
+
|
18 |
+
def load_balancing_loss_func(router_logits: tuple, num_experts: int = 4, topk: int = 2):
|
19 |
+
router_logits = torch.cat(router_logits, dim=0)
|
20 |
+
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1)
|
21 |
+
_, selected_experts = torch.topk(routing_weights, topk, dim=-1)
|
22 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
23 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
24 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
25 |
+
overall_loss = num_experts * torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
26 |
+
return overall_loss
|
27 |
+
|
28 |
+
|
29 |
+
class MldDenoiser(nn.Module):
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
latent_dim: list = [1, 256],
|
33 |
+
hidden_dim: Optional[int] = None,
|
34 |
+
text_dim: int = 768,
|
35 |
+
time_dim: int = 768,
|
36 |
+
ff_size: int = 1024,
|
37 |
+
num_layers: int = 9,
|
38 |
+
num_heads: int = 4,
|
39 |
+
dropout: float = 0.1,
|
40 |
+
normalize_before: bool = False,
|
41 |
+
norm_eps: float = 1e-5,
|
42 |
+
activation: str = "gelu",
|
43 |
+
norm_post: bool = True,
|
44 |
+
activation_post: Optional[str] = None,
|
45 |
+
flip_sin_to_cos: bool = True,
|
46 |
+
freq_shift: float = 0,
|
47 |
+
time_act_fn: str = 'silu',
|
48 |
+
time_post_act_fn: Optional[str] = None,
|
49 |
+
position_embedding: str = "learned",
|
50 |
+
arch: str = "trans_enc",
|
51 |
+
add_mem_pos: bool = True,
|
52 |
+
force_pre_post_proj: bool = False,
|
53 |
+
text_act_fn: str = 'relu',
|
54 |
+
time_cond_proj_dim: Optional[int] = None,
|
55 |
+
zero_init_cond: bool = True,
|
56 |
+
is_controlnet: bool = False,
|
57 |
+
controlnet_embed_dim: Optional[int] = None,
|
58 |
+
controlnet_act_fn: str = 'silu',
|
59 |
+
moe: bool = False,
|
60 |
+
moe_num_experts: int = 4,
|
61 |
+
moe_topk: int = 2,
|
62 |
+
moe_loss_weight: float = 1e-2,
|
63 |
+
moe_jitter_noise: Optional[float] = None
|
64 |
+
) -> None:
|
65 |
+
super(MldDenoiser, self).__init__()
|
66 |
+
|
67 |
+
self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim
|
68 |
+
add_pre_post_proj = force_pre_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1])
|
69 |
+
self.latent_pre = nn.Linear(latent_dim[-1], self.latent_dim) if add_pre_post_proj else nn.Identity()
|
70 |
+
self.latent_post = nn.Linear(self.latent_dim, latent_dim[-1]) if add_pre_post_proj else nn.Identity()
|
71 |
+
|
72 |
+
self.arch = arch
|
73 |
+
self.time_cond_proj_dim = time_cond_proj_dim
|
74 |
+
|
75 |
+
self.moe_num_experts = moe_num_experts
|
76 |
+
self.moe_topk = moe_topk
|
77 |
+
self.moe_loss_weight = moe_loss_weight
|
78 |
+
|
79 |
+
self.time_proj = Timesteps(time_dim, flip_sin_to_cos, freq_shift)
|
80 |
+
self.time_embedding = TimestepEmbedding(time_dim, self.latent_dim, time_act_fn, post_act_fn=time_post_act_fn,
|
81 |
+
cond_proj_dim=time_cond_proj_dim, zero_init_cond=zero_init_cond)
|
82 |
+
self.emb_proj = nn.Sequential(get_activation_fn(text_act_fn), nn.Linear(text_dim, self.latent_dim))
|
83 |
+
|
84 |
+
self.query_pos = build_position_encoding(self.latent_dim, position_embedding=position_embedding)
|
85 |
+
if self.arch == "trans_enc":
|
86 |
+
if moe:
|
87 |
+
encoder_layer = MoeTransformerEncoderLayer(
|
88 |
+
self.latent_dim, num_heads, moe_num_experts, moe_topk, ff_size,
|
89 |
+
dropout, activation, normalize_before, norm_eps, moe_jitter_noise)
|
90 |
+
else:
|
91 |
+
encoder_layer = TransformerEncoderLayer(
|
92 |
+
self.latent_dim, num_heads, ff_size, dropout,
|
93 |
+
activation, normalize_before, norm_eps)
|
94 |
+
|
95 |
+
encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post and not is_controlnet else None
|
96 |
+
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post,
|
97 |
+
is_controlnet=is_controlnet, is_moe=moe)
|
98 |
+
|
99 |
+
elif self.arch == 'trans_dec':
|
100 |
+
if add_mem_pos:
|
101 |
+
self.mem_pos = build_position_encoding(self.latent_dim, position_embedding=position_embedding)
|
102 |
+
else:
|
103 |
+
self.mem_pos = None
|
104 |
+
if moe:
|
105 |
+
decoder_layer = MoeTransformerDecoderLayer(
|
106 |
+
self.latent_dim, num_heads, moe_num_experts, moe_topk, ff_size,
|
107 |
+
dropout, activation, normalize_before, norm_eps, moe_jitter_noise)
|
108 |
+
else:
|
109 |
+
decoder_layer = TransformerDecoderLayer(
|
110 |
+
self.latent_dim, num_heads, ff_size, dropout,
|
111 |
+
activation, normalize_before, norm_eps)
|
112 |
+
|
113 |
+
decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post and not is_controlnet else None
|
114 |
+
self.decoder = SkipTransformerDecoder(decoder_layer, num_layers, decoder_norm, activation_post,
|
115 |
+
is_controlnet=is_controlnet, is_moe=moe)
|
116 |
+
else:
|
117 |
+
raise ValueError(f"Not supported architecture: {self.arch}!")
|
118 |
+
|
119 |
+
self.is_controlnet = is_controlnet
|
120 |
+
if self.is_controlnet:
|
121 |
+
embed_dim = controlnet_embed_dim if controlnet_embed_dim is not None else self.latent_dim
|
122 |
+
modules = [
|
123 |
+
nn.Linear(latent_dim[-1], embed_dim),
|
124 |
+
get_activation_fn(controlnet_act_fn) if controlnet_act_fn else None,
|
125 |
+
nn.Linear(embed_dim, embed_dim),
|
126 |
+
get_activation_fn(controlnet_act_fn) if controlnet_act_fn else None,
|
127 |
+
zero_module(nn.Linear(embed_dim, latent_dim[-1]))
|
128 |
+
]
|
129 |
+
self.controlnet_cond_embedding = nn.Sequential(*[m for m in modules if m is not None])
|
130 |
+
|
131 |
+
self.controlnet_down_mid_blocks = nn.ModuleList([
|
132 |
+
zero_module(nn.Linear(self.latent_dim, self.latent_dim)) for _ in range(num_layers)])
|
133 |
+
|
134 |
+
def forward(self,
|
135 |
+
sample: torch.Tensor,
|
136 |
+
timestep: torch.Tensor,
|
137 |
+
encoder_hidden_states: torch.Tensor,
|
138 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
139 |
+
controlnet_cond: Optional[torch.Tensor] = None,
|
140 |
+
controlnet_residuals: Optional[list[torch.Tensor]] = None
|
141 |
+
) -> tuple:
|
142 |
+
|
143 |
+
# 0. check if controlnet
|
144 |
+
if self.is_controlnet:
|
145 |
+
sample = sample + self.controlnet_cond_embedding(controlnet_cond)
|
146 |
+
|
147 |
+
# 1. dimension matching (pre)
|
148 |
+
sample = sample.permute(1, 0, 2)
|
149 |
+
sample = self.latent_pre(sample)
|
150 |
+
|
151 |
+
# 2. time_embedding
|
152 |
+
timesteps = timestep.expand(sample.shape[1]).clone()
|
153 |
+
time_emb = self.time_proj(timesteps)
|
154 |
+
time_emb = time_emb.to(dtype=sample.dtype)
|
155 |
+
# [1, bs, latent_dim] <= [bs, latent_dim]
|
156 |
+
time_emb = self.time_embedding(time_emb, timestep_cond).unsqueeze(0)
|
157 |
+
|
158 |
+
# 3. condition + time embedding
|
159 |
+
# text_emb [seq_len, batch_size, text_dim] <= [batch_size, seq_len, text_dim]
|
160 |
+
encoder_hidden_states = encoder_hidden_states.permute(1, 0, 2)
|
161 |
+
# text embedding projection
|
162 |
+
text_emb_latent = self.emb_proj(encoder_hidden_states)
|
163 |
+
emb_latent = torch.cat((time_emb, text_emb_latent), 0)
|
164 |
+
|
165 |
+
# 4. transformer
|
166 |
+
if self.arch == "trans_enc":
|
167 |
+
xseq = torch.cat((sample, emb_latent), axis=0)
|
168 |
+
xseq = self.query_pos(xseq)
|
169 |
+
tokens, intermediates, router_logits = self.encoder(xseq, controlnet_residuals=controlnet_residuals)
|
170 |
+
elif self.arch == 'trans_dec':
|
171 |
+
sample = self.query_pos(sample)
|
172 |
+
if self.mem_pos:
|
173 |
+
emb_latent = self.mem_pos(emb_latent)
|
174 |
+
tokens, intermediates, router_logits = self.decoder(sample, emb_latent,
|
175 |
+
controlnet_residuals=controlnet_residuals)
|
176 |
+
else:
|
177 |
+
raise TypeError(f"{self.arch} is not supported")
|
178 |
+
|
179 |
+
router_loss = None
|
180 |
+
if router_logits is not None:
|
181 |
+
router_loss = load_balancing_loss_func(router_logits, self.moe_num_experts, self.moe_topk)
|
182 |
+
router_loss = self.moe_loss_weight * router_loss
|
183 |
+
|
184 |
+
if self.is_controlnet:
|
185 |
+
control_res_samples = []
|
186 |
+
for res, block in zip(intermediates, self.controlnet_down_mid_blocks):
|
187 |
+
r = block(res)
|
188 |
+
control_res_samples.append(r)
|
189 |
+
return control_res_samples, router_loss
|
190 |
+
elif self.arch == "trans_enc":
|
191 |
+
sample = tokens[:sample.shape[0]]
|
192 |
+
elif self.arch == 'trans_dec':
|
193 |
+
sample = tokens
|
194 |
+
else:
|
195 |
+
raise TypeError(f"{self.arch} is not supported")
|
196 |
+
|
197 |
+
# 5. dimension matching (post)
|
198 |
+
sample = self.latent_post(sample)
|
199 |
+
sample = sample.permute(1, 0, 2)
|
200 |
+
return sample, router_loss
|
mld/models/architectures/mld_traj_encoder.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from mld.models.operator.attention import SkipTransformerEncoder, TransformerEncoderLayer
|
7 |
+
from mld.models.operator.position_encoding import build_position_encoding
|
8 |
+
|
9 |
+
|
10 |
+
class MldTrajEncoder(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
nfeats: int,
|
14 |
+
latent_dim: list = [1, 256],
|
15 |
+
hidden_dim: Optional[int] = None,
|
16 |
+
force_post_proj: bool = False,
|
17 |
+
ff_size: int = 1024,
|
18 |
+
num_layers: int = 9,
|
19 |
+
num_heads: int = 4,
|
20 |
+
dropout: float = 0.1,
|
21 |
+
normalize_before: bool = False,
|
22 |
+
norm_eps: float = 1e-5,
|
23 |
+
activation: str = "gelu",
|
24 |
+
norm_post: bool = True,
|
25 |
+
activation_post: Optional[str] = None,
|
26 |
+
position_embedding: str = "learned") -> None:
|
27 |
+
super(MldTrajEncoder, self).__init__()
|
28 |
+
|
29 |
+
self.latent_size = latent_dim[0]
|
30 |
+
self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim
|
31 |
+
add_post_proj = force_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1])
|
32 |
+
self.latent_proj = nn.Linear(self.latent_dim, latent_dim[-1]) if add_post_proj else nn.Identity()
|
33 |
+
|
34 |
+
self.skel_embedding = nn.Linear(nfeats * 3, self.latent_dim)
|
35 |
+
|
36 |
+
self.query_pos_encoder = build_position_encoding(
|
37 |
+
self.latent_dim, position_embedding=position_embedding)
|
38 |
+
|
39 |
+
encoder_layer = TransformerEncoderLayer(
|
40 |
+
self.latent_dim,
|
41 |
+
num_heads,
|
42 |
+
ff_size,
|
43 |
+
dropout,
|
44 |
+
activation,
|
45 |
+
normalize_before,
|
46 |
+
norm_eps
|
47 |
+
)
|
48 |
+
encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
|
49 |
+
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post)
|
50 |
+
self.global_motion_token = nn.Parameter(torch.randn(self.latent_size, self.latent_dim))
|
51 |
+
|
52 |
+
def forward(self, features: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
53 |
+
bs, nframes, nfeats = features.shape
|
54 |
+
x = self.skel_embedding(features)
|
55 |
+
x = x.permute(1, 0, 2)
|
56 |
+
dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
|
57 |
+
dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device)
|
58 |
+
aug_mask = torch.cat((dist_masks, mask), 1)
|
59 |
+
xseq = torch.cat((dist, x), 0)
|
60 |
+
xseq = self.query_pos_encoder(xseq)
|
61 |
+
global_token = self.encoder(xseq, src_key_padding_mask=~aug_mask)[0][:dist.shape[0]]
|
62 |
+
global_token = self.latent_proj(global_token)
|
63 |
+
global_token = global_token.permute(1, 0, 2)
|
64 |
+
return global_token
|
mld/models/architectures/mld_vae.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.distributions.distribution import Distribution
|
6 |
+
|
7 |
+
from mld.models.operator.attention import (
|
8 |
+
SkipTransformerEncoder,
|
9 |
+
SkipTransformerDecoder,
|
10 |
+
TransformerDecoder,
|
11 |
+
TransformerDecoderLayer,
|
12 |
+
TransformerEncoder,
|
13 |
+
TransformerEncoderLayer
|
14 |
+
)
|
15 |
+
from mld.models.operator.position_encoding import build_position_encoding
|
16 |
+
|
17 |
+
|
18 |
+
class MldVae(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
nfeats: int,
|
22 |
+
latent_dim: list = [1, 256],
|
23 |
+
hidden_dim: Optional[int] = None,
|
24 |
+
force_pre_post_proj: bool = False,
|
25 |
+
ff_size: int = 1024,
|
26 |
+
num_layers: int = 9,
|
27 |
+
num_heads: int = 4,
|
28 |
+
dropout: float = 0.1,
|
29 |
+
arch: str = "encoder_decoder",
|
30 |
+
normalize_before: bool = False,
|
31 |
+
norm_eps: float = 1e-5,
|
32 |
+
activation: str = "gelu",
|
33 |
+
norm_post: bool = True,
|
34 |
+
activation_post: Optional[str] = None,
|
35 |
+
position_embedding: str = "learned") -> None:
|
36 |
+
super(MldVae, self).__init__()
|
37 |
+
|
38 |
+
self.latent_size = latent_dim[0]
|
39 |
+
self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim
|
40 |
+
add_pre_post_proj = force_pre_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1])
|
41 |
+
self.latent_pre = nn.Linear(self.latent_dim, latent_dim[-1]) if add_pre_post_proj else nn.Identity()
|
42 |
+
self.latent_post = nn.Linear(latent_dim[-1], self.latent_dim) if add_pre_post_proj else nn.Identity()
|
43 |
+
|
44 |
+
self.arch = arch
|
45 |
+
|
46 |
+
self.query_pos_encoder = build_position_encoding(
|
47 |
+
self.latent_dim, position_embedding=position_embedding)
|
48 |
+
|
49 |
+
encoder_layer = TransformerEncoderLayer(
|
50 |
+
self.latent_dim,
|
51 |
+
num_heads,
|
52 |
+
ff_size,
|
53 |
+
dropout,
|
54 |
+
activation,
|
55 |
+
normalize_before,
|
56 |
+
norm_eps
|
57 |
+
)
|
58 |
+
encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
|
59 |
+
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post)
|
60 |
+
|
61 |
+
if self.arch == "all_encoder":
|
62 |
+
decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
|
63 |
+
self.decoder = SkipTransformerEncoder(encoder_layer, num_layers, decoder_norm, activation_post)
|
64 |
+
elif self.arch == 'encoder_decoder':
|
65 |
+
self.query_pos_decoder = build_position_encoding(
|
66 |
+
self.latent_dim, position_embedding=position_embedding)
|
67 |
+
|
68 |
+
decoder_layer = TransformerDecoderLayer(
|
69 |
+
self.latent_dim,
|
70 |
+
num_heads,
|
71 |
+
ff_size,
|
72 |
+
dropout,
|
73 |
+
activation,
|
74 |
+
normalize_before,
|
75 |
+
norm_eps
|
76 |
+
)
|
77 |
+
decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
|
78 |
+
self.decoder = SkipTransformerDecoder(decoder_layer, num_layers, decoder_norm, activation_post)
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Not support architecture: {self.arch}!")
|
81 |
+
|
82 |
+
self.global_motion_token = nn.Parameter(torch.randn(self.latent_size * 2, self.latent_dim))
|
83 |
+
self.skel_embedding = nn.Linear(nfeats, self.latent_dim)
|
84 |
+
self.final_layer = nn.Linear(self.latent_dim, nfeats)
|
85 |
+
|
86 |
+
def forward(self, features: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, Distribution]:
|
87 |
+
z, dist = self.encode(features, mask)
|
88 |
+
feats_rst = self.decode(z, mask)
|
89 |
+
return feats_rst, z, dist
|
90 |
+
|
91 |
+
def encode(self, features: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, Distribution]:
|
92 |
+
bs, nframes, nfeats = features.shape
|
93 |
+
x = self.skel_embedding(features)
|
94 |
+
x = x.permute(1, 0, 2)
|
95 |
+
dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
|
96 |
+
dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device)
|
97 |
+
aug_mask = torch.cat((dist_masks, mask), 1)
|
98 |
+
xseq = torch.cat((dist, x), 0)
|
99 |
+
|
100 |
+
xseq = self.query_pos_encoder(xseq)
|
101 |
+
dist = self.encoder(xseq, src_key_padding_mask=~aug_mask)[0][:dist.shape[0]]
|
102 |
+
dist = self.latent_pre(dist)
|
103 |
+
|
104 |
+
mu = dist[0:self.latent_size, ...]
|
105 |
+
logvar = dist[self.latent_size:, ...]
|
106 |
+
|
107 |
+
std = logvar.exp().pow(0.5)
|
108 |
+
dist = torch.distributions.Normal(mu, std)
|
109 |
+
latent = dist.rsample()
|
110 |
+
# [latent_dim[0], batch_size, latent_dim] -> [batch_size, latent_dim[0], latent_dim[1]]
|
111 |
+
latent = latent.permute(1, 0, 2)
|
112 |
+
return latent, dist
|
113 |
+
|
114 |
+
def decode(self, z: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
115 |
+
# [batch_size, latent_dim[0], latent_dim[1]] -> [latent_dim[0], batch_size, latent_dim[1]]
|
116 |
+
z = self.latent_post(z)
|
117 |
+
z = z.permute(1, 0, 2)
|
118 |
+
bs, nframes = mask.shape
|
119 |
+
queries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)
|
120 |
+
|
121 |
+
if self.arch == "all_encoder":
|
122 |
+
xseq = torch.cat((z, queries), axis=0)
|
123 |
+
z_mask = torch.ones((bs, self.latent_size), dtype=torch.bool, device=z.device)
|
124 |
+
aug_mask = torch.cat((z_mask, mask), axis=1)
|
125 |
+
xseq = self.query_pos_decoder(xseq)
|
126 |
+
output = self.decoder(xseq, src_key_padding_mask=~aug_mask)[0][z.shape[0]:]
|
127 |
+
elif self.arch == "encoder_decoder":
|
128 |
+
queries = self.query_pos_decoder(queries)
|
129 |
+
output = self.decoder(tgt=queries, memory=z, tgt_key_padding_mask=~mask)[0]
|
130 |
+
else:
|
131 |
+
raise ValueError(f"Not support architecture: {self.arch}!")
|
132 |
+
|
133 |
+
output = self.final_layer(output)
|
134 |
+
output[~mask.T] = 0
|
135 |
+
feats = output.permute(1, 0, 2)
|
136 |
+
return feats
|