|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from einops import rearrange, repeat |
|
|
|
from opensora.acceleration.checkpoint import auto_grad_checkpoint |
|
from opensora.models.dit import DiT |
|
from opensora.registry import MODELS |
|
from opensora.utils.ckpt_utils import load_checkpoint |
|
|
|
|
|
@MODELS.register_module() |
|
class Latte(DiT): |
|
def forward(self, x, t, y): |
|
""" |
|
Forward pass of DiT. |
|
x: (B, C, T, H, W) tensor of inputs |
|
t: (B,) tensor of diffusion timesteps |
|
y: list of text |
|
""" |
|
|
|
x = x.to(self.dtype) |
|
|
|
|
|
x = self.x_embedder(x) |
|
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) |
|
x = x + self.pos_embed_spatial |
|
x = rearrange(x, "b t s d -> b (t s) d") |
|
|
|
t = self.t_embedder(t, dtype=x.dtype) |
|
y = self.y_embedder(y, self.training) |
|
if self.use_text_encoder: |
|
y = y.squeeze(1).squeeze(1) |
|
condition = t + y |
|
condition_spatial = repeat(condition, "b d -> (b t) d", t=self.num_temporal) |
|
condition_temporal = repeat(condition, "b d -> (b s) d", s=self.num_spatial) |
|
|
|
|
|
for i, block in enumerate(self.blocks): |
|
if i % 2 == 0: |
|
|
|
x = rearrange(x, "b (t s) d -> (b t) s d", t=self.num_temporal, s=self.num_spatial) |
|
c = condition_spatial |
|
else: |
|
|
|
x = rearrange(x, "b (t s) d -> (b s) t d", t=self.num_temporal, s=self.num_spatial) |
|
c = condition_temporal |
|
if i == 1: |
|
x = x + self.pos_embed_temporal |
|
|
|
x = auto_grad_checkpoint(block, x, c) |
|
|
|
if i % 2 == 0: |
|
x = rearrange(x, "(b t) s d -> b (t s) d", t=self.num_temporal, s=self.num_spatial) |
|
else: |
|
x = rearrange(x, "(b s) t d -> b (t s) d", t=self.num_temporal, s=self.num_spatial) |
|
|
|
|
|
x = self.final_layer(x, condition) |
|
x = self.unpatchify(x) |
|
|
|
|
|
x = x.to(torch.float32) |
|
return x |
|
|
|
|
|
@MODELS.register_module("Latte-XL/2") |
|
def Latte_XL_2(from_pretrained=None, **kwargs): |
|
model = Latte( |
|
depth=28, |
|
hidden_size=1152, |
|
patch_size=(1, 2, 2), |
|
num_heads=16, |
|
**kwargs, |
|
) |
|
if from_pretrained is not None: |
|
load_checkpoint(model, from_pretrained) |
|
return model |
|
|
|
|
|
@MODELS.register_module("Latte-XL/2x2") |
|
def Latte_XL_2x2(from_pretrained=None, **kwargs): |
|
model = Latte( |
|
depth=28, |
|
hidden_size=1152, |
|
patch_size=(2, 2, 2), |
|
num_heads=16, |
|
**kwargs, |
|
) |
|
if from_pretrained is not None: |
|
load_checkpoint(model, from_pretrained) |
|
return model |
|
|