Elle McFarlane commited on
Commit
a02a7e6
·
0 Parent(s):

add trainers

Browse files
text2motion/trainers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .ddpm_trainer import DDPMTrainer
2
+
3
+
4
+ __all__ = ['DDPMTrainer']
text2motion/trainers/ddpm_trainer.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from collections import OrderedDict
3
+ from os.path import join as pjoin
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.optim as optim
8
+ from mmcv.runner import get_dist_info
9
+ from torch.nn.utils import clip_grad_norm_
10
+
11
+ # import wandb
12
+ from datasets import build_dataloader
13
+ from mmcv.runner import get_dist_info
14
+ from models.gaussian_diffusion import (GaussianDiffusion, LossType,
15
+ ModelMeanType, ModelVarType,
16
+ create_named_schedule_sampler,
17
+ get_named_beta_schedule)
18
+ from utils.utils import print_current_loss
19
+
20
+
21
+
22
+
23
+ class DDPMTrainer(object):
24
+
25
+ def __init__(self, args, encoder):
26
+ self.opt = args
27
+ self.device = args.device
28
+ self.encoder = encoder # MotionTransformer from train.build_models
29
+ self.diffusion_steps = args.diffusion_steps
30
+ sampler = 'uniform'
31
+ beta_scheduler = 'linear'
32
+ betas = get_named_beta_schedule(beta_scheduler, self.diffusion_steps)
33
+ self.diffusion = GaussianDiffusion(
34
+ betas=betas,
35
+ model_mean_type=ModelMeanType.EPSILON,
36
+ model_var_type=ModelVarType.FIXED_SMALL,
37
+ loss_type=LossType.MSE
38
+ )
39
+ self.sampler = create_named_schedule_sampler(sampler, self.diffusion)
40
+ self.sampler_name = sampler
41
+
42
+ if args.is_train:
43
+ self.mse_criterion = torch.nn.MSELoss(reduction='none')
44
+ self.to(self.device)
45
+
46
+ @staticmethod
47
+ def zero_grad(opt_list):
48
+ for opt in opt_list:
49
+ opt.zero_grad()
50
+
51
+ @staticmethod
52
+ def clip_norm(network_list):
53
+ for network in network_list:
54
+ clip_grad_norm_(network.parameters(), 0.5)
55
+
56
+ @staticmethod
57
+ def step(opt_list):
58
+ for opt in opt_list:
59
+ opt.step()
60
+
61
+ def forward(self, batch_data, eval_mode=False):
62
+ caption, motions, m_lens = batch_data
63
+ motions = motions.detach().to(self.device).float()
64
+
65
+ self.caption = caption
66
+ self.motions = motions
67
+ x_start = motions
68
+ B, T = x_start.shape[:2]
69
+ cur_len = torch.LongTensor([min(T, m_len) for m_len in m_lens]).to(self.device)
70
+ t, _ = self.sampler.sample(B, x_start.device)
71
+ output = self.diffusion.training_losses(
72
+ model=self.encoder, # MotionDiffusion is encoder
73
+ x_start=x_start,
74
+ t=t,
75
+ model_kwargs={"text": caption, "length": cur_len}
76
+ )
77
+ self.real_noise = output['target']
78
+ self.fake_noise = output['pred']
79
+ try:
80
+ self.src_mask = self.encoder.module.generate_src_mask(T, cur_len).to(x_start.device)
81
+ except:
82
+ self.src_mask = self.encoder.generate_src_mask(T, cur_len).to(x_start.device)
83
+
84
+ def generate_batch(self, caption, m_lens, dim_pose):
85
+ # import pdb; pdb.set_trace()
86
+ # xf_proj they explain here https://github.com/mingyuan-zhang/MotionDiffuse/issues/10
87
+ # is an overall semantic feature to represent given language description,
88
+ # a common choice in NLP and motion gen & GLIDE is to use last token to represent overall characteristics
89
+ xf_proj, xf_out = self.encoder.encode_text(caption, self.device)
90
+
91
+ B = len(caption)
92
+ T = min(m_lens.max(), self.encoder.num_frames)
93
+ output = self.diffusion.p_sample_loop(
94
+ self.encoder,
95
+ (B, T, dim_pose),
96
+ clip_denoised=False,
97
+ progress=True,
98
+ model_kwargs={
99
+ 'xf_proj': xf_proj,
100
+ 'xf_out': xf_out,
101
+ 'length': m_lens
102
+ })
103
+ return output
104
+
105
+ def generate(self, caption, m_lens, dim_pose, batch_size=1024):
106
+ N = len(caption)
107
+ cur_idx = 0
108
+ self.encoder.eval()
109
+ all_output = []
110
+ while cur_idx < N:
111
+ if cur_idx + batch_size >= N:
112
+ batch_caption = caption[cur_idx:]
113
+ batch_m_lens = m_lens[cur_idx:]
114
+ else:
115
+ batch_caption = caption[cur_idx: cur_idx + batch_size]
116
+ batch_m_lens = m_lens[cur_idx: cur_idx + batch_size]
117
+ output = self.generate_batch(batch_caption, batch_m_lens, dim_pose)
118
+ B = output.shape[0]
119
+
120
+ for i in range(B):
121
+ all_output.append(output[i])
122
+ cur_idx += batch_size
123
+ return all_output
124
+
125
+ def backward_G(self):
126
+ loss_mot_rec = self.mse_criterion(self.fake_noise, self.real_noise).mean(dim=-1)
127
+ loss_mot_rec = (loss_mot_rec * self.src_mask).sum() / self.src_mask.sum()
128
+ self.loss_mot_rec = loss_mot_rec
129
+ loss_logs = OrderedDict({})
130
+ loss_logs['loss_mot_rec'] = self.loss_mot_rec.item()
131
+ return loss_logs
132
+
133
+ def update(self):
134
+ self.zero_grad([self.opt_encoder])
135
+ loss_logs = self.backward_G()
136
+ self.loss_mot_rec.backward()
137
+ self.clip_norm([self.encoder])
138
+ self.step([self.opt_encoder])
139
+
140
+ return loss_logs
141
+
142
+ def to(self, device):
143
+ if self.opt.is_train:
144
+ self.mse_criterion.to(device)
145
+ self.encoder = self.encoder.to(device)
146
+
147
+ def train_mode(self):
148
+ self.encoder.train()
149
+
150
+ def eval_mode(self):
151
+ self.encoder.eval()
152
+
153
+ def save(self, file_name, ep, total_it):
154
+ state = {
155
+ 'opt_encoder': self.opt_encoder.state_dict(),
156
+ 'ep': ep,
157
+ 'total_it': total_it
158
+ }
159
+ try:
160
+ state['encoder'] = self.encoder.module.state_dict()
161
+ except:
162
+ state['encoder'] = self.encoder.state_dict()
163
+ torch.save(state, file_name)
164
+ return
165
+
166
+ def load(self, model_dir):
167
+ print(f'{self.__class__.__name__} loading model {model_dir}')
168
+ checkpoint = torch.load(model_dir, map_location=self.device)
169
+ if self.opt.is_train:
170
+ self.opt_encoder.load_state_dict(checkpoint['opt_encoder'])
171
+ self.encoder.load_state_dict(checkpoint['encoder'], strict=True)
172
+ return checkpoint['ep'], checkpoint.get('total_it', 0)
173
+
174
+ def train(self, train_dataset):
175
+ rank, world_size = get_dist_info()
176
+ self.to(self.device)
177
+ self.opt_encoder = optim.Adam(self.encoder.parameters(), lr=self.opt.lr)
178
+ it = 0
179
+ cur_epoch = 0
180
+ if self.opt.is_continue:
181
+ # model_dir = pjoin(self.opt.model_dir, 'latest.tar')
182
+ model_dir = pjoin(self.opt.model_dir, f'{self.opt.model_name}.tar')
183
+ cur_epoch, it = self.load(model_dir)
184
+
185
+ start_time = time.time()
186
+
187
+ train_loader = build_dataloader(
188
+ train_dataset,
189
+ samples_per_gpu=self.opt.batch_size,
190
+ drop_last=True,
191
+ workers_per_gpu=4,
192
+ shuffle=True,
193
+ dist=self.opt.distributed,
194
+ num_gpus=len(self.opt.gpu_id))
195
+
196
+ logs = OrderedDict()
197
+ for epoch in range(cur_epoch, self.opt.num_epochs):
198
+ print(f"epoch {epoch}, logging to wandb every {self.opt.log_every} iters")
199
+ self.train_mode()
200
+ # import pdb; pdb.set_trace()
201
+ for i, batch_data in enumerate(train_loader):
202
+ print(f"epoch {epoch}, batch {i}")
203
+ self.forward(batch_data)
204
+ log_dict = self.update()
205
+ for k, v in log_dict.items():
206
+ if k not in logs:
207
+ logs[k] = v
208
+ else:
209
+ logs[k] += v
210
+ it += 1
211
+ if it % self.opt.log_every == 0 and rank == 0:
212
+ mean_loss = OrderedDict({})
213
+ for tag, value in logs.items():
214
+ mean_loss[tag] = value / self.opt.log_every
215
+ logs = OrderedDict()
216
+ print_current_loss(start_time, it, mean_loss, epoch, inner_iter=i)
217
+ if self.opt.use_wandb:
218
+ print(f"logging loss w wandb {mean_loss['loss_mot_rec']:.4f}")
219
+ perf_dict = {
220
+ 'loss_mot_rec': mean_loss['loss_mot_rec']
221
+ }
222
+ wandb.log(perf_dict)
223
+ # TODO (elmc): evaluate!
224
+ # if it % self.opt.eval_every_e == 0 and rank == 0:
225
+ # self.eval_mode()
226
+ # print(f"noise shape {self.real_noise.shape}")
227
+ # print(f"real noise: {self.real_noise}")
228
+ # print(f"fake noise: {self.fake_noise}")
229
+ # save real noise
230
+ # noise_path = f"{self.opt.noise_dir}/{epoch}_{i}.npy"
231
+ # np.save(noise_path, self.real_noise.cpu().numpy())
232
+ if it % self.opt.save_latest == 0 and rank == 0:
233
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
234
+
235
+ if rank == 0:
236
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
237
+
238
+ if epoch % self.opt.save_every_e == 0 and rank == 0:
239
+ self.save(pjoin(self.opt.model_dir, 'ckpt_e%03d.tar'%(epoch)),
240
+ epoch, total_it=it)