Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
import random | |
import time | |
from models.transformer import MotionTransformer | |
from torch.utils.data import DataLoader | |
import torch.optim as optim | |
from torch.nn.utils import clip_grad_norm_ | |
from collections import OrderedDict | |
from utils.utils import print_current_loss | |
from os.path import join as pjoin | |
import codecs as cs | |
import torch.distributed as dist | |
from mmcv.runner import get_dist_info | |
from models.gaussian_diffusion import ( | |
GaussianDiffusion, | |
get_named_beta_schedule, | |
create_named_schedule_sampler, | |
ModelMeanType, | |
ModelVarType, | |
LossType | |
) | |
from datasets import build_dataloader | |
class DDPMTrainer(object): | |
def __init__(self, args, encoder): | |
self.opt = args | |
self.device = args.device | |
self.encoder = encoder | |
self.diffusion_steps = args.diffusion_steps | |
sampler = 'uniform' | |
beta_scheduler = 'linear' | |
betas = get_named_beta_schedule(beta_scheduler, self.diffusion_steps) | |
self.diffusion = GaussianDiffusion( | |
betas=betas, | |
model_mean_type=ModelMeanType.EPSILON, | |
model_var_type=ModelVarType.FIXED_SMALL, | |
loss_type=LossType.MSE | |
) | |
self.sampler = create_named_schedule_sampler(sampler, self.diffusion) | |
self.sampler_name = sampler | |
if args.is_train: | |
self.mse_criterion = torch.nn.MSELoss(reduction='none') | |
self.to(self.device) | |
def zero_grad(opt_list): | |
for opt in opt_list: | |
opt.zero_grad() | |
def clip_norm(network_list): | |
for network in network_list: | |
clip_grad_norm_(network.parameters(), 0.5) | |
def step(opt_list): | |
for opt in opt_list: | |
opt.step() | |
def forward(self, batch_data, eval_mode=False): | |
caption, motions, m_lens = batch_data | |
motions = motions.detach().to(self.device).float() | |
self.caption = caption | |
self.motions = motions | |
x_start = motions | |
B, T = x_start.shape[:2] | |
cur_len = torch.LongTensor([min(T, m_len) for m_len in m_lens]).to(self.device) | |
t, _ = self.sampler.sample(B, x_start.device) | |
output = self.diffusion.training_losses( | |
model=self.encoder, | |
x_start=x_start, | |
t=t, | |
model_kwargs={"text": caption, "length": cur_len} | |
) | |
self.real_noise = output['target'] | |
self.fake_noise = output['pred'] | |
try: | |
self.src_mask = self.encoder.module.generate_src_mask(T, cur_len).to(x_start.device) | |
except: | |
self.src_mask = self.encoder.generate_src_mask(T, cur_len).to(x_start.device) | |
def generate_batch(self, caption, m_lens, dim_pose): | |
xf_proj, xf_out = self.encoder.encode_text(caption, self.device) | |
B = len(caption) | |
T = min(m_lens.max(), self.encoder.num_frames) | |
output = self.diffusion.p_sample_loop( | |
self.encoder, | |
(B, T, dim_pose), | |
clip_denoised=False, | |
progress=True, | |
model_kwargs={ | |
'xf_proj': xf_proj, | |
'xf_out': xf_out, | |
'length': m_lens | |
}) | |
return output | |
def generate(self, caption, m_lens, dim_pose, batch_size=1024): | |
N = len(caption) | |
cur_idx = 0 | |
self.encoder.eval() | |
all_output = [] | |
while cur_idx < N: | |
if cur_idx + batch_size >= N: | |
batch_caption = caption[cur_idx:] | |
batch_m_lens = m_lens[cur_idx:] | |
else: | |
batch_caption = caption[cur_idx: cur_idx + batch_size] | |
batch_m_lens = m_lens[cur_idx: cur_idx + batch_size] | |
output = self.generate_batch(batch_caption, batch_m_lens, dim_pose) | |
B = output.shape[0] | |
for i in range(B): | |
all_output.append(output[i]) | |
cur_idx += batch_size | |
return all_output | |
def backward_G(self): | |
loss_mot_rec = self.mse_criterion(self.fake_noise, self.real_noise).mean(dim=-1) | |
loss_mot_rec = (loss_mot_rec * self.src_mask).sum() / self.src_mask.sum() | |
self.loss_mot_rec = loss_mot_rec | |
loss_logs = OrderedDict({}) | |
loss_logs['loss_mot_rec'] = self.loss_mot_rec.item() | |
return loss_logs | |
def update(self): | |
self.zero_grad([self.opt_encoder]) | |
loss_logs = self.backward_G() | |
self.loss_mot_rec.backward() | |
self.clip_norm([self.encoder]) | |
self.step([self.opt_encoder]) | |
return loss_logs | |
def to(self, device): | |
if self.opt.is_train: | |
self.mse_criterion.to(device) | |
self.encoder = self.encoder.to(device) | |
def train_mode(self): | |
self.encoder.train() | |
def eval_mode(self): | |
self.encoder.eval() | |
def save(self, file_name, ep, total_it): | |
state = { | |
'opt_encoder': self.opt_encoder.state_dict(), | |
'ep': ep, | |
'total_it': total_it | |
} | |
try: | |
state['encoder'] = self.encoder.module.state_dict() | |
except: | |
state['encoder'] = self.encoder.state_dict() | |
torch.save(state, file_name) | |
return | |
def load(self, model_dir): | |
checkpoint = torch.load(model_dir, map_location=self.device) | |
if self.opt.is_train: | |
self.opt_encoder.load_state_dict(checkpoint['opt_encoder']) | |
self.encoder.load_state_dict(checkpoint['encoder'], strict=True) | |
return checkpoint['ep'], checkpoint.get('total_it', 0) | |
def train(self, train_dataset): | |
rank, world_size = get_dist_info() | |
self.to(self.device) | |
self.opt_encoder = optim.Adam(self.encoder.parameters(), lr=self.opt.lr) | |
it = 0 | |
cur_epoch = 0 | |
if self.opt.is_continue: | |
model_dir = pjoin(self.opt.model_dir, 'latest.tar') | |
cur_epoch, it = self.load(model_dir) | |
start_time = time.time() | |
train_loader = build_dataloader( | |
train_dataset, | |
samples_per_gpu=self.opt.batch_size, | |
drop_last=True, | |
workers_per_gpu=4, | |
shuffle=True) | |
logs = OrderedDict() | |
for epoch in range(cur_epoch, self.opt.num_epochs): | |
self.train_mode() | |
for i, batch_data in enumerate(train_loader): | |
self.forward(batch_data) | |
log_dict = self.update() | |
for k, v in log_dict.items(): | |
if k not in logs: | |
logs[k] = v | |
else: | |
logs[k] += v | |
it += 1 | |
if it % self.opt.log_every == 0 and rank == 0: | |
mean_loss = OrderedDict({}) | |
for tag, value in logs.items(): | |
mean_loss[tag] = value / self.opt.log_every | |
logs = OrderedDict() | |
print_current_loss(start_time, it, mean_loss, epoch, inner_iter=i) | |
if it % self.opt.save_latest == 0 and rank == 0: | |
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) | |
if rank == 0: | |
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) | |
if epoch % self.opt.save_every_e == 0 and rank == 0: | |
self.save(pjoin(self.opt.model_dir, 'ckpt_e%03d.tar'%(epoch)), | |
epoch, total_it=it) | |