|
|
|
import torch |
|
import gc |
|
import os |
|
from .utils import log, print_memory |
|
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
import comfy.model_management as mm |
|
from .hyvideo.diffusion.pipelines.pipeline_hunyuan_video import get_rotary_pos_embed |
|
from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight |
|
|
|
script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
VAE_SCALING_FACTOR = 0.476986 |
|
|
|
def generate_eta_values( |
|
timesteps, |
|
start_step, |
|
end_step, |
|
eta, |
|
eta_trend, |
|
): |
|
assert start_step < end_step and start_step >= 0 and end_step <= len(timesteps), "Invalid start_step and end_step" |
|
|
|
|
|
eta_values = [0.0] * (len(timesteps) - 1) |
|
|
|
if eta_trend == 'constant': |
|
for i in range(start_step, end_step): |
|
eta_values[i] = eta |
|
elif eta_trend == 'linear_increase': |
|
total_time = timesteps[start_step] - timesteps[end_step - 1] |
|
for i in range(start_step, end_step): |
|
eta_values[i] = eta * (timesteps[start_step] - timesteps[i]) / total_time |
|
elif eta_trend == 'linear_decrease': |
|
total_time = timesteps[start_step] - timesteps[end_step - 1] |
|
for i in range(start_step, end_step): |
|
eta_values[i] = eta * (timesteps[i] - timesteps[end_step - 1]) / total_time |
|
else: |
|
raise NotImplementedError(f"Unsupported eta_trend: {eta_trend}") |
|
print("eta_values", eta_values) |
|
return eta_values |
|
|
|
class HyVideoEmptyTextEmbeds: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("HYVIDEMBEDS", ) |
|
RETURN_NAMES = ("hyvid_embeds",) |
|
FUNCTION = "process" |
|
CATEGORY = "HunyuanVideoWrapper" |
|
DESCRIPTION = "Empty Text Embeds for HunyuanVideoWrapper, to avoid having to encode prompts for inverse sampling" |
|
|
|
def process(self): |
|
device = mm.text_encoder_device() |
|
offload_device = mm.text_encoder_offload_device() |
|
|
|
prompt_embeds_dict = torch.load(os.path.join(script_directory, "hunyuan_empty_prompt_embeds_dict.pt")) |
|
return (prompt_embeds_dict,) |
|
|
|
|
|
class HyVideoInverseSampler: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"model": ("HYVIDEOMODEL",), |
|
"hyvid_embeds": ("HYVIDEMBEDS", ), |
|
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ), |
|
"steps": ("INT", {"default": 30, "min": 1}), |
|
"embedded_guidance_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.01}), |
|
"flow_shift": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 30.0, "step": 0.01}), |
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), |
|
"force_offload": ("BOOLEAN", {"default": True}), |
|
"gamma": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), |
|
"start_step": ("INT", {"default": 0, "min": 0}), |
|
"end_step": ("INT", {"default": 18, "min": 0}), |
|
"gamma_trend": (['constant', 'linear_increase', 'linear_decrease'], {"default": "constant"}), |
|
}, |
|
"optional": { |
|
"interpolation_curve": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("LATENT",) |
|
RETURN_NAMES = ("samples",) |
|
FUNCTION = "process" |
|
CATEGORY = "HunyuanVideoWrapper" |
|
|
|
def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, seed, samples, gamma, start_step, end_step, gamma_trend, force_offload, interpolation_curve=None): |
|
comfy_model_patcher = model |
|
model = model.model |
|
device = mm.get_torch_device() |
|
offload_device = mm.unet_offload_device() |
|
dtype = model["dtype"] |
|
transformer = model["pipe"].transformer |
|
pipeline = model["pipe"] |
|
|
|
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) |
|
|
|
latents = samples["samples"] * VAE_SCALING_FACTOR if samples is not None else None |
|
batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width = latents.shape |
|
height = latent_height * pipeline.vae_scale_factor |
|
width = latent_width * pipeline.vae_scale_factor |
|
num_frames = (latent_num_frames - 1) * 4 + 1 |
|
|
|
|
|
if width <= 0 or height <= 0 or num_frames <= 0: |
|
raise ValueError( |
|
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={num_frames}" |
|
) |
|
if (num_frames - 1) % 4 != 0: |
|
raise ValueError( |
|
f"`video_length - 1 (that's minus one frame)` must be a multiple of 4, got {num_frames}" |
|
) |
|
|
|
log.info( |
|
f"Input (height, width, video_length) = ({height}, {width}, {num_frames})" |
|
) |
|
|
|
freqs_cos, freqs_sin = get_rotary_pos_embed(transformer, latent_num_frames, height, width) |
|
freqs_cos = freqs_cos.to(device) |
|
freqs_sin = freqs_sin.to(device) |
|
|
|
pipeline.scheduler.flow_shift = flow_shift |
|
|
|
if model["block_swap_args"] is not None: |
|
for name, param in transformer.named_parameters(): |
|
|
|
if "single" not in name and "double" not in name: |
|
param.data = param.data.to(device) |
|
|
|
transformer.block_swap( |
|
model["block_swap_args"]["double_blocks_to_swap"] - 1 , |
|
model["block_swap_args"]["single_blocks_to_swap"] - 1, |
|
offload_txt_in = model["block_swap_args"]["offload_txt_in"], |
|
offload_img_in = model["block_swap_args"]["offload_img_in"], |
|
) |
|
elif model["auto_cpu_offload"]: |
|
for name, param in transformer.named_parameters(): |
|
if "single" not in name and "double" not in name: |
|
param.data = param.data.to(device) |
|
elif model["manual_offloading"]: |
|
transformer.to(device) |
|
|
|
mm.soft_empty_cache() |
|
gc.collect() |
|
|
|
try: |
|
torch.cuda.reset_peak_memory_stats(device) |
|
except: |
|
pass |
|
|
|
pipeline.scheduler.set_timesteps(steps, device=device) |
|
timesteps = pipeline.scheduler.timesteps |
|
timesteps = timesteps.flip(0) |
|
print("timesteps", timesteps) |
|
print("pipeline.scheduler.order", pipeline.scheduler.order) |
|
print("len(timesteps)", len(timesteps)) |
|
|
|
latent_video_length = (num_frames - 1) // 4 + 1 |
|
|
|
|
|
num_channels_latents = transformer.config.in_channels |
|
|
|
|
|
latents = latents.to(device) |
|
|
|
shape = ( |
|
1, |
|
num_channels_latents, |
|
latent_video_length, |
|
int(height) // pipeline.vae_scale_factor, |
|
int(width) // pipeline.vae_scale_factor, |
|
) |
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) |
|
|
|
frames_needed = noise.shape[1] |
|
current_frames = latents.shape[1] |
|
|
|
if frames_needed > current_frames: |
|
repeat_factor = frames_needed - current_frames |
|
additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device) |
|
latents = torch.cat((additional_frame, latents), dim=1) |
|
self.additional_frames = repeat_factor |
|
elif frames_needed < current_frames: |
|
latents = latents[:, :frames_needed, :, :, :] |
|
|
|
gamma_values = generate_eta_values(timesteps / 1000, start_step, end_step, gamma, gamma_trend) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - steps * pipeline.scheduler.order |
|
self._num_timesteps = len(timesteps) |
|
|
|
latents = latents.to(dtype) |
|
|
|
from latent_preview import prepare_callback |
|
callback = prepare_callback(comfy_model_patcher, steps) |
|
|
|
from comfy.utils import ProgressBar |
|
from tqdm import tqdm |
|
log.info(f"Sampling {num_frames} frames in {latents.shape[2]} latents at {width}x{height} with {len(timesteps)} inference steps") |
|
comfy_pbar = ProgressBar(len(timesteps)) |
|
with tqdm(total=len(timesteps)) as progress_bar: |
|
for idx, (t, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): |
|
latent_model_input = latents |
|
|
|
t_expand = t.repeat(latent_model_input.shape[0]) |
|
guidance_expand = ( |
|
torch.tensor( |
|
[embedded_guidance_scale] * latent_model_input.shape[0], |
|
dtype=torch.float32, |
|
device=device, |
|
).to(pipeline.base_dtype) |
|
* 1000.0 |
|
if embedded_guidance_scale is not None |
|
else None |
|
) |
|
|
|
|
|
with torch.autocast( |
|
device_type="cuda", dtype=pipeline.base_dtype, enabled=True |
|
): |
|
noise_pred = transformer( |
|
latent_model_input, |
|
t_expand, |
|
text_states=hyvid_embeds["prompt_embeds"], |
|
text_mask=hyvid_embeds["attention_mask"], |
|
text_states_2=hyvid_embeds["prompt_embeds_2"], |
|
freqs_cos=freqs_cos, |
|
freqs_sin=freqs_sin, |
|
guidance=guidance_expand, |
|
stg_block_idx=-1, |
|
stg_mode=None, |
|
return_dict=True, |
|
)["x"] |
|
sigma = t / 1000.0 |
|
sigma_prev = t_prev / 1000.0 |
|
latents = latents.to(torch.float32) |
|
noise_pred = noise_pred.to(torch.float32) |
|
target_noise_velocity = (noise - latents) / (1.0 - sigma) |
|
|
|
if interpolation_curve is not None: |
|
time_weights = torch.tensor(interpolation_curve, device=latents.device) |
|
assert time_weights.shape[0] == latents.shape[2], f"Weight list length {len(interpolation_curve)} must match temporal dimension {latents.shape[2]}" |
|
gamma = gamma_values[idx] * time_weights.view(1, 1, -1, 1, 1) |
|
else: |
|
gamma = gamma_values[idx] |
|
interpolated_velocity = gamma * target_noise_velocity + (1 - gamma) * noise_pred |
|
|
|
latents = latents + (sigma_prev - sigma) * interpolated_velocity |
|
latents = latents.to(torch.bfloat16) |
|
|
|
|
|
|
|
|
|
|
|
progress_bar.update() |
|
if callback is not None: |
|
callback(idx, (latent_model_input - noise_pred * t / 1000).detach()[0].permute(1,0,2,3), None, steps) |
|
else: |
|
comfy_pbar.update(1) |
|
|
|
|
|
print_memory(device) |
|
try: |
|
torch.cuda.reset_peak_memory_stats(device) |
|
except: |
|
pass |
|
|
|
if force_offload: |
|
if model["manual_offloading"]: |
|
transformer.to(offload_device) |
|
mm.soft_empty_cache() |
|
gc.collect() |
|
|
|
return ({ |
|
"samples": latents / VAE_SCALING_FACTOR |
|
},) |
|
|
|
|
|
class HyVideoReSampler: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"model": ("HYVIDEOMODEL",), |
|
"hyvid_embeds": ("HYVIDEMBEDS", ), |
|
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ), |
|
"inversed_latents": ("LATENT", {"tooltip": "inversed latents from HyVideoInverseSampler"} ), |
|
"steps": ("INT", {"default": 30, "min": 1}), |
|
"embedded_guidance_scale": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), |
|
"flow_shift": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 30.0, "step": 0.01}), |
|
"force_offload": ("BOOLEAN", {"default": True}), |
|
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "The step to start the effect of the inversed latents"}), |
|
"end_step": ("INT", {"default": 18, "min": 0, "tooltip": "The step to end the effect of the inversed latents"}), |
|
"eta_base": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The base value of the eta, overall strength of the effect from the inversed latents"}), |
|
"eta_trend": (['constant', 'linear_increase', 'linear_decrease'], {"default": "constant", "tooltip": "The trend of the eta value over steps"}), |
|
}, |
|
"optional": { |
|
"interpolation_curve": ("FLOAT", {"forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}), |
|
"feta_args": ("FETAARGS", ), |
|
|
|
} |
|
} |
|
|
|
RETURN_TYPES = ("LATENT",) |
|
RETURN_NAMES = ("samples",) |
|
FUNCTION = "process" |
|
CATEGORY = "HunyuanVideoWrapper" |
|
|
|
def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, |
|
samples, inversed_latents, force_offload, start_step, end_step, eta_base, eta_trend, interpolation_curve=None, feta_args=None): |
|
comfy_model_patcher = model |
|
model = model.model |
|
device = mm.get_torch_device() |
|
offload_device = mm.unet_offload_device() |
|
dtype = model["dtype"] |
|
transformer = model["pipe"].transformer |
|
pipeline = model["pipe"] |
|
|
|
target_latents = samples["samples"] * VAE_SCALING_FACTOR |
|
|
|
batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width = target_latents.shape |
|
height = latent_height * pipeline.vae_scale_factor |
|
width = latent_width * pipeline.vae_scale_factor |
|
num_frames = (latent_num_frames - 1) * 4 + 1 |
|
|
|
if width <= 0 or height <= 0 or num_frames <= 0: |
|
raise ValueError( |
|
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={num_frames}" |
|
) |
|
if (num_frames - 1) % 4 != 0: |
|
raise ValueError( |
|
f"`video_length - 1 (that's minus one frame)` must be a multiple of 4, got {num_frames}" |
|
) |
|
|
|
log.info( |
|
f"Input (height, width, video_length) = ({height}, {width}, {num_frames})" |
|
) |
|
|
|
freqs_cos, freqs_sin = get_rotary_pos_embed(transformer, latent_num_frames, height, width) |
|
freqs_cos = freqs_cos.to(device) |
|
freqs_sin = freqs_sin.to(device) |
|
|
|
pipeline.scheduler.flow_shift = flow_shift |
|
|
|
if model["block_swap_args"] is not None: |
|
for name, param in transformer.named_parameters(): |
|
|
|
if "single" not in name and "double" not in name: |
|
param.data = param.data.to(device) |
|
|
|
transformer.block_swap( |
|
model["block_swap_args"]["double_blocks_to_swap"] - 1 , |
|
model["block_swap_args"]["single_blocks_to_swap"] - 1, |
|
offload_txt_in = model["block_swap_args"]["offload_txt_in"], |
|
offload_img_in = model["block_swap_args"]["offload_img_in"], |
|
) |
|
elif model["auto_cpu_offload"]: |
|
for name, param in transformer.named_parameters(): |
|
if "single" not in name and "double" not in name: |
|
param.data = param.data.to(device) |
|
elif model["manual_offloading"]: |
|
transformer.to(device) |
|
|
|
mm.soft_empty_cache() |
|
gc.collect() |
|
|
|
try: |
|
torch.cuda.reset_peak_memory_stats(device) |
|
except: |
|
pass |
|
|
|
pipeline.scheduler.set_timesteps(steps, device=device) |
|
timesteps = pipeline.scheduler.timesteps |
|
|
|
eta_values = generate_eta_values(timesteps / 1000, start_step, end_step, eta_base, eta_trend) |
|
|
|
target_latents = target_latents.to(device=device, dtype=dtype) |
|
latents = inversed_latents["samples"] * VAE_SCALING_FACTOR |
|
latents = latents.to(device=device, dtype=dtype) |
|
|
|
|
|
self._num_timesteps = len(timesteps) |
|
|
|
from latent_preview import prepare_callback |
|
callback = prepare_callback(comfy_model_patcher, steps) |
|
|
|
if feta_args is not None: |
|
set_enhance_weight(feta_args["weight"]) |
|
feta_start_percent = feta_args["start_percent"] |
|
feta_end_percent = feta_args["end_percent"] |
|
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"]) |
|
else: |
|
disable_enhance() |
|
|
|
from comfy.utils import ProgressBar |
|
from tqdm import tqdm |
|
log.info(f"Sampling {num_frames} frames in {latents.shape[2]} latents at {width}x{height} with {len(timesteps)} inference steps") |
|
comfy_pbar = ProgressBar(len(timesteps)) |
|
|
|
with tqdm(total=len(timesteps)) as progress_bar: |
|
for idx, (t, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): |
|
|
|
current_step_percentage = idx / len(timesteps) |
|
|
|
if feta_args is not None: |
|
if feta_start_percent <= current_step_percentage <= feta_end_percent: |
|
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"]) |
|
else: |
|
disable_enhance() |
|
|
|
latent_model_input = latents |
|
|
|
t_expand = t.repeat(latent_model_input.shape[0]) |
|
guidance_expand = ( |
|
torch.tensor( |
|
[embedded_guidance_scale] * latent_model_input.shape[0], |
|
dtype=torch.float32, |
|
device=device, |
|
).to(pipeline.base_dtype) |
|
* 1000.0 |
|
if embedded_guidance_scale is not None |
|
else None |
|
) |
|
|
|
|
|
with torch.autocast( |
|
device_type="cuda", dtype=pipeline.base_dtype, enabled=True |
|
): |
|
noise_pred = transformer( |
|
latent_model_input, |
|
t_expand, |
|
text_states=hyvid_embeds["prompt_embeds"], |
|
text_mask=hyvid_embeds["attention_mask"], |
|
text_states_2=hyvid_embeds["prompt_embeds_2"], |
|
freqs_cos=freqs_cos, |
|
freqs_sin=freqs_sin, |
|
guidance=guidance_expand, |
|
stg_block_idx=-1, |
|
stg_mode=None, |
|
return_dict=True, |
|
)["x"] |
|
sigma = t / 1000.0 |
|
sigma_prev = t_prev / 1000.0 |
|
noise_pred = noise_pred.to(torch.float32) |
|
latents = latents.to(torch.float32) |
|
target_latents = target_latents.to(torch.float32) |
|
target_img_velocity = -(target_latents - latents) / sigma |
|
|
|
|
|
|
|
if interpolation_curve is not None: |
|
time_weights = torch.tensor(interpolation_curve, device=latents.device) |
|
assert time_weights.shape[0] == latents.shape[2], f"Weight list length {len(interpolation_curve)} must match temporal dimension {latents.shape[2]}" |
|
eta = eta_values[idx] * time_weights.view(1, 1, -1, 1, 1) |
|
else: |
|
eta = eta_values[idx] |
|
|
|
|
|
interpolated_velocity = eta * target_img_velocity + (1 - eta) * noise_pred |
|
latents = latents + (sigma_prev - sigma) * interpolated_velocity |
|
|
|
|
|
latents = latents.to(torch.bfloat16) |
|
|
|
progress_bar.update() |
|
if callback is not None: |
|
callback(idx, (latent_model_input - noise_pred * t / 1000).detach()[0].permute(1,0,2,3), None, steps) |
|
else: |
|
comfy_pbar.update(1) |
|
|
|
print_memory(device) |
|
try: |
|
torch.cuda.reset_peak_memory_stats(device) |
|
except: |
|
pass |
|
|
|
if force_offload: |
|
if model["manual_offloading"]: |
|
transformer.to(offload_device) |
|
mm.soft_empty_cache() |
|
gc.collect() |
|
|
|
return ({ |
|
"samples": latents / VAE_SCALING_FACTOR |
|
},) |
|
|
|
|
|
class HyVideoPromptMixSampler: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"model": ("HYVIDEOMODEL",), |
|
"hyvid_embeds": ("HYVIDEMBEDS", ), |
|
"hyvid_embeds_2": ("HYVIDEMBEDS", ), |
|
"width": ("INT", {"default": 512, "min": 1}), |
|
"height": ("INT", {"default": 512, "min": 1}), |
|
"num_frames": ("INT", {"default": 17, "min": 1}), |
|
"steps": ("INT", {"default": 30, "min": 1}), |
|
"embedded_guidance_scale": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), |
|
"flow_shift": ("FLOAT", {"default": 9.0, "min": 1.0, "max": 30.0, "step": 0.01}), |
|
"force_offload": ("BOOLEAN", {"default": True}), |
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), |
|
"alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Adjusts the blending sharpness"}), |
|
}, |
|
"optional": { |
|
"interpolation_curve": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}), |
|
"feta_args": ("FETAARGS", ), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("LATENT",) |
|
RETURN_NAMES = ("samples",) |
|
FUNCTION = "process" |
|
CATEGORY = "HunyuanVideoWrapper" |
|
EXPERIMENTAL = True |
|
|
|
def process(self, model, width, height, num_frames, hyvid_embeds, hyvid_embeds_2, flow_shift, steps, embedded_guidance_scale, |
|
seed, force_offload, alpha, interpolation_curve=None, feta_args=None): |
|
comfy_model_patcher = model |
|
model = model.model |
|
device = mm.get_torch_device() |
|
offload_device = mm.unet_offload_device() |
|
dtype = model["dtype"] |
|
transformer = model["pipe"].transformer |
|
pipeline = model["pipe"] |
|
|
|
|
|
if width <= 0 or height <= 0 or num_frames <= 0: |
|
raise ValueError( |
|
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={num_frames}" |
|
) |
|
if (num_frames - 1) % 4 != 0: |
|
raise ValueError( |
|
f"`video_length - 1 (that's minus one frame)` must be a multiple of 4, got {num_frames}" |
|
) |
|
|
|
log.info( |
|
f"Input (height, width, video_length) = ({height}, {width}, {num_frames})" |
|
) |
|
latent_video_length = (num_frames - 1) // 4 + 1 |
|
freqs_cos, freqs_sin = get_rotary_pos_embed(transformer, latent_video_length, height, width) |
|
freqs_cos = freqs_cos.to(device) |
|
freqs_sin = freqs_sin.to(device) |
|
|
|
pipeline.scheduler.flow_shift = flow_shift |
|
|
|
if model["block_swap_args"] is not None: |
|
for name, param in transformer.named_parameters(): |
|
|
|
if "single" not in name and "double" not in name: |
|
param.data = param.data.to(device) |
|
|
|
transformer.block_swap( |
|
model["block_swap_args"]["double_blocks_to_swap"] - 1 , |
|
model["block_swap_args"]["single_blocks_to_swap"] - 1, |
|
offload_txt_in = model["block_swap_args"]["offload_txt_in"], |
|
offload_img_in = model["block_swap_args"]["offload_img_in"], |
|
) |
|
elif model["auto_cpu_offload"]: |
|
for name, param in transformer.named_parameters(): |
|
if "single" not in name and "double" not in name: |
|
param.data = param.data.to(device) |
|
elif model["manual_offloading"]: |
|
transformer.to(device) |
|
|
|
mm.soft_empty_cache() |
|
gc.collect() |
|
|
|
try: |
|
torch.cuda.reset_peak_memory_stats(device) |
|
except: |
|
pass |
|
|
|
pipeline.scheduler.set_timesteps(steps, device=device) |
|
timesteps = pipeline.scheduler.timesteps |
|
|
|
|
|
|
|
shape = ( |
|
1, |
|
16, |
|
latent_video_length, |
|
int(height) // pipeline.vae_scale_factor, |
|
int(width) // pipeline.vae_scale_factor, |
|
) |
|
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) |
|
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) |
|
|
|
|
|
llm_embeds_1 = hyvid_embeds["prompt_embeds"].to(dtype).to(device) |
|
clip_embeds_1 = hyvid_embeds["prompt_embeds_2"].to(dtype).to(device) |
|
mask_1 = hyvid_embeds["attention_mask"].to(device) |
|
llm_embeds_2 = hyvid_embeds_2["prompt_embeds"].to(dtype).to(device) |
|
clip_embeds_2 = hyvid_embeds_2["prompt_embeds_2"].to(dtype).to(device) |
|
mask_2 = hyvid_embeds_2["attention_mask"].to(device) |
|
text_embeds = torch.cat((llm_embeds_1, llm_embeds_2), dim=0) |
|
text_mask = torch.cat((mask_1, mask_2), dim=0) |
|
clip_embeds = torch.cat((clip_embeds_1, clip_embeds_2), dim=0) |
|
assert len(interpolation_curve) == latents.shape[2], f"Weight list length {len(interpolation_curve)} must match temporal dimension {latents.shape[2]}" |
|
latents_1 = latents.clone() |
|
latents_2 = latents.clone() |
|
|
|
if feta_args is not None: |
|
set_enhance_weight(feta_args["weight"]) |
|
feta_start_percent = feta_args["start_percent"] |
|
feta_end_percent = feta_args["end_percent"] |
|
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"]) |
|
else: |
|
disable_enhance() |
|
|
|
|
|
self._num_timesteps = len(timesteps) |
|
|
|
from latent_preview import prepare_callback |
|
callback = prepare_callback(comfy_model_patcher, steps) |
|
|
|
from comfy.utils import ProgressBar |
|
from tqdm import tqdm |
|
log.info(f"Sampling {num_frames} frames in {latents.shape[2]} latents at {width}x{height} with {len(timesteps)} inference steps") |
|
comfy_pbar = ProgressBar(len(timesteps)) |
|
|
|
with tqdm(total=len(timesteps)) as progress_bar: |
|
for idx, t in enumerate(timesteps): |
|
current_step_percentage = idx / len(timesteps) |
|
|
|
if feta_args is not None: |
|
if feta_start_percent <= current_step_percentage <= feta_end_percent: |
|
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"]) |
|
else: |
|
disable_enhance() |
|
|
|
|
|
weighted_latents_1 = torch.zeros_like(latents_1) |
|
weighted_latents_2 = torch.zeros_like(latents_2) |
|
|
|
for t_idx in range(latents_1.shape[2]): |
|
weight = interpolation_curve[t_idx] |
|
weighted_latents_1[..., t_idx, :, :] = ( |
|
(1 - alpha * weight) * latents_1[..., t_idx, :, :] + |
|
(alpha * weight) * latents_2[..., t_idx, :, :] |
|
) |
|
weighted_latents_2[..., t_idx, :, :] = ( |
|
(1 - alpha * (1-weight)) * latents_2[..., t_idx, :, :] + |
|
(alpha * (1-weight)) * latents_1[..., t_idx, :, :] |
|
) |
|
|
|
|
|
latent_model_input = torch.cat([weighted_latents_1, weighted_latents_2]) |
|
|
|
t_expand = t.repeat(latent_model_input.shape[0]) |
|
guidance_expand = ( |
|
torch.tensor( |
|
[embedded_guidance_scale] * latent_model_input.shape[0], |
|
dtype=torch.float32, |
|
device=device, |
|
).to(pipeline.base_dtype) |
|
* 1000.0 |
|
if embedded_guidance_scale is not None |
|
else None |
|
) |
|
|
|
|
|
with torch.autocast( |
|
device_type="cuda", dtype=pipeline.base_dtype, enabled=True |
|
): |
|
noise_pred = transformer( |
|
latent_model_input, |
|
t_expand, |
|
text_states=text_embeds, |
|
text_mask=text_mask, |
|
text_states_2=clip_embeds, |
|
freqs_cos=freqs_cos, |
|
freqs_sin=freqs_sin, |
|
guidance=guidance_expand, |
|
stg_block_idx=-1, |
|
stg_mode=None, |
|
return_dict=True, |
|
)["x"] |
|
|
|
|
|
noise_pred = noise_pred.to(torch.float32) |
|
|
|
noise_pred_prompt_1, noise_pred_prompt_2 = noise_pred.chunk(2) |
|
|
|
|
|
|
|
dt = pipeline.scheduler.sigmas[idx + 1] - pipeline.scheduler.sigmas[idx] |
|
|
|
latents_1 = latents_1 + noise_pred_prompt_1 * dt |
|
latents_2 = latents_2 + noise_pred_prompt_2 * dt |
|
|
|
|
|
|
|
interpolated_latents = torch.zeros_like(latents_1) |
|
for t_idx in range(latents.shape[2]): |
|
weight = interpolation_curve[t_idx] |
|
interpolated_latents[..., t_idx, :, :] = ( |
|
(1 - weight) * latents_1[..., t_idx, :, :] + |
|
weight * latents_2[..., t_idx, :, :] |
|
) |
|
|
|
latents = interpolated_latents |
|
|
|
progress_bar.update() |
|
if callback is not None: |
|
callback(idx, (latent_model_input - noise_pred * t / 1000).detach()[0].permute(1,0,2,3), None, steps) |
|
else: |
|
comfy_pbar.update(1) |
|
|
|
print_memory(device) |
|
try: |
|
torch.cuda.reset_peak_memory_stats(device) |
|
except: |
|
pass |
|
|
|
if force_offload: |
|
if model["manual_offloading"]: |
|
transformer.to(offload_device) |
|
mm.soft_empty_cache() |
|
gc.collect() |
|
|
|
return ({ |
|
"samples": latents / VAE_SCALING_FACTOR |
|
},) |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"HyVideoInverseSampler": HyVideoInverseSampler, |
|
"HyVideoReSampler": HyVideoReSampler, |
|
"HyVideoEmptyTextEmbeds": HyVideoEmptyTextEmbeds, |
|
"HyVideoPromptMixSampler": HyVideoPromptMixSampler |
|
} |
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"HyVideoInverseSampler": "HunyuanVideo Inverse Sampler", |
|
"HyVideoReSampler": "HunyuanVideo ReSampler", |
|
"HyVideoEmptyTextEmbeds": "HunyuanVideo Empty Text Embeds", |
|
"HyVideoPromptMixSampler": "HunyuanVideo Prompt Mix Sampler" |
|
} |