Spaces:
Running
Running
File size: 4,823 Bytes
a0d91d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_architecture import BaseArchitecture
from ..builder import (
ARCHITECTURES,
build_architecture,
build_submodule,
build_loss
)
from ..utils.gaussian_diffusion import (
GaussianDiffusion, get_named_beta_schedule, create_named_schedule_sampler,
ModelMeanType, ModelVarType, LossType, space_timesteps, SpacedDiffusion
)
def build_diffusion(cfg):
beta_scheduler = cfg['beta_scheduler']
diffusion_steps = cfg['diffusion_steps']
betas = get_named_beta_schedule(beta_scheduler, diffusion_steps)
model_mean_type = {
'start_x': ModelMeanType.START_X,
'previous_x': ModelMeanType.PREVIOUS_X,
'epsilon': ModelMeanType.EPSILON
}[cfg['model_mean_type']]
model_var_type = {
'learned': ModelVarType.LEARNED,
'fixed_small': ModelVarType.FIXED_SMALL,
'fixed_large': ModelVarType.FIXED_LARGE,
'learned_range': ModelVarType.LEARNED_RANGE
}[cfg['model_var_type']]
if cfg.get('respace', None) is not None:
diffusion = SpacedDiffusion(
use_timesteps=space_timesteps(diffusion_steps, cfg['respace']),
betas=betas,
model_mean_type=model_mean_type,
model_var_type=model_var_type,
loss_type=LossType.MSE
)
else:
diffusion = GaussianDiffusion(
betas=betas,
model_mean_type=model_mean_type,
model_var_type=model_var_type,
loss_type=LossType.MSE)
return diffusion
@ARCHITECTURES.register_module()
class MotionDiffusion(BaseArchitecture):
def __init__(self,
model=None,
loss_recon=None,
diffusion_train=None,
diffusion_test=None,
init_cfg=None,
inference_type='ddpm',
**kwargs):
super().__init__(init_cfg=init_cfg, **kwargs)
self.model = build_submodule(model)
self.loss_recon = build_loss(loss_recon)
self.diffusion_train = build_diffusion(diffusion_train)
self.diffusion_test = build_diffusion(diffusion_test)
self.sampler = create_named_schedule_sampler('uniform', self.diffusion_train)
self.inference_type = inference_type
def forward(self, **kwargs):
motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'].float()
sample_idx = kwargs.get('sample_idx', None)
clip_feat = kwargs.get('clip_feat', None)
B, T = motion.shape[:2]
text = []
for i in range(B):
text.append(kwargs['motion_metas'][i]['text'])
if self.training:
t, _ = self.sampler.sample(B, motion.device)
output = self.diffusion_train.training_losses(
model=self.model,
x_start=motion,
t=t,
model_kwargs={
'motion_mask': motion_mask,
'motion_length': kwargs['motion_length'],
'text': text,
'clip_feat': clip_feat,
'sample_idx': sample_idx}
)
pred, target = output['pred'], output['target']
recon_loss = self.loss_recon(pred, target, reduction_override='none')
recon_loss = (recon_loss.mean(dim=-1) * motion_mask).sum() / motion_mask.sum()
loss = {'recon_loss': recon_loss}
return loss
else:
dim_pose = kwargs['motion'].shape[-1]
model_kwargs = self.model.get_precompute_condition(device=motion.device, text=text, **kwargs)
model_kwargs['motion_mask'] = motion_mask
model_kwargs['sample_idx'] = sample_idx
inference_kwargs = kwargs.get('inference_kwargs', {})
if self.inference_type == 'ddpm':
output = self.diffusion_test.p_sample_loop(
self.model,
(B, T, dim_pose),
clip_denoised=False,
progress=False,
model_kwargs=model_kwargs,
**inference_kwargs
)
else:
output = self.diffusion_test.ddim_sample_loop(
self.model,
(B, T, dim_pose),
clip_denoised=False,
progress=False,
model_kwargs=model_kwargs,
eta=0,
**inference_kwargs
)
if getattr(self.model, "post_process") is not None:
output = self.model.post_process(output)
results = kwargs
results['pred_motion'] = output
results = self.split_results(results)
return results
|