Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from ..builder import SUBMODULES | |
from .diffusion_transformer import DiffusionTransformer | |
class MotionDiffuseTransformer(DiffusionTransformer): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def get_precompute_condition(self, | |
text=None, | |
xf_proj=None, | |
xf_out=None, | |
device=None, | |
clip_feat=None, | |
**kwargs): | |
if xf_proj is None or xf_out is None: | |
xf_proj, xf_out = self.encode_text(text, clip_feat, device) | |
return {'xf_proj': xf_proj, 'xf_out': xf_out} | |
def forward_train(self, h=None, src_mask=None, emb=None, xf_out=None, **kwargs): | |
B, T = h.shape[0], h.shape[1] | |
for module in self.temporal_decoder_blocks: | |
h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) | |
output = self.out(h).view(B, T, -1).contiguous() | |
return output | |
def forward_test(self, h=None, src_mask=None, emb=None, xf_out=None, **kwargs): | |
B, T = h.shape[0], h.shape[1] | |
for module in self.temporal_decoder_blocks: | |
h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) | |
output = self.out(h).view(B, T, -1).contiguous() | |
return output | |