jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
#based on https://github.com/DarkMnDragon/rf-inversion-diffuser/blob/main/inversion_editing_cli.py
import torch
import gc
import os
from .utils import log, print_memory
from diffusers.utils.torch_utils import randn_tensor
import comfy.model_management as mm
from .hyvideo.diffusion.pipelines.pipeline_hunyuan_video import get_rotary_pos_embed
from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight
script_directory = os.path.dirname(os.path.abspath(__file__))
VAE_SCALING_FACTOR = 0.476986
def generate_eta_values(
timesteps,
start_step,
end_step,
eta,
eta_trend,
):
assert start_step < end_step and start_step >= 0 and end_step <= len(timesteps), "Invalid start_step and end_step"
# timesteps are monotonically decreasing, from 1.0 to 0.0
eta_values = [0.0] * (len(timesteps) - 1)
if eta_trend == 'constant':
for i in range(start_step, end_step):
eta_values[i] = eta
elif eta_trend == 'linear_increase':
total_time = timesteps[start_step] - timesteps[end_step - 1]
for i in range(start_step, end_step):
eta_values[i] = eta * (timesteps[start_step] - timesteps[i]) / total_time
elif eta_trend == 'linear_decrease':
total_time = timesteps[start_step] - timesteps[end_step - 1]
for i in range(start_step, end_step):
eta_values[i] = eta * (timesteps[i] - timesteps[end_step - 1]) / total_time
else:
raise NotImplementedError(f"Unsupported eta_trend: {eta_trend}")
print("eta_values", eta_values)
return eta_values
class HyVideoEmptyTextEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
}
}
RETURN_TYPES = ("HYVIDEMBEDS", )
RETURN_NAMES = ("hyvid_embeds",)
FUNCTION = "process"
CATEGORY = "HunyuanVideoWrapper"
DESCRIPTION = "Empty Text Embeds for HunyuanVideoWrapper, to avoid having to encode prompts for inverse sampling"
def process(self):
device = mm.text_encoder_device()
offload_device = mm.text_encoder_offload_device()
prompt_embeds_dict = torch.load(os.path.join(script_directory, "hunyuan_empty_prompt_embeds_dict.pt"))
return (prompt_embeds_dict,)
#region Inverse Sampling
class HyVideoInverseSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("HYVIDEOMODEL",),
"hyvid_embeds": ("HYVIDEMBEDS", ),
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ),
"steps": ("INT", {"default": 30, "min": 1}),
"embedded_guidance_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"flow_shift": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 30.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"force_offload": ("BOOLEAN", {"default": True}),
"gamma": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"start_step": ("INT", {"default": 0, "min": 0}),
"end_step": ("INT", {"default": 18, "min": 0}),
"gamma_trend": (['constant', 'linear_increase', 'linear_decrease'], {"default": "constant"}),
},
"optional": {
"interpolation_curve": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}),
}
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "process"
CATEGORY = "HunyuanVideoWrapper"
def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, seed, samples, gamma, start_step, end_step, gamma_trend, force_offload, interpolation_curve=None):
comfy_model_patcher = model
model = model.model
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
dtype = model["dtype"]
transformer = model["pipe"].transformer
pipeline = model["pipe"]
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
latents = samples["samples"] * VAE_SCALING_FACTOR if samples is not None else None
batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width = latents.shape
height = latent_height * pipeline.vae_scale_factor
width = latent_width * pipeline.vae_scale_factor
num_frames = (latent_num_frames - 1) * 4 + 1
if width <= 0 or height <= 0 or num_frames <= 0:
raise ValueError(
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={num_frames}"
)
if (num_frames - 1) % 4 != 0:
raise ValueError(
f"`video_length - 1 (that's minus one frame)` must be a multiple of 4, got {num_frames}"
)
log.info(
f"Input (height, width, video_length) = ({height}, {width}, {num_frames})"
)
freqs_cos, freqs_sin = get_rotary_pos_embed(transformer, latent_num_frames, height, width)
freqs_cos = freqs_cos.to(device)
freqs_sin = freqs_sin.to(device)
pipeline.scheduler.flow_shift = flow_shift
if model["block_swap_args"] is not None:
for name, param in transformer.named_parameters():
#print(name, param.data.device)
if "single" not in name and "double" not in name:
param.data = param.data.to(device)
transformer.block_swap(
model["block_swap_args"]["double_blocks_to_swap"] - 1 ,
model["block_swap_args"]["single_blocks_to_swap"] - 1,
offload_txt_in = model["block_swap_args"]["offload_txt_in"],
offload_img_in = model["block_swap_args"]["offload_img_in"],
)
elif model["auto_cpu_offload"]:
for name, param in transformer.named_parameters():
if "single" not in name and "double" not in name:
param.data = param.data.to(device)
elif model["manual_offloading"]:
transformer.to(device)
mm.soft_empty_cache()
gc.collect()
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
pipeline.scheduler.set_timesteps(steps, device=device)
timesteps = pipeline.scheduler.timesteps
timesteps = timesteps.flip(0)
print("timesteps", timesteps)
print("pipeline.scheduler.order", pipeline.scheduler.order)
print("len(timesteps)", len(timesteps))
latent_video_length = (num_frames - 1) // 4 + 1
# 5. Prepare latent variables
num_channels_latents = transformer.config.in_channels
latents = latents.to(device)
shape = (
1,
num_channels_latents,
latent_video_length,
int(height) // pipeline.vae_scale_factor,
int(width) // pipeline.vae_scale_factor,
)
noise = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
frames_needed = noise.shape[1]
current_frames = latents.shape[1]
if frames_needed > current_frames:
repeat_factor = frames_needed - current_frames
additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
latents = torch.cat((additional_frame, latents), dim=1)
self.additional_frames = repeat_factor
elif frames_needed < current_frames:
latents = latents[:, :frames_needed, :, :, :]
gamma_values = generate_eta_values(timesteps / 1000, start_step, end_step, gamma, gamma_trend)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - steps * pipeline.scheduler.order
self._num_timesteps = len(timesteps)
latents = latents.to(dtype)
from latent_preview import prepare_callback
callback = prepare_callback(comfy_model_patcher, steps)
from comfy.utils import ProgressBar
from tqdm import tqdm
log.info(f"Sampling {num_frames} frames in {latents.shape[2]} latents at {width}x{height} with {len(timesteps)} inference steps")
comfy_pbar = ProgressBar(len(timesteps))
with tqdm(total=len(timesteps)) as progress_bar:
for idx, (t, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
latent_model_input = latents
t_expand = t.repeat(latent_model_input.shape[0])
guidance_expand = (
torch.tensor(
[embedded_guidance_scale] * latent_model_input.shape[0],
dtype=torch.float32,
device=device,
).to(pipeline.base_dtype)
* 1000.0
if embedded_guidance_scale is not None
else None
)
# predict the noise residual
with torch.autocast(
device_type="cuda", dtype=pipeline.base_dtype, enabled=True
):
noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
latent_model_input, # [2, 16, 33, 24, 42]
t_expand, # [2]
text_states=hyvid_embeds["prompt_embeds"], # [2, 256, 4096]
text_mask=hyvid_embeds["attention_mask"], # [2, 256]
text_states_2=hyvid_embeds["prompt_embeds_2"], # [2, 768]
freqs_cos=freqs_cos, # [seqlen, head_dim]
freqs_sin=freqs_sin, # [seqlen, head_dim]
guidance=guidance_expand,
stg_block_idx=-1,
stg_mode=None,
return_dict=True,
)["x"]
sigma = t / 1000.0
sigma_prev = t_prev / 1000.0
latents = latents.to(torch.float32)
noise_pred = noise_pred.to(torch.float32)
target_noise_velocity = (noise - latents) / (1.0 - sigma)
if interpolation_curve is not None:
time_weights = torch.tensor(interpolation_curve, device=latents.device)
assert time_weights.shape[0] == latents.shape[2], f"Weight list length {len(interpolation_curve)} must match temporal dimension {latents.shape[2]}"
gamma = gamma_values[idx] * time_weights.view(1, 1, -1, 1, 1) # shape [1, 1, 33, 1, 1]
else:
gamma = gamma_values[idx]
interpolated_velocity = gamma * target_noise_velocity + (1 - gamma) * noise_pred
latents = latents + (sigma_prev - sigma) * interpolated_velocity
latents = latents.to(torch.bfloat16)
# compute the previous noisy sample x_t -> x_t-1
#latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
progress_bar.update()
if callback is not None:
callback(idx, (latent_model_input - noise_pred * t / 1000).detach()[0].permute(1,0,2,3), None, steps)
else:
comfy_pbar.update(1)
print_memory(device)
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
if force_offload:
if model["manual_offloading"]:
transformer.to(offload_device)
mm.soft_empty_cache()
gc.collect()
return ({
"samples": latents / VAE_SCALING_FACTOR
},)
#region ReSampler
class HyVideoReSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("HYVIDEOMODEL",),
"hyvid_embeds": ("HYVIDEMBEDS", ),
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ),
"inversed_latents": ("LATENT", {"tooltip": "inversed latents from HyVideoInverseSampler"} ),
"steps": ("INT", {"default": 30, "min": 1}),
"embedded_guidance_scale": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"flow_shift": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 30.0, "step": 0.01}),
"force_offload": ("BOOLEAN", {"default": True}),
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "The step to start the effect of the inversed latents"}),
"end_step": ("INT", {"default": 18, "min": 0, "tooltip": "The step to end the effect of the inversed latents"}),
"eta_base": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The base value of the eta, overall strength of the effect from the inversed latents"}),
"eta_trend": (['constant', 'linear_increase', 'linear_decrease'], {"default": "constant", "tooltip": "The trend of the eta value over steps"}),
},
"optional": {
"interpolation_curve": ("FLOAT", {"forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}),
"feta_args": ("FETAARGS", ),
}
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "process"
CATEGORY = "HunyuanVideoWrapper"
def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale,
samples, inversed_latents, force_offload, start_step, end_step, eta_base, eta_trend, interpolation_curve=None, feta_args=None):
comfy_model_patcher = model
model = model.model
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
dtype = model["dtype"]
transformer = model["pipe"].transformer
pipeline = model["pipe"]
target_latents = samples["samples"] * VAE_SCALING_FACTOR
batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width = target_latents.shape
height = latent_height * pipeline.vae_scale_factor
width = latent_width * pipeline.vae_scale_factor
num_frames = (latent_num_frames - 1) * 4 + 1
if width <= 0 or height <= 0 or num_frames <= 0:
raise ValueError(
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={num_frames}"
)
if (num_frames - 1) % 4 != 0:
raise ValueError(
f"`video_length - 1 (that's minus one frame)` must be a multiple of 4, got {num_frames}"
)
log.info(
f"Input (height, width, video_length) = ({height}, {width}, {num_frames})"
)
freqs_cos, freqs_sin = get_rotary_pos_embed(transformer, latent_num_frames, height, width)
freqs_cos = freqs_cos.to(device)
freqs_sin = freqs_sin.to(device)
pipeline.scheduler.flow_shift = flow_shift
if model["block_swap_args"] is not None:
for name, param in transformer.named_parameters():
#print(name, param.data.device)
if "single" not in name and "double" not in name:
param.data = param.data.to(device)
transformer.block_swap(
model["block_swap_args"]["double_blocks_to_swap"] - 1 ,
model["block_swap_args"]["single_blocks_to_swap"] - 1,
offload_txt_in = model["block_swap_args"]["offload_txt_in"],
offload_img_in = model["block_swap_args"]["offload_img_in"],
)
elif model["auto_cpu_offload"]:
for name, param in transformer.named_parameters():
if "single" not in name and "double" not in name:
param.data = param.data.to(device)
elif model["manual_offloading"]:
transformer.to(device)
mm.soft_empty_cache()
gc.collect()
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
pipeline.scheduler.set_timesteps(steps, device=device)
timesteps = pipeline.scheduler.timesteps
eta_values = generate_eta_values(timesteps / 1000, start_step, end_step, eta_base, eta_trend)
target_latents = target_latents.to(device=device, dtype=dtype)
latents = inversed_latents["samples"] * VAE_SCALING_FACTOR
latents = latents.to(device=device, dtype=dtype)
# 7. Denoising loop
self._num_timesteps = len(timesteps)
from latent_preview import prepare_callback
callback = prepare_callback(comfy_model_patcher, steps)
if feta_args is not None:
set_enhance_weight(feta_args["weight"])
feta_start_percent = feta_args["start_percent"]
feta_end_percent = feta_args["end_percent"]
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"])
else:
disable_enhance()
from comfy.utils import ProgressBar
from tqdm import tqdm
log.info(f"Sampling {num_frames} frames in {latents.shape[2]} latents at {width}x{height} with {len(timesteps)} inference steps")
comfy_pbar = ProgressBar(len(timesteps))
with tqdm(total=len(timesteps)) as progress_bar:
for idx, (t, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
current_step_percentage = idx / len(timesteps)
if feta_args is not None:
if feta_start_percent <= current_step_percentage <= feta_end_percent:
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"])
else:
disable_enhance()
latent_model_input = latents
t_expand = t.repeat(latent_model_input.shape[0])
guidance_expand = (
torch.tensor(
[embedded_guidance_scale] * latent_model_input.shape[0],
dtype=torch.float32,
device=device,
).to(pipeline.base_dtype)
* 1000.0
if embedded_guidance_scale is not None
else None
)
# predict the noise residual
with torch.autocast(
device_type="cuda", dtype=pipeline.base_dtype, enabled=True
):
noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
latent_model_input, # [2, 16, 33, 24, 42]
t_expand, # [2]
text_states=hyvid_embeds["prompt_embeds"], # [2, 256, 4096]
text_mask=hyvid_embeds["attention_mask"], # [2, 256]
text_states_2=hyvid_embeds["prompt_embeds_2"], # [2, 768]
freqs_cos=freqs_cos, # [seqlen, head_dim]
freqs_sin=freqs_sin, # [seqlen, head_dim]
guidance=guidance_expand,
stg_block_idx=-1,
stg_mode=None,
return_dict=True,
)["x"]
sigma = t / 1000.0
sigma_prev = t_prev / 1000.0
noise_pred = noise_pred.to(torch.float32)
latents = latents.to(torch.float32)
target_latents = target_latents.to(torch.float32)
target_img_velocity = -(target_latents - latents) / sigma
# interpolated velocity
# Add time-varying weights
if interpolation_curve is not None:
time_weights = torch.tensor(interpolation_curve, device=latents.device)
assert time_weights.shape[0] == latents.shape[2], f"Weight list length {len(interpolation_curve)} must match temporal dimension {latents.shape[2]}"
eta = eta_values[idx] * time_weights.view(1, 1, -1, 1, 1) # shape [1, 1, 33, 1, 1]
else:
eta = eta_values[idx]
# Time-varying interpolation
interpolated_velocity = eta * target_img_velocity + (1 - eta) * noise_pred
latents = latents + (sigma_prev - sigma) * interpolated_velocity
#print(f"X_{sigma_prev:.3f} = X_{sigma:.3f} + {sigma_prev - sigma:.3f} * ({eta:.3f} * target_img_velocity + {1 - eta:.3f} * noise_pred)")
latents = latents.to(torch.bfloat16)
progress_bar.update()
if callback is not None:
callback(idx, (latent_model_input - noise_pred * t / 1000).detach()[0].permute(1,0,2,3), None, steps)
else:
comfy_pbar.update(1)
print_memory(device)
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
if force_offload:
if model["manual_offloading"]:
transformer.to(offload_device)
mm.soft_empty_cache()
gc.collect()
return ({
"samples": latents / VAE_SCALING_FACTOR
},)
#region PromptMix
class HyVideoPromptMixSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("HYVIDEOMODEL",),
"hyvid_embeds": ("HYVIDEMBEDS", ),
"hyvid_embeds_2": ("HYVIDEMBEDS", ),
"width": ("INT", {"default": 512, "min": 1}),
"height": ("INT", {"default": 512, "min": 1}),
"num_frames": ("INT", {"default": 17, "min": 1}),
"steps": ("INT", {"default": 30, "min": 1}),
"embedded_guidance_scale": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"flow_shift": ("FLOAT", {"default": 9.0, "min": 1.0, "max": 30.0, "step": 0.01}),
"force_offload": ("BOOLEAN", {"default": True}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Adjusts the blending sharpness"}),
},
"optional": {
"interpolation_curve": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}),
"feta_args": ("FETAARGS", ),
}
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "process"
CATEGORY = "HunyuanVideoWrapper"
EXPERIMENTAL = True
def process(self, model, width, height, num_frames, hyvid_embeds, hyvid_embeds_2, flow_shift, steps, embedded_guidance_scale,
seed, force_offload, alpha, interpolation_curve=None, feta_args=None):
comfy_model_patcher = model
model = model.model
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
dtype = model["dtype"]
transformer = model["pipe"].transformer
pipeline = model["pipe"]
if width <= 0 or height <= 0 or num_frames <= 0:
raise ValueError(
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={num_frames}"
)
if (num_frames - 1) % 4 != 0:
raise ValueError(
f"`video_length - 1 (that's minus one frame)` must be a multiple of 4, got {num_frames}"
)
log.info(
f"Input (height, width, video_length) = ({height}, {width}, {num_frames})"
)
latent_video_length = (num_frames - 1) // 4 + 1
freqs_cos, freqs_sin = get_rotary_pos_embed(transformer, latent_video_length, height, width)
freqs_cos = freqs_cos.to(device)
freqs_sin = freqs_sin.to(device)
pipeline.scheduler.flow_shift = flow_shift
if model["block_swap_args"] is not None:
for name, param in transformer.named_parameters():
#print(name, param.data.device)
if "single" not in name and "double" not in name:
param.data = param.data.to(device)
transformer.block_swap(
model["block_swap_args"]["double_blocks_to_swap"] - 1 ,
model["block_swap_args"]["single_blocks_to_swap"] - 1,
offload_txt_in = model["block_swap_args"]["offload_txt_in"],
offload_img_in = model["block_swap_args"]["offload_img_in"],
)
elif model["auto_cpu_offload"]:
for name, param in transformer.named_parameters():
if "single" not in name and "double" not in name:
param.data = param.data.to(device)
elif model["manual_offloading"]:
transformer.to(device)
mm.soft_empty_cache()
gc.collect()
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
pipeline.scheduler.set_timesteps(steps, device=device)
timesteps = pipeline.scheduler.timesteps
#latents = samples["samples"]
shape = (
1,
16,
latent_video_length,
int(height) // pipeline.vae_scale_factor,
int(width) // pipeline.vae_scale_factor,
)
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
llm_embeds_1 = hyvid_embeds["prompt_embeds"].to(dtype).to(device)
clip_embeds_1 = hyvid_embeds["prompt_embeds_2"].to(dtype).to(device)
mask_1 = hyvid_embeds["attention_mask"].to(device)
llm_embeds_2 = hyvid_embeds_2["prompt_embeds"].to(dtype).to(device)
clip_embeds_2 = hyvid_embeds_2["prompt_embeds_2"].to(dtype).to(device)
mask_2 = hyvid_embeds_2["attention_mask"].to(device)
text_embeds = torch.cat((llm_embeds_1, llm_embeds_2), dim=0)
text_mask = torch.cat((mask_1, mask_2), dim=0)
clip_embeds = torch.cat((clip_embeds_1, clip_embeds_2), dim=0)
assert len(interpolation_curve) == latents.shape[2], f"Weight list length {len(interpolation_curve)} must match temporal dimension {latents.shape[2]}"
latents_1 = latents.clone()
latents_2 = latents.clone()
if feta_args is not None:
set_enhance_weight(feta_args["weight"])
feta_start_percent = feta_args["start_percent"]
feta_end_percent = feta_args["end_percent"]
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"])
else:
disable_enhance()
# 7. Denoising loop
self._num_timesteps = len(timesteps)
from latent_preview import prepare_callback
callback = prepare_callback(comfy_model_patcher, steps)
from comfy.utils import ProgressBar
from tqdm import tqdm
log.info(f"Sampling {num_frames} frames in {latents.shape[2]} latents at {width}x{height} with {len(timesteps)} inference steps")
comfy_pbar = ProgressBar(len(timesteps))
with tqdm(total=len(timesteps)) as progress_bar:
for idx, t in enumerate(timesteps):
current_step_percentage = idx / len(timesteps)
if feta_args is not None:
if feta_start_percent <= current_step_percentage <= feta_end_percent:
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"])
else:
disable_enhance()
# Pre-compute weighted latents
weighted_latents_1 = torch.zeros_like(latents_1)
weighted_latents_2 = torch.zeros_like(latents_2)
for t_idx in range(latents_1.shape[2]):
weight = interpolation_curve[t_idx]
weighted_latents_1[..., t_idx, :, :] = (
(1 - alpha * weight) * latents_1[..., t_idx, :, :] +
(alpha * weight) * latents_2[..., t_idx, :, :]
)
weighted_latents_2[..., t_idx, :, :] = (
(1 - alpha * (1-weight)) * latents_2[..., t_idx, :, :] +
(alpha * (1-weight)) * latents_1[..., t_idx, :, :]
)
# Use weighted inputs for model
latent_model_input = torch.cat([weighted_latents_1, weighted_latents_2])
t_expand = t.repeat(latent_model_input.shape[0])
guidance_expand = (
torch.tensor(
[embedded_guidance_scale] * latent_model_input.shape[0],
dtype=torch.float32,
device=device,
).to(pipeline.base_dtype)
* 1000.0
if embedded_guidance_scale is not None
else None
)
# predict the noise residual
with torch.autocast(
device_type="cuda", dtype=pipeline.base_dtype, enabled=True
):
noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
latent_model_input, # [2, 16, 33, 24, 42]
t_expand, # [2]
text_states=text_embeds, # [2, 256, 4096]
text_mask=text_mask, # [2, 256]
text_states_2=clip_embeds, # [2, 768]
freqs_cos=freqs_cos, # [seqlen, head_dim]
freqs_sin=freqs_sin, # [seqlen, head_dim]
guidance=guidance_expand,
stg_block_idx=-1,
stg_mode=None,
return_dict=True,
)["x"]
noise_pred = noise_pred.to(torch.float32)
# 1. Get noise predictions for both prompts
noise_pred_prompt_1, noise_pred_prompt_2 = noise_pred.chunk(2)
# 2. Update latents separately for each prompt
dt = pipeline.scheduler.sigmas[idx + 1] - pipeline.scheduler.sigmas[idx]
latents_1 = latents_1 + noise_pred_prompt_1 * dt
latents_2 = latents_2 + noise_pred_prompt_2 * dt
# 3. Interpolate latents based on temporal curve
interpolated_latents = torch.zeros_like(latents_1)
for t_idx in range(latents.shape[2]):
weight = interpolation_curve[t_idx]
interpolated_latents[..., t_idx, :, :] = (
(1 - weight) * latents_1[..., t_idx, :, :] +
weight * latents_2[..., t_idx, :, :]
)
latents = interpolated_latents
progress_bar.update()
if callback is not None:
callback(idx, (latent_model_input - noise_pred * t / 1000).detach()[0].permute(1,0,2,3), None, steps)
else:
comfy_pbar.update(1)
print_memory(device)
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
if force_offload:
if model["manual_offloading"]:
transformer.to(offload_device)
mm.soft_empty_cache()
gc.collect()
return ({
"samples": latents / VAE_SCALING_FACTOR
},)
NODE_CLASS_MAPPINGS = {
"HyVideoInverseSampler": HyVideoInverseSampler,
"HyVideoReSampler": HyVideoReSampler,
"HyVideoEmptyTextEmbeds": HyVideoEmptyTextEmbeds,
"HyVideoPromptMixSampler": HyVideoPromptMixSampler
}
NODE_DISPLAY_NAME_MAPPINGS = {
"HyVideoInverseSampler": "HunyuanVideo Inverse Sampler",
"HyVideoReSampler": "HunyuanVideo ReSampler",
"HyVideoEmptyTextEmbeds": "HunyuanVideo Empty Text Embeds",
"HyVideoPromptMixSampler": "HunyuanVideo Prompt Mix Sampler"
}