MotionDiffuse / trainers /ddpm_trainer.py
root
initial commit
12deb01
raw
history blame
7.49 kB
import torch
import torch.nn.functional as F
import random
import time
from models.transformer import MotionTransformer
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from collections import OrderedDict
from utils.utils import print_current_loss
from os.path import join as pjoin
import codecs as cs
import torch.distributed as dist
from mmcv.runner import get_dist_info
from models.gaussian_diffusion import (
GaussianDiffusion,
get_named_beta_schedule,
create_named_schedule_sampler,
ModelMeanType,
ModelVarType,
LossType
)
from datasets import build_dataloader
class DDPMTrainer(object):
def __init__(self, args, encoder):
self.opt = args
self.device = args.device
self.encoder = encoder
self.diffusion_steps = args.diffusion_steps
sampler = 'uniform'
beta_scheduler = 'linear'
betas = get_named_beta_schedule(beta_scheduler, self.diffusion_steps)
self.diffusion = GaussianDiffusion(
betas=betas,
model_mean_type=ModelMeanType.EPSILON,
model_var_type=ModelVarType.FIXED_SMALL,
loss_type=LossType.MSE
)
self.sampler = create_named_schedule_sampler(sampler, self.diffusion)
self.sampler_name = sampler
if args.is_train:
self.mse_criterion = torch.nn.MSELoss(reduction='none')
self.to(self.device)
@staticmethod
def zero_grad(opt_list):
for opt in opt_list:
opt.zero_grad()
@staticmethod
def clip_norm(network_list):
for network in network_list:
clip_grad_norm_(network.parameters(), 0.5)
@staticmethod
def step(opt_list):
for opt in opt_list:
opt.step()
def forward(self, batch_data, eval_mode=False):
caption, motions, m_lens = batch_data
motions = motions.detach().to(self.device).float()
self.caption = caption
self.motions = motions
x_start = motions
B, T = x_start.shape[:2]
cur_len = torch.LongTensor([min(T, m_len) for m_len in m_lens]).to(self.device)
t, _ = self.sampler.sample(B, x_start.device)
output = self.diffusion.training_losses(
model=self.encoder,
x_start=x_start,
t=t,
model_kwargs={"text": caption, "length": cur_len}
)
self.real_noise = output['target']
self.fake_noise = output['pred']
try:
self.src_mask = self.encoder.module.generate_src_mask(T, cur_len).to(x_start.device)
except:
self.src_mask = self.encoder.generate_src_mask(T, cur_len).to(x_start.device)
def generate_batch(self, caption, m_lens, dim_pose):
xf_proj, xf_out = self.encoder.encode_text(caption, self.device)
B = len(caption)
T = min(m_lens.max(), self.encoder.num_frames)
output = self.diffusion.p_sample_loop(
self.encoder,
(B, T, dim_pose),
clip_denoised=False,
progress=True,
model_kwargs={
'xf_proj': xf_proj,
'xf_out': xf_out,
'length': m_lens
})
return output
def generate(self, caption, m_lens, dim_pose, batch_size=1024):
N = len(caption)
cur_idx = 0
self.encoder.eval()
all_output = []
while cur_idx < N:
if cur_idx + batch_size >= N:
batch_caption = caption[cur_idx:]
batch_m_lens = m_lens[cur_idx:]
else:
batch_caption = caption[cur_idx: cur_idx + batch_size]
batch_m_lens = m_lens[cur_idx: cur_idx + batch_size]
output = self.generate_batch(batch_caption, batch_m_lens, dim_pose)
B = output.shape[0]
for i in range(B):
all_output.append(output[i])
cur_idx += batch_size
return all_output
def backward_G(self):
loss_mot_rec = self.mse_criterion(self.fake_noise, self.real_noise).mean(dim=-1)
loss_mot_rec = (loss_mot_rec * self.src_mask).sum() / self.src_mask.sum()
self.loss_mot_rec = loss_mot_rec
loss_logs = OrderedDict({})
loss_logs['loss_mot_rec'] = self.loss_mot_rec.item()
return loss_logs
def update(self):
self.zero_grad([self.opt_encoder])
loss_logs = self.backward_G()
self.loss_mot_rec.backward()
self.clip_norm([self.encoder])
self.step([self.opt_encoder])
return loss_logs
def to(self, device):
if self.opt.is_train:
self.mse_criterion.to(device)
self.encoder = self.encoder.to(device)
def train_mode(self):
self.encoder.train()
def eval_mode(self):
self.encoder.eval()
def save(self, file_name, ep, total_it):
state = {
'opt_encoder': self.opt_encoder.state_dict(),
'ep': ep,
'total_it': total_it
}
try:
state['encoder'] = self.encoder.module.state_dict()
except:
state['encoder'] = self.encoder.state_dict()
torch.save(state, file_name)
return
def load(self, model_dir):
checkpoint = torch.load(model_dir, map_location=self.device)
if self.opt.is_train:
self.opt_encoder.load_state_dict(checkpoint['opt_encoder'])
self.encoder.load_state_dict(checkpoint['encoder'], strict=True)
return checkpoint['ep'], checkpoint.get('total_it', 0)
def train(self, train_dataset):
rank, world_size = get_dist_info()
self.to(self.device)
self.opt_encoder = optim.Adam(self.encoder.parameters(), lr=self.opt.lr)
it = 0
cur_epoch = 0
if self.opt.is_continue:
model_dir = pjoin(self.opt.model_dir, 'latest.tar')
cur_epoch, it = self.load(model_dir)
start_time = time.time()
train_loader = build_dataloader(
train_dataset,
samples_per_gpu=self.opt.batch_size,
drop_last=True,
workers_per_gpu=4,
shuffle=True)
logs = OrderedDict()
for epoch in range(cur_epoch, self.opt.num_epochs):
self.train_mode()
for i, batch_data in enumerate(train_loader):
self.forward(batch_data)
log_dict = self.update()
for k, v in log_dict.items():
if k not in logs:
logs[k] = v
else:
logs[k] += v
it += 1
if it % self.opt.log_every == 0 and rank == 0:
mean_loss = OrderedDict({})
for tag, value in logs.items():
mean_loss[tag] = value / self.opt.log_every
logs = OrderedDict()
print_current_loss(start_time, it, mean_loss, epoch, inner_iter=i)
if it % self.opt.save_latest == 0 and rank == 0:
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
if rank == 0:
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
if epoch % self.opt.save_every_e == 0 and rank == 0:
self.save(pjoin(self.opt.model_dir, 'ckpt_e%03d.tar'%(epoch)),
epoch, total_it=it)