|
import torch |
|
from einops import repeat |
|
from torch import Tensor |
|
|
|
from comfy.ldm.flux.layers import ( |
|
timestep_embedding |
|
) |
|
from comfy.ldm.hunyuan_video.model import HunyuanVideo |
|
|
|
|
|
class ModifiedHunyuanVideo(HunyuanVideo): |
|
def forward_orig( |
|
self, |
|
img: Tensor, |
|
img_ids: Tensor, |
|
txt: Tensor, |
|
txt_ids: Tensor, |
|
txt_mask: Tensor, |
|
timesteps: Tensor, |
|
y: Tensor, |
|
guidance: Tensor = None, |
|
control=None, |
|
transformer_options={}, |
|
) -> Tensor: |
|
patches_replace = transformer_options.get("patches_replace", {}) |
|
|
|
initial_shape = list(img.shape) |
|
transformer_options['original_shape'] = initial_shape |
|
|
|
img = self.img_in(img) |
|
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) |
|
|
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) |
|
|
|
if self.params.guidance_embed: |
|
if guidance is None: |
|
raise ValueError("Didn't get guidance strength for guidance distilled model.") |
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) |
|
|
|
if txt_mask is not None and not torch.is_floating_point(txt_mask): |
|
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max |
|
|
|
transformer_options['txt_size'] = txt.shape[1] |
|
txt = self.txt_in(txt, timesteps, txt_mask) |
|
|
|
ids = torch.cat((img_ids, txt_ids), dim=1) |
|
pe = self.pe_embedder(ids) |
|
|
|
img_len = img.shape[1] |
|
if txt_mask is not None: |
|
attn_mask_len = img_len + txt.shape[1] |
|
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) |
|
attn_mask[:, 0, img_len:] = txt_mask |
|
else: |
|
attn_mask = None |
|
|
|
blocks_replace = patches_replace.get("dit", {}) |
|
for i, block in enumerate(self.double_blocks): |
|
if ("double_block", i) in blocks_replace: |
|
def block_wrap(args): |
|
out = {} |
|
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) |
|
return out |
|
|
|
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) |
|
txt = out["txt"] |
|
img = out["img"] |
|
else: |
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) |
|
|
|
if control is not None: |
|
control_i = control.get("input") |
|
if i < len(control_i): |
|
add = control_i[i] |
|
if add is not None: |
|
img += add |
|
|
|
img = torch.cat((img, txt), 1) |
|
|
|
for i, block in enumerate(self.single_blocks): |
|
if ("single_block", i) in blocks_replace: |
|
def block_wrap(args): |
|
out = {} |
|
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) |
|
return out |
|
|
|
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) |
|
img = out["img"] |
|
else: |
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) |
|
|
|
if control is not None: |
|
control_o = control.get("output") |
|
if i < len(control_o): |
|
add = control_o[i] |
|
if add is not None: |
|
img[:, : img_len] += add |
|
|
|
img = img[:, : img_len] |
|
|
|
img = self.final_layer(img, vec) |
|
|
|
shape = initial_shape[-3:] |
|
for i in range(len(shape)): |
|
shape[i] = shape[i] // self.patch_size[i] |
|
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) |
|
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) |
|
img = img.reshape(initial_shape) |
|
return img |
|
|
|
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs): |
|
bs, c, t, h, w = x.shape |
|
patch_size = self.patch_size |
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) |
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) |
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) |
|
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) |
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) |
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) |
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) |
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) |
|
|
|
regional_conditioning = transformer_options.get('patches', {}).get('regional_conditioning', None) |
|
if regional_conditioning is not None: |
|
context = regional_conditioning[0](context, transformer_options) |
|
|
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) |
|
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options) |
|
return out |
|
|
|
|
|
def inject_model(diffusion_model): |
|
diffusion_model.__class__ = ModifiedHunyuanVideo |
|
return diffusion_model |