jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import math
import torch
from torch import Tensor
from unittest.mock import patch
from comfy.ldm.flux.layers import timestep_embedding
from comfy.ldm.lightricks.model import precompute_freqs_cis
from comfy.ldm.common_dit import rms_norm
def poly1d(coefficients, x):
result = torch.zeros_like(x)
for i, coeff in enumerate(coefficients):
result += coeff * (x ** (len(coefficients) - 1 - i))
return result
def teacache_flux_forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control = None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
rel_l1_thresh = transformer_options.get("rel_l1_thresh", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
# enable teacache
inp = img.clone()
vec_ = vec.clone()
img_mod1, _ = self.double_blocks[0].img_mod(vec_)
modulated_inp = self.double_blocks[0].img_norm1(inp)
modulated_inp = (1 + img_mod1.scale) * modulated_inp + img_mod1.shift
ca_idx = 0
if not hasattr(self, 'accumulated_rel_l1_distance'):
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
try:
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()))
if self.accumulated_rel_l1_distance < rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
except:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
if not should_calc:
img += self.previous_residual
else:
ori_img = img.clone()
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
# PuLID attention
if getattr(self, "pulid_data", {}):
if i % self.pulid_double_interval == 0:
# Will calculate influence of all pulid nodes at once
for _, node_data in self.pulid_data.items():
if torch.any((node_data['sigma_start'] >= timesteps)
& (timesteps >= node_data['sigma_end'])):
img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img)
ca_idx += 1
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
# PuLID attention
if getattr(self, "pulid_data", {}):
real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...]
if i % self.pulid_single_interval == 0:
# Will calculate influence of all nodes at once
for _, node_data in self.pulid_data.items():
if torch.any((node_data['sigma_start'] >= timesteps)
& (timesteps >= node_data['sigma_end'])):
real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], real_img)
ca_idx += 1
img = torch.cat((txt, real_img), 1)
img = img[:, txt.shape[1] :, ...]
self.previous_residual = img - ori_img
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def teacache_hunyuanvideo_forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
transformer_options={},
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
rel_l1_thresh = transformer_options.get("rel_l1_thresh", {})
initial_shape = list(img.shape)
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
txt = self.txt_in(txt, timesteps, txt_mask)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask
else:
attn_mask = None
blocks_replace = patches_replace.get("dit", {})
# enable teacache
inp = img.clone()
vec_ = vec.clone()
img_mod1, _ = self.double_blocks[0].img_mod(vec_)
modulated_inp = self.double_blocks[0].img_norm1(inp)
modulated_inp = (1 + img_mod1.scale) * modulated_inp + img_mod1.shift
if not hasattr(self, 'accumulated_rel_l1_distance'):
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
try:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()))
if self.accumulated_rel_l1_distance < rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
except:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
if not should_calc:
img += self.previous_residual
else:
ori_img = img.clone()
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((img, txt), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, : img_len] += add
img = img[:, : img_len]
self.previous_residual = img - ori_img
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape)
return img
def teacache_ltxvmodel_forward(
self,
x,
timestep,
context,
attention_mask,
frame_rate=25,
guiding_latent=None,
guiding_latent_noise_scale=0,
transformer_options={},
**kwargs
):
patches_replace = transformer_options.get("patches_replace", {})
rel_l1_thresh = transformer_options.get("rel_l1_thresh", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
if guiding_latent is not None:
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if guiding_latent_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
scale = guiding_latent_noise_scale * (input_ts ** 2)
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
orig_shape = list(x.shape)
x = self.patchifier.patchify(x)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=x.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(
batch_size, -1, x.shape[-1]
)
blocks_replace = patches_replace.get("dit", {})
# enable teacache
inp = x.clone()
timestep_ = timestep.clone()
num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
ada_values = self.transformer_blocks[0].scale_shift_table[None, None] + timestep_.reshape(batch_size, timestep_.size(1), num_ada_params, -1)
shift_msa, scale_msa, _, _, _, _ = ada_values.unbind(dim=2)
modulated_inp = rms_norm(inp)
modulated_inp = modulated_inp * (1 + scale_msa) + shift_msa
if not hasattr(self, 'accumulated_rel_l1_distance'):
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
try:
coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()))
if self.accumulated_rel_l1_distance < rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
except:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
if not should_calc:
x += self.previous_residual
else:
ori_x = x.clone()
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
x,
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
)
# 3. Output
scale_shift_values = (
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
self.previous_residual = x - ori_x
x = self.proj_out(x)
x = self.patchifier.unpatchify(
latents=x,
output_height=orig_shape[3],
output_width=orig_shape[4],
output_num_frames=orig_shape[2],
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
if guiding_latent is not None:
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
# print("res", x)
return x
class TeaCacheForImgGen:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The image diffusion model the TeaCache will be applied to."}),
"model_type": (["flux"],),
"rel_l1_thresh": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."})
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "apply_teacache"
CATEGORY = "TeaCache"
TITLE = "TeaCache For Img Gen"
def apply_teacache(self, model, model_type: str, rel_l1_thresh: float):
if rel_l1_thresh == 0:
return (model,)
new_model = model.clone()
if 'transformer_options' not in new_model.model_options:
new_model.model_options['transformer_options'] = {}
new_model.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh
diffusion_model = new_model.get_model_object("diffusion_model")
if model_type == "flux":
forward_name = "forward_orig"
replaced_forward_fn = teacache_flux_forward.__get__(
diffusion_model,
diffusion_model.__class__
)
else:
raise ValueError(f"Unknown type {model_type}")
def unet_wrapper_function(model_function, kwargs):
input = kwargs["input"]
timestep = kwargs["timestep"]
c = kwargs["c"]
with patch.object(diffusion_model, forward_name, replaced_forward_fn):
return model_function(input, timestep, **c)
new_model.set_model_unet_function_wrapper(unet_wrapper_function)
return (new_model,)
class TeaCacheForVidGen:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The video diffusion model the TeaCache will be applied to."}),
"model_type": (["hunyuan_video", "ltxv"],),
"rel_l1_thresh": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."})
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "apply_teacache"
CATEGORY = "TeaCache"
TITLE = "TeaCache For Vid Gen"
def apply_teacache(self, model, model_type: str, rel_l1_thresh: float):
if rel_l1_thresh == 0:
return (model,)
new_model = model.clone()
if 'transformer_options' not in new_model.model_options:
new_model.model_options['transformer_options'] = {}
new_model.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh
diffusion_model = new_model.get_model_object("diffusion_model")
if model_type == "hunyuan_video":
forward_name = "forward_orig"
replaced_forward_fn = teacache_hunyuanvideo_forward.__get__(
diffusion_model,
diffusion_model.__class__
)
elif model_type == "ltxv":
forward_name = "forward"
replaced_forward_fn = teacache_ltxvmodel_forward.__get__(
diffusion_model,
diffusion_model.__class__
)
else:
raise ValueError(f"Unknown type {model_type}")
def unet_wrapper_function(model_function, kwargs):
input = kwargs["input"]
timestep = kwargs["timestep"]
c = kwargs["c"]
with patch.object(diffusion_model, forward_name, replaced_forward_fn):
return model_function(input, timestep, **c)
new_model.set_model_unet_function_wrapper(unet_wrapper_function)
return (new_model,)
class CompileModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the torch.compile will be applied to."}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"backend": (["inductor","cudagraphs", "eager", "aot_eager"], {"default": "inductor"}),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "apply_compile"
CATEGORY = "TeaCache"
TITLE = "Compile Model"
def apply_compile(self, model, mode: str, backend: str, fullgraph: bool, dynamic: bool):
new_model = model.clone()
new_model.add_object_patch(
"diffusion_model",
torch.compile(
new_model.get_model_object("diffusion_model"),
mode=mode,
backend=backend,
fullgraph=fullgraph,
dynamic=dynamic
)
)
return (new_model,)
NODE_CLASS_MAPPINGS = {
"TeaCacheForImgGen": TeaCacheForImgGen,
"TeaCacheForVidGen": TeaCacheForVidGen,
"CompileModel": CompileModel
}
NODE_DISPLAY_NAME_MAPPINGS = {k: v.TITLE for k, v in NODE_CLASS_MAPPINGS.items()}