Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .base_architecture import BaseArchitecture | |
from ..builder import ( | |
ARCHITECTURES, | |
build_architecture, | |
build_submodule, | |
build_loss | |
) | |
from ..utils.gaussian_diffusion import ( | |
GaussianDiffusion, get_named_beta_schedule, create_named_schedule_sampler, | |
ModelMeanType, ModelVarType, LossType, space_timesteps, SpacedDiffusion | |
) | |
def build_diffusion(cfg): | |
beta_scheduler = cfg['beta_scheduler'] | |
diffusion_steps = cfg['diffusion_steps'] | |
betas = get_named_beta_schedule(beta_scheduler, diffusion_steps) | |
model_mean_type = { | |
'start_x': ModelMeanType.START_X, | |
'previous_x': ModelMeanType.PREVIOUS_X, | |
'epsilon': ModelMeanType.EPSILON | |
}[cfg['model_mean_type']] | |
model_var_type = { | |
'learned': ModelVarType.LEARNED, | |
'fixed_small': ModelVarType.FIXED_SMALL, | |
'fixed_large': ModelVarType.FIXED_LARGE, | |
'learned_range': ModelVarType.LEARNED_RANGE | |
}[cfg['model_var_type']] | |
if cfg.get('respace', None) is not None: | |
diffusion = SpacedDiffusion( | |
use_timesteps=space_timesteps(diffusion_steps, cfg['respace']), | |
betas=betas, | |
model_mean_type=model_mean_type, | |
model_var_type=model_var_type, | |
loss_type=LossType.MSE | |
) | |
else: | |
diffusion = GaussianDiffusion( | |
betas=betas, | |
model_mean_type=model_mean_type, | |
model_var_type=model_var_type, | |
loss_type=LossType.MSE) | |
return diffusion | |
class MotionDiffusion(BaseArchitecture): | |
def __init__(self, | |
model=None, | |
loss_recon=None, | |
diffusion_train=None, | |
diffusion_test=None, | |
init_cfg=None, | |
inference_type='ddpm', | |
**kwargs): | |
super().__init__(init_cfg=init_cfg, **kwargs) | |
self.model = build_submodule(model) | |
self.loss_recon = build_loss(loss_recon) | |
self.diffusion_train = build_diffusion(diffusion_train) | |
self.diffusion_test = build_diffusion(diffusion_test) | |
self.sampler = create_named_schedule_sampler('uniform', self.diffusion_train) | |
self.inference_type = inference_type | |
def forward(self, **kwargs): | |
motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'].float() | |
sample_idx = kwargs.get('sample_idx', None) | |
clip_feat = kwargs.get('clip_feat', None) | |
B, T = motion.shape[:2] | |
text = [] | |
for i in range(B): | |
text.append(kwargs['motion_metas'][i]['text']) | |
if self.training: | |
t, _ = self.sampler.sample(B, motion.device) | |
output = self.diffusion_train.training_losses( | |
model=self.model, | |
x_start=motion, | |
t=t, | |
model_kwargs={ | |
'motion_mask': motion_mask, | |
'motion_length': kwargs['motion_length'], | |
'text': text, | |
'clip_feat': clip_feat, | |
'sample_idx': sample_idx} | |
) | |
pred, target = output['pred'], output['target'] | |
recon_loss = self.loss_recon(pred, target, reduction_override='none') | |
recon_loss = (recon_loss.mean(dim=-1) * motion_mask).sum() / motion_mask.sum() | |
loss = {'recon_loss': recon_loss} | |
return loss | |
else: | |
dim_pose = kwargs['motion'].shape[-1] | |
model_kwargs = self.model.get_precompute_condition(device=motion.device, text=text, **kwargs) | |
model_kwargs['motion_mask'] = motion_mask | |
model_kwargs['sample_idx'] = sample_idx | |
inference_kwargs = kwargs.get('inference_kwargs', {}) | |
if self.inference_type == 'ddpm': | |
output = self.diffusion_test.p_sample_loop( | |
self.model, | |
(B, T, dim_pose), | |
clip_denoised=False, | |
progress=False, | |
model_kwargs=model_kwargs, | |
**inference_kwargs | |
) | |
else: | |
output = self.diffusion_test.ddim_sample_loop( | |
self.model, | |
(B, T, dim_pose), | |
clip_denoised=False, | |
progress=False, | |
model_kwargs=model_kwargs, | |
eta=0, | |
**inference_kwargs | |
) | |
if getattr(self.model, "post_process") is not None: | |
output = self.model.post_process(output) | |
results = kwargs | |
results['pred_motion'] = output | |
results = self.split_results(results) | |
return results | |