ReMoDiffuse / mogen /models /architectures /diffusion_architecture.py
mingyuan's picture
initial commit
a0d91d3
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_architecture import BaseArchitecture
from ..builder import (
ARCHITECTURES,
build_architecture,
build_submodule,
build_loss
)
from ..utils.gaussian_diffusion import (
GaussianDiffusion, get_named_beta_schedule, create_named_schedule_sampler,
ModelMeanType, ModelVarType, LossType, space_timesteps, SpacedDiffusion
)
def build_diffusion(cfg):
beta_scheduler = cfg['beta_scheduler']
diffusion_steps = cfg['diffusion_steps']
betas = get_named_beta_schedule(beta_scheduler, diffusion_steps)
model_mean_type = {
'start_x': ModelMeanType.START_X,
'previous_x': ModelMeanType.PREVIOUS_X,
'epsilon': ModelMeanType.EPSILON
}[cfg['model_mean_type']]
model_var_type = {
'learned': ModelVarType.LEARNED,
'fixed_small': ModelVarType.FIXED_SMALL,
'fixed_large': ModelVarType.FIXED_LARGE,
'learned_range': ModelVarType.LEARNED_RANGE
}[cfg['model_var_type']]
if cfg.get('respace', None) is not None:
diffusion = SpacedDiffusion(
use_timesteps=space_timesteps(diffusion_steps, cfg['respace']),
betas=betas,
model_mean_type=model_mean_type,
model_var_type=model_var_type,
loss_type=LossType.MSE
)
else:
diffusion = GaussianDiffusion(
betas=betas,
model_mean_type=model_mean_type,
model_var_type=model_var_type,
loss_type=LossType.MSE)
return diffusion
@ARCHITECTURES.register_module()
class MotionDiffusion(BaseArchitecture):
def __init__(self,
model=None,
loss_recon=None,
diffusion_train=None,
diffusion_test=None,
init_cfg=None,
inference_type='ddpm',
**kwargs):
super().__init__(init_cfg=init_cfg, **kwargs)
self.model = build_submodule(model)
self.loss_recon = build_loss(loss_recon)
self.diffusion_train = build_diffusion(diffusion_train)
self.diffusion_test = build_diffusion(diffusion_test)
self.sampler = create_named_schedule_sampler('uniform', self.diffusion_train)
self.inference_type = inference_type
def forward(self, **kwargs):
motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'].float()
sample_idx = kwargs.get('sample_idx', None)
clip_feat = kwargs.get('clip_feat', None)
B, T = motion.shape[:2]
text = []
for i in range(B):
text.append(kwargs['motion_metas'][i]['text'])
if self.training:
t, _ = self.sampler.sample(B, motion.device)
output = self.diffusion_train.training_losses(
model=self.model,
x_start=motion,
t=t,
model_kwargs={
'motion_mask': motion_mask,
'motion_length': kwargs['motion_length'],
'text': text,
'clip_feat': clip_feat,
'sample_idx': sample_idx}
)
pred, target = output['pred'], output['target']
recon_loss = self.loss_recon(pred, target, reduction_override='none')
recon_loss = (recon_loss.mean(dim=-1) * motion_mask).sum() / motion_mask.sum()
loss = {'recon_loss': recon_loss}
return loss
else:
dim_pose = kwargs['motion'].shape[-1]
model_kwargs = self.model.get_precompute_condition(device=motion.device, text=text, **kwargs)
model_kwargs['motion_mask'] = motion_mask
model_kwargs['sample_idx'] = sample_idx
inference_kwargs = kwargs.get('inference_kwargs', {})
if self.inference_type == 'ddpm':
output = self.diffusion_test.p_sample_loop(
self.model,
(B, T, dim_pose),
clip_denoised=False,
progress=False,
model_kwargs=model_kwargs,
**inference_kwargs
)
else:
output = self.diffusion_test.ddim_sample_loop(
self.model,
(B, T, dim_pose),
clip_denoised=False,
progress=False,
model_kwargs=model_kwargs,
eta=0,
**inference_kwargs
)
if getattr(self.model, "post_process") is not None:
output = self.model.post_process(output)
results = kwargs
results['pred_motion'] = output
results = self.split_results(results)
return results