import os import torch import json import gc from .utils import log, print_memory from diffusers.video_processor import VideoProcessor from typing import List, Dict, Any, Tuple from .hyvideo.constants import PROMPT_TEMPLATE from .hyvideo.text_encoder import TextEncoder from .hyvideo.utils.data_utils import align_to from .hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler from .hyvideo.diffusion.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .hyvideo.diffusion.schedulers.scheduling_sasolver import SASolverScheduler from. hyvideo.diffusion.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler # from diffusers.schedulers import ( # DDIMScheduler, # PNDMScheduler, # DPMSolverMultistepScheduler, # EulerDiscreteScheduler, # EulerAncestralDiscreteScheduler, # UniPCMultistepScheduler, # HeunDiscreteScheduler, # SASolverScheduler, # DEISMultistepScheduler, # LCMScheduler # ) scheduler_mapping = { "FlowMatchDiscreteScheduler": FlowMatchDiscreteScheduler, "SDE-DPMSolverMultistepScheduler": DPMSolverMultistepScheduler, "DPMSolverMultistepScheduler": DPMSolverMultistepScheduler, "SASolverScheduler": SASolverScheduler, "UniPCMultistepScheduler": UniPCMultistepScheduler, } available_schedulers = list(scheduler_mapping.keys()) from .hyvideo.diffusion.pipelines import HunyuanVideoPipeline from .hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D from .hyvideo.modules.models import HYVideoDiffusionTransformer from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device import folder_paths folder_paths.add_model_folder_path("hyvid_embeds", os.path.join(folder_paths.get_output_directory(), "hyvid_embeds")) import comfy.model_management as mm from comfy.utils import load_torch_file, save_torch_file import comfy.model_base import comfy.latent_formats script_directory = os.path.dirname(os.path.abspath(__file__)) VAE_SCALING_FACTOR = 0.476986 def add_noise_to_reference_video(image, ratio=None): if ratio is None: sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) sigma = torch.exp(sigma).to(image.dtype) else: sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) image = image + image_noise return image def filter_state_dict_by_blocks(state_dict, blocks_mapping): filtered_dict = {} for key in state_dict: if 'double_blocks.' in key or 'single_blocks.' in key: block_pattern = key.split('diffusion_model.')[1].split('.', 2)[0:2] block_key = f'{block_pattern[0]}.{block_pattern[1]}.' if block_key in blocks_mapping: filtered_dict[key] = state_dict[key] return filtered_dict def standardize_lora_key_format(lora_sd): new_sd = {} for k, v in lora_sd.items(): # Diffusers format if k.startswith('transformer.'): k = k.replace('transformer.', 'diffusion_model.') if "img_attn.proj" in k: k = k.replace("img_attn.proj", "img_attn_proj") if "img_attn.qkv" in k: k = k.replace("img_attn.qkv", "img_attn_qkv") if "txt_attn.proj" in k: k = k.replace("txt_attn.proj ", "txt_attn_proj") if "txt_attn.qkv" in k: k = k.replace("txt_attn.qkv", "txt_attn_qkv") new_sd[k] = v return new_sd class HyVideoLoraBlockEdit: def __init__(self): self.loaded_lora = None @classmethod def INPUT_TYPES(s): arg_dict = {} argument = ("BOOLEAN", {"default": True}) for i in range(20): arg_dict["double_blocks.{}.".format(i)] = argument for i in range(40): arg_dict["single_blocks.{}.".format(i)] = argument return {"required": arg_dict} RETURN_TYPES = ("SELECTEDBLOCKS", ) RETURN_NAMES = ("blocks", ) OUTPUT_TOOLTIPS = ("The modified diffusion model.",) FUNCTION = "select" CATEGORY = "HunyuanVideoWrapper" def select(self, **kwargs): selected_blocks = {k: v for k, v in kwargs.items() if v is True} print("Selected blocks: ", selected_blocks) return (selected_blocks,) class HyVideoLoraSelect: @classmethod def INPUT_TYPES(s): return { "required": { "lora": (folder_paths.get_filename_list("loras"), {"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), }, "optional": { "prev_lora":("HYVIDLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}), "blocks":("SELECTEDBLOCKS", ), } } RETURN_TYPES = ("HYVIDLORA",) RETURN_NAMES = ("lora", ) FUNCTION = "getlorapath" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras" def getlorapath(self, lora, strength, blocks=None, prev_lora=None, fuse_lora=False): loras_list = [] lora = { "path": folder_paths.get_full_path("loras", lora), "strength": strength, "name": lora.split(".")[0], "fuse_lora": fuse_lora, "blocks": blocks } if prev_lora is not None: loras_list.extend(prev_lora) loras_list.append(lora) return (loras_list,) class HyVideoBlockSwap: @classmethod def INPUT_TYPES(s): return { "required": { "double_blocks_to_swap": ("INT", {"default": 20, "min": 0, "max": 20, "step": 1, "tooltip": "Number of double blocks to swap"}), "single_blocks_to_swap": ("INT", {"default": 0, "min": 0, "max": 40, "step": 1, "tooltip": "Number of single blocks to swap"}), "offload_txt_in": ("BOOLEAN", {"default": False, "tooltip": "Offload txt_in layer"}), "offload_img_in": ("BOOLEAN", {"default": False, "tooltip": "Offload img_in layer"}), }, } RETURN_TYPES = ("BLOCKSWAPARGS",) RETURN_NAMES = ("block_swap_args",) FUNCTION = "setargs" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "Settings for block swapping, reduces VRAM use by swapping blocks to CPU memory" def setargs(self, **kwargs): return (kwargs, ) class HyVideoEnhanceAVideo: @classmethod def INPUT_TYPES(s): return { "required": { "weight": ("FLOAT", {"default": 2.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}), "single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Enable Enhance-A-Video for single blocks"}), "double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Enable Enhance-A-Video for double blocks"}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}), }, } RETURN_TYPES = ("FETAARGS",) RETURN_NAMES = ("feta_args",) FUNCTION = "setargs" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video" def setargs(self, **kwargs): return (kwargs, ) class HyVideoSTG: @classmethod def INPUT_TYPES(s): return { "required": { "stg_mode": (["STG-A", "STG-R"],), "stg_block_idx": ("INT", {"default": 0, "min": -1, "max": 39, "step": 1, "tooltip": "Block index to apply STG"}), "stg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Recommended values are ≤2.0"}), "stg_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply STG"}), "stg_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply STG"}), }, } RETURN_TYPES = ("STGARGS",) RETURN_NAMES = ("stg_args",) FUNCTION = "setargs" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "Spatio Temporal Guidance, https://github.com/junhahyung/STGuidance" def setargs(self, **kwargs): return (kwargs, ) class HyVideoTeaCache: @classmethod def INPUT_TYPES(s): return { "required": { "rel_l1_thresh": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Higher values will make TeaCache more aggressive, faster, but may cause artifacts"}), }, } RETURN_TYPES = ("TEACACHEARGS",) RETURN_NAMES = ("teacache_args",) FUNCTION = "process" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "TeaCache settings for HunyuanVideo to speed up inference" def process(self, rel_l1_thresh): teacache_args = { "rel_l1_thresh": rel_l1_thresh, } return (teacache_args,) class HyVideoModel(comfy.model_base.BaseModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pipeline = {} def __getitem__(self, k): return self.pipeline[k] def __setitem__(self, k, v): self.pipeline[k] = v class HyVideoModelConfig: def __init__(self, dtype): self.unet_config = {} self.unet_extra_config = {} self.latent_format = comfy.latent_formats.HunyuanVideo self.latent_format.latent_channels = 16 self.manual_cast_dtype = dtype self.sampling_settings = {"multiplier": 1.0} # Don't know what this is. Value taken from ComfyUI Mochi model. self.memory_usage_factor = 2.0 # denoiser is handled by extension self.unet_config["disable_unet_model_creation"] = True #region Model loading class HyVideoModelLoader: @classmethod def INPUT_TYPES(s): return { "required": { "model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), "base_precision": (["fp32", "bf16"], {"default": "bf16"}), "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2', 'fp8_scaled', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6", "torchao_int4", "torchao_int8"], {"default": 'disabled', "tooltip": "optional quantization method"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}), }, "optional": { "attention_mode": ([ "sdpa", "flash_attn_varlen", "sageattn_varlen", "sageattn", "comfy", ], {"default": "flash_attn"}), "compile_args": ("COMPILEARGS", ), "block_swap_args": ("BLOCKSWAPARGS", ), "lora": ("HYVIDLORA", {"default": None}), "auto_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "Enable auto offloading for reduced VRAM usage, implementation from DiffSynth-Studio, slightly different from block swapping and uses even less VRAM, but can be slower as you can't define how much VRAM to use"}), "upcast_rope": ("BOOLEAN", {"default": True, "tooltip": "Upcast RoPE to fp32 for better accuracy, this is the default behaviour, disabling can improve speed and reduce memory use slightly"}), } } RETURN_TYPES = ("HYVIDEOMODEL",) RETURN_NAMES = ("model", ) FUNCTION = "loadmodel" CATEGORY = "HunyuanVideoWrapper" def loadmodel(self, model, base_precision, load_device, quantization, compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None, auto_cpu_offload=False, upcast_rope=True): transformer = None #mm.unload_all_models() mm.soft_empty_cache() manual_offloading = True if "sage" in attention_mode: try: from sageattention import sageattn_varlen except Exception as e: raise ValueError(f"Can't import SageAttention: {str(e)}") device = mm.get_torch_device() offload_device = mm.unet_offload_device() manual_offloading = True transformer_load_device = device if load_device == "main_device" else offload_device base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[base_precision] model_path = folder_paths.get_full_path_or_raise("diffusion_models", model) sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True) in_channels = sd["img_in.proj.weight"].shape[1] out_channels = 16 factor_kwargs = {"device": transformer_load_device, "dtype": base_dtype} HUNYUAN_VIDEO_CONFIG = { "mm_double_blocks_depth": 20, "mm_single_blocks_depth": 40, "rope_dim_list": [16, 56, 56], "hidden_size": 3072, "heads_num": 24, "mlp_width_ratio": 4, "guidance_embed": True, } with init_empty_weights(): transformer = HYVideoDiffusionTransformer( in_channels=in_channels, out_channels=out_channels, attention_mode=attention_mode, main_device=device, offload_device=offload_device, **HUNYUAN_VIDEO_CONFIG, **factor_kwargs ) transformer.eval() transformer.upcast_rope = upcast_rope comfy_model = HyVideoModel( HyVideoModelConfig(base_dtype), model_type=comfy.model_base.ModelType.FLOW, device=device, ) scheduler_config = { "flow_shift": 9.0, "reverse": True, "solver": "euler", "use_flow_sigmas": True, "prediction_type": 'flow_prediction' } scheduler = FlowMatchDiscreteScheduler.from_config(scheduler_config) pipe = HunyuanVideoPipeline( transformer=transformer, scheduler=scheduler, progress_bar_config=None, base_dtype=base_dtype, comfy_model=comfy_model, ) if not "torchao" in quantization: log.info("Using accelerate to load and assign model weights to device...") if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast" or quantization == "fp8_scaled": dtype = torch.float8_e4m3fn elif quantization == "fp8_e5m2": dtype = torch.float8_e5m2 else: dtype = base_dtype params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"} for name, param in transformer.named_parameters(): #print("Assigning Parameter name: ", name) dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name]) comfy_model.diffusion_model = transformer patcher = comfy.model_patcher.ModelPatcher(comfy_model, device, offload_device) pipe.comfy_model = patcher del sd gc.collect() mm.soft_empty_cache() if lora is not None: from comfy.sd import load_lora_for_models for l in lora: log.info(f"Loading LoRA: {l['name']} with strength: {l['strength']}") lora_path = l["path"] lora_strength = l["strength"] lora_sd = load_torch_file(lora_path, safe_load=True) lora_sd = standardize_lora_key_format(lora_sd) if l["blocks"]: lora_sd = filter_state_dict_by_blocks(lora_sd, l["blocks"]) #for k in lora_sd.keys(): # print(k) patcher, _ = load_lora_for_models(patcher, None, lora_sd, lora_strength, 0) comfy.model_management.load_models_gpu([patcher]) if load_device == "offload_device": patcher.model.diffusion_model.to(offload_device) if quantization == "fp8_e4m3fn_fast": from .fp8_optimization import convert_fp8_linear convert_fp8_linear(patcher.model.diffusion_model, base_dtype, params_to_keep=params_to_keep) elif quantization == "fp8_scaled": from .hyvideo.modules.fp8_optimization import convert_fp8_linear convert_fp8_linear(patcher.model.diffusion_model, base_dtype) if auto_cpu_offload: transformer.enable_auto_offload(dtype=dtype, device=device) #compile if compile_args is not None: torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] if compile_args["compile_single_blocks"]: for i, block in enumerate(patcher.model.diffusion_model.single_blocks): patcher.model.diffusion_model.single_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) if compile_args["compile_double_blocks"]: for i, block in enumerate(patcher.model.diffusion_model.double_blocks): patcher.model.diffusion_model.double_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) if compile_args["compile_txt_in"]: patcher.model.diffusion_model.txt_in = torch.compile(patcher.model.diffusion_model.txt_in, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) if compile_args["compile_vector_in"]: patcher.model.diffusion_model.vector_in = torch.compile(patcher.model.diffusion_model.vector_in, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) if compile_args["compile_final_layer"]: patcher.model.diffusion_model.final_layer = torch.compile(patcher.model.diffusion_model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) elif "torchao" in quantization: try: from torchao.quantization import ( quantize_, fpx_weight_only, float8_dynamic_activation_float8_weight, int8_dynamic_activation_int8_weight, int8_weight_only, int4_weight_only ) except: raise ImportError("torchao is not installed") # def filter_fn(module: nn.Module, fqn: str) -> bool: # target_submodules = {'attn1', 'ff'} # avoid norm layers, 1.5 at least won't work with quantized norm1 #todo: test other models # if any(sub in fqn for sub in target_submodules): # return isinstance(module, nn.Linear) # return False if "fp6" in quantization: quant_func = fpx_weight_only(3, 2) elif "int4" in quantization: quant_func = int4_weight_only() elif "int8" in quantization: quant_func = int8_weight_only() elif "fp8dq" in quantization: quant_func = float8_dynamic_activation_float8_weight() elif 'fp8dqrow' in quantization: from torchao.quantization.quant_api import PerRow quant_func = float8_dynamic_activation_float8_weight(granularity=PerRow()) elif 'int8dq' in quantization: quant_func = int8_dynamic_activation_int8_weight() log.info(f"Quantizing model with {quant_func}") comfy_model.diffusion_model = transformer patcher = comfy.model_patcher.ModelPatcher(comfy_model, device, offload_device) if lora is not None: from comfy.sd import load_lora_for_models for l in lora: lora_path = l["path"] lora_strength = l["strength"] lora_sd = load_torch_file(lora_path, safe_load=True) lora_sd = standardize_lora_key_format(lora_sd) patcher, _ = load_lora_for_models(patcher, None, lora_sd, lora_strength, 0) comfy.model_management.load_models_gpu([patcher]) for i, block in enumerate(patcher.model.diffusion_model.single_blocks): log.info(f"Quantizing single_block {i}") for name, _ in block.named_parameters(prefix=f"single_blocks.{i}"): #print(f"Parameter name: {name}") set_module_tensor_to_device(patcher.model.diffusion_model, name, device=patcher.model.diffusion_model_load_device, dtype=base_dtype, value=sd[name]) if compile_args is not None: patcher.model.diffusion_model.single_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) quantize_(block, quant_func) print(block) block.to(offload_device) for i, block in enumerate(patcher.model.diffusion_model.double_blocks): log.info(f"Quantizing double_block {i}") for name, _ in block.named_parameters(prefix=f"double_blocks.{i}"): #print(f"Parameter name: {name}") set_module_tensor_to_device(patcher.model.diffusion_model, name, device=patcher.model.diffusion_model_load_device, dtype=base_dtype, value=sd[name]) if compile_args is not None: patcher.model.diffusion_model.double_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) quantize_(block, quant_func) for name, param in patcher.model.diffusion_model.named_parameters(): if "single_blocks" not in name and "double_blocks" not in name: set_module_tensor_to_device(patcher.model.diffusion_model, name, device=patcher.model.diffusion_model_load_device, dtype=base_dtype, value=sd[name]) manual_offloading = False # to disable manual .to(device) calls log.info(f"Quantized transformer blocks to {quantization}") for name, param in patcher.model.diffusion_model.named_parameters(): print(name, param.dtype) #param.data = param.data.to(self.vae_dtype).to(device) del sd mm.soft_empty_cache() patcher.model["pipe"] = pipe patcher.model["dtype"] = base_dtype patcher.model["base_path"] = model_path patcher.model["model_name"] = model patcher.model["manual_offloading"] = manual_offloading patcher.model["quantization"] = "disabled" patcher.model["block_swap_args"] = block_swap_args patcher.model["auto_cpu_offload"] = auto_cpu_offload patcher.model["scheduler_config"] = scheduler_config for model in mm.current_loaded_models: if model._model() == patcher: mm.current_loaded_models.remove(model) return (patcher,) #region load VAE class HyVideoVAELoader: @classmethod def INPUT_TYPES(s): return { "required": { "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}), }, "optional": { "precision": (["fp16", "fp32", "bf16"], {"default": "bf16"} ), "compile_args":("COMPILEARGS", ), } } RETURN_TYPES = ("VAE",) RETURN_NAMES = ("vae", ) FUNCTION = "loadmodel" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "Loads Hunyuan VAE model from 'ComfyUI/models/vae'" def loadmodel(self, model_name, precision, compile_args=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] with open(os.path.join(script_directory, 'configs', 'hy_vae_config.json')) as f: vae_config = json.load(f) model_path = folder_paths.get_full_path("vae", model_name) vae_sd = load_torch_file(model_path, safe_load=True) vae = AutoencoderKLCausal3D.from_config(vae_config) vae.load_state_dict(vae_sd) del vae_sd vae.requires_grad_(False) vae.eval() vae.to(device = device, dtype = dtype) #compile if compile_args is not None: torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] vae = torch.compile(vae, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) return (vae,) class HyVideoTorchCompileSettings: @classmethod def INPUT_TYPES(s): return { "required": { "backend": (["inductor","cudagraphs"], {"default": "inductor"}), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), "compile_single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}), "compile_double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}), "compile_txt_in": ("BOOLEAN", {"default": False, "tooltip": "Compile txt_in layers"}), "compile_vector_in": ("BOOLEAN", {"default": False, "tooltip": "Compile vector_in layers"}), "compile_final_layer": ("BOOLEAN", {"default": False, "tooltip": "Compile final layer"}), }, } RETURN_TYPES = ("COMPILEARGS",) RETURN_NAMES = ("torch_compile_args",) FUNCTION = "loadmodel" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_single_blocks, compile_double_blocks, compile_txt_in, compile_vector_in, compile_final_layer): compile_args = { "backend": backend, "fullgraph": fullgraph, "mode": mode, "dynamic": dynamic, "dynamo_cache_size_limit": dynamo_cache_size_limit, "compile_single_blocks": compile_single_blocks, "compile_double_blocks": compile_double_blocks, "compile_txt_in": compile_txt_in, "compile_vector_in": compile_vector_in, "compile_final_layer": compile_final_layer } return (compile_args, ) #region TextEncode class DownloadAndLoadHyVideoTextEncoder: @classmethod def INPUT_TYPES(s): return { "required": { "llm_model": (["Kijai/llava-llama-3-8b-text-encoder-tokenizer","xtuner/llava-llama-3-8b-v1_1-transformers"],), "clip_model": (["disabled","openai/clip-vit-large-patch14",],), "precision": (["fp16", "fp32", "bf16"], {"default": "bf16"} ), }, "optional": { "apply_final_norm": ("BOOLEAN", {"default": False}), "hidden_state_skip_layer": ("INT", {"default": 2}), "quantization": (['disabled', 'bnb_nf4', "fp8_e4m3fn"], {"default": 'disabled'}), "load_device": (["main_device", "offload_device"], {"default": "offload_device"}), } } RETURN_TYPES = ("HYVIDTEXTENCODER",) RETURN_NAMES = ("hyvid_text_encoder", ) FUNCTION = "loadmodel" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "Loads Hunyuan text_encoder model from 'ComfyUI/models/LLM'" def loadmodel(self, llm_model, clip_model, precision, apply_final_norm=False, hidden_state_skip_layer=2, quantization="disabled", load_device="offload_device"): lm_type_mapping = { "Kijai/llava-llama-3-8b-text-encoder-tokenizer": "llm", "xtuner/llava-llama-3-8b-v1_1-transformers": "vlm", } lm_type = lm_type_mapping[llm_model] device = mm.get_torch_device() offload_device = mm.unet_offload_device() text_encoder_load_device = device if load_device == "main_device" else offload_device dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] quantization_config = None if quantization == "bnb_nf4": from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 ) if clip_model != "disabled": clip_model_path = os.path.join(folder_paths.models_dir, "clip", "clip-vit-large-patch14") if not os.path.exists(clip_model_path): log.info(f"Downloading clip model to: {clip_model_path}") from huggingface_hub import snapshot_download snapshot_download( repo_id=clip_model, ignore_patterns=["*.msgpack", "*.bin", "*.h5"], local_dir=clip_model_path, local_dir_use_symlinks=False, ) text_encoder_2 = TextEncoder( text_encoder_path=clip_model_path, text_encoder_type="clipL", max_length=77, text_encoder_precision=precision, tokenizer_type="clipL", logger=log, device=text_encoder_load_device, ) else: text_encoder_2 = None download_path = os.path.join(folder_paths.models_dir,"LLM") base_path = os.path.join(download_path, (llm_model.split("/")[-1])) if not os.path.exists(base_path): log.info(f"Downloading model to: {base_path}") from huggingface_hub import snapshot_download snapshot_download( repo_id=llm_model, local_dir=base_path, local_dir_use_symlinks=False, ) text_encoder = TextEncoder( text_encoder_path=base_path, text_encoder_type=lm_type, max_length=256, text_encoder_precision=precision, tokenizer_type=lm_type, hidden_state_skip_layer=hidden_state_skip_layer, apply_final_norm=apply_final_norm, logger=log, device=text_encoder_load_device, dtype=dtype, quantization_config=quantization_config ) if quantization == "fp8_e4m3fn": text_encoder.is_fp8 = True text_encoder.to(torch.float8_e4m3fn) def forward_hook(module): def forward(hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon) return module.weight.to(input_dtype) * hidden_states.to(input_dtype) return forward for module in text_encoder.model.modules(): if module.__class__.__name__ in ["Embedding"]: module.to(dtype) if module.__class__.__name__ in ["LlamaRMSNorm"]: module.forward = forward_hook(module) else: text_encoder.is_fp8 = False hyvid_text_encoders = { "text_encoder": text_encoder, "text_encoder_2": text_encoder_2, } return (hyvid_text_encoders,) class HyVideoCustomPromptTemplate: @classmethod def INPUT_TYPES(s): return {"required": { "custom_prompt_template": ("STRING", {"default": f"{PROMPT_TEMPLATE['dit-llm-encode-video']['template']}", "multiline": True}), "crop_start": ("INT", {"default": PROMPT_TEMPLATE['dit-llm-encode-video']["crop_start"], "tooltip": "To cropt the system prompt"}), }, } RETURN_TYPES = ("PROMPT_TEMPLATE", ) RETURN_NAMES = ("hyvid_prompt_template",) FUNCTION = "process" CATEGORY = "HunyuanVideoWrapper" def process(self, custom_prompt_template, crop_start): prompt_template_dict = { "template": custom_prompt_template, "crop_start": crop_start, } return (prompt_template_dict,) class HyVideoTextEncode: @classmethod def INPUT_TYPES(s): return {"required": { "text_encoders": ("HYVIDTEXTENCODER",), "prompt": ("STRING", {"default": "", "multiline": True} ), }, "optional": { "force_offload": ("BOOLEAN", {"default": True}), "prompt_template": (["video", "image", "custom", "disabled"], {"default": "video", "tooltip": "Use the default prompt templates for the llm text encoder"}), "custom_prompt_template": ("PROMPT_TEMPLATE", {"default": PROMPT_TEMPLATE["dit-llm-encode-video"], "multiline": True}), "clip_l": ("CLIP", {"tooltip": "Use comfy clip model instead, in this case the text encoder loader's clip_l should be disabled"}), "hyvid_cfg": ("HYVID_CFG", ), } } RETURN_TYPES = ("HYVIDEMBEDS", ) RETURN_NAMES = ("hyvid_embeds",) FUNCTION = "process" CATEGORY = "HunyuanVideoWrapper" def process(self, text_encoders, prompt, force_offload=True, prompt_template="video", custom_prompt_template=None, clip_l=None, image_token_selection_expr="::4", hyvid_cfg=None, image1=None, image2=None, clip_text_override=None): if clip_text_override is not None and len(clip_text_override) == 0: clip_text_override = None device = mm.text_encoder_device() offload_device = mm.text_encoder_offload_device() text_encoder_1 = text_encoders["text_encoder"] if clip_l is None: text_encoder_2 = text_encoders["text_encoder_2"] else: text_encoder_2 = None if hyvid_cfg is not None: negative_prompt = hyvid_cfg["negative_prompt"] do_classifier_free_guidance = True else: do_classifier_free_guidance = False negative_prompt = None if prompt_template != "disabled": if prompt_template == "custom": prompt_template_dict = custom_prompt_template elif prompt_template == "video": prompt_template_dict = PROMPT_TEMPLATE["dit-llm-encode-video"] elif prompt_template == "image": prompt_template_dict = PROMPT_TEMPLATE["dit-llm-encode"] else: raise ValueError(f"Invalid prompt_template: {prompt_template_dict}") assert ( isinstance(prompt_template_dict, dict) and "template" in prompt_template_dict ), f"`prompt_template` must be a dictionary with a key 'template', got {prompt_template_dict}" assert "{}" in str(prompt_template_dict["template"]), ( "`prompt_template['template']` must contain a placeholder `{}` for the input text, " f"got {prompt_template_dict['template']}" ) else: prompt_template_dict = None def encode_prompt(self, prompt, negative_prompt, text_encoder, image_token_selection_expr="::4", image1=None, image2=None, clip_text_override=None): batch_size = 1 num_videos_per_prompt = 1 text_inputs = text_encoder.text2tokens(prompt, prompt_template=prompt_template_dict, image1=image1, image2=image2, clip_text_override=clip_text_override) prompt_outputs = text_encoder.encode(text_inputs, prompt_template=prompt_template_dict, image_token_selection_expr=image_token_selection_expr, device=device ) prompt_embeds = prompt_outputs.hidden_state attention_mask = prompt_outputs.attention_mask log.info(f"{text_encoder.text_encoder_type} prompt attention_mask shape: {attention_mask.shape}, masked tokens: {attention_mask[0].sum().item()}") if attention_mask is not None: attention_mask = attention_mask.to(device) bs_embed, seq_len = attention_mask.shape attention_mask = attention_mask.repeat(1, num_videos_per_prompt) attention_mask = attention_mask.view( bs_embed * num_videos_per_prompt, seq_len ) prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # max_length = prompt_embeds.shape[1] uncond_input = text_encoder.text2tokens(uncond_tokens, prompt_template=prompt_template_dict) negative_prompt_outputs = text_encoder.encode( uncond_input, prompt_template=prompt_template_dict, device=device ) negative_prompt_embeds = negative_prompt_outputs.hidden_state negative_attention_mask = negative_prompt_outputs.attention_mask if negative_attention_mask is not None: negative_attention_mask = negative_attention_mask.to(device) _, seq_len = negative_attention_mask.shape negative_attention_mask = negative_attention_mask.repeat( 1, num_videos_per_prompt ) negative_attention_mask = negative_attention_mask.view( batch_size * num_videos_per_prompt, seq_len ) else: negative_prompt_embeds = None negative_attention_mask = None return ( prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask, ) text_encoder_1.to(device) with torch.autocast(device_type=mm.get_autocast_device(device), dtype=text_encoder_1.dtype, enabled=text_encoder_1.is_fp8): prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = encode_prompt(self, prompt, negative_prompt, text_encoder_1, image_token_selection_expr=image_token_selection_expr, image1=image1, image2=image2) if force_offload: text_encoder_1.to(offload_device) mm.soft_empty_cache() if text_encoder_2 is not None: text_encoder_2.to(device) prompt_embeds_2, negative_prompt_embeds_2, attention_mask_2, negative_attention_mask_2 = encode_prompt(self, prompt, negative_prompt, text_encoder_2, clip_text_override=clip_text_override) if force_offload: text_encoder_2.to(offload_device) mm.soft_empty_cache() elif clip_l is not None: clip_l.cond_stage_model.to(device) tokens = clip_l.tokenize(prompt if clip_text_override is None else clip_text_override, return_word_ids=True) prompt_embeds_2 = clip_l.encode_from_tokens(tokens, return_pooled=True, return_dict=False)[1] prompt_embeds_2 = prompt_embeds_2.to(device=device) if negative_prompt is not None: tokens = clip_l.tokenize(negative_prompt, return_word_ids=True) negative_prompt_embeds_2 = clip_l.encode_from_tokens(tokens, return_pooled=True, return_dict=False)[1] negative_prompt_embeds_2 = negative_prompt_embeds_2.to(device=device) else: negative_prompt_embeds_2 = None attention_mask_2, negative_attention_mask_2 = None, None if force_offload: clip_l.cond_stage_model.to(offload_device) mm.soft_empty_cache() else: prompt_embeds_2 = None negative_prompt_embeds_2 = None attention_mask_2 = None negative_attention_mask_2 = None prompt_embeds_dict = { "prompt_embeds": prompt_embeds, "negative_prompt_embeds": negative_prompt_embeds, "attention_mask": attention_mask, "negative_attention_mask": negative_attention_mask, "prompt_embeds_2": prompt_embeds_2, "negative_prompt_embeds_2": negative_prompt_embeds_2, "attention_mask_2": attention_mask_2, "negative_attention_mask_2": negative_attention_mask_2, "cfg": torch.tensor(hyvid_cfg["cfg"]) if hyvid_cfg is not None else None, "start_percent": torch.tensor(hyvid_cfg["start_percent"]) if hyvid_cfg is not None else None, "end_percent": torch.tensor(hyvid_cfg["end_percent"]) if hyvid_cfg is not None else None, "batched_cfg": torch.tensor(hyvid_cfg["batched_cfg"]) if hyvid_cfg is not None else None, } return (prompt_embeds_dict,) class HyVideoTextImageEncode(HyVideoTextEncode): # Experimental Image Prompt to Video (IP2V) via VLM implementation by @Dango233 @classmethod def INPUT_TYPES(s): return {"required": { "text_encoders": ("HYVIDTEXTENCODER",), "prompt": ("STRING", {"default": "", "multiline": True} ), "image_token_selection_expr": ("STRING", {"default": "::4", "multiline": False} ), }, "optional": { "force_offload": ("BOOLEAN", {"default": True}), "prompt_template": (["video", "image", "custom", "disabled"], {"default": "video", "tooltip": "Use the default prompt templates for the llm text encoder"}), "custom_prompt_template": ("PROMPT_TEMPLATE", {"default": PROMPT_TEMPLATE["dit-llm-encode-video"], "multiline": True}), "clip_l": ("CLIP", {"tooltip": "Use comfy clip model instead, in this case the text encoder loader's clip_l should be disabled"}), "image1": ("IMAGE", {"default": None}), "image2": ("IMAGE", {"default": None}), "clip_text_override": ("STRING", {"default": "", "multiline": True} ), "hyvid_cfg": ("HYVID_CFG", ), } } RETURN_TYPES = ("HYVIDEMBEDS", ) RETURN_NAMES = ("hyvid_embeds",) FUNCTION = "process" CATEGORY = "HunyuanVideoWrapper" # region CFG class HyVideoCFG: @classmethod def INPUT_TYPES(s): return {"required": { "negative_prompt": ("STRING", {"default": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", "multiline": True} ), "cfg": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "guidance scale"} ), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply CFG, rest of the steps use guidance_embeds"} ), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply CFG, rest of the steps use guidance_embeds"} ), "batched_cfg": ("BOOLEAN", {"default": True, "tooltip": "Calculate cond and uncond as a batch, increases memory usage but can be faster"}), }, } RETURN_TYPES = ("HYVID_CFG", ) RETURN_NAMES = ("hyvid_cfg",) FUNCTION = "process" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "To use CFG with HunyuanVideo" def process(self, negative_prompt, cfg, start_percent, end_percent, batched_cfg): cfg_dict = { "negative_prompt": negative_prompt, "cfg": cfg, "start_percent": start_percent, "end_percent": end_percent, "batched_cfg": batched_cfg } return (cfg_dict,) #region embeds class HyVideoTextEmbedsSave: def __init__(self): self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "hyvid_embeds": ("HYVIDEMBEDS",), "filename_prefix": ("STRING", {"default": "hyvid_embeds/hyvid_embed"}), }, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } RETURN_TYPES = ("STRING", ) RETURN_NAMES = ("output_path",) FUNCTION = "save" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "Save the text embeds" def save(self, hyvid_embeds, prompt, filename_prefix, extra_pnginfo=None): from comfy.cli_args import args full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) file = f"{filename}_{counter:05}_.safetensors" file = os.path.join(full_output_folder, file) tensors_to_save = {} for key, value in hyvid_embeds.items(): if value is not None: tensors_to_save[key] = value prompt_info = "" if prompt is not None: prompt_info = json.dumps(prompt) metadata = None if not args.disable_metadata: metadata = {"prompt": prompt_info} if extra_pnginfo is not None: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) save_torch_file(tensors_to_save, file, metadata=metadata) return (file,) class HyVideoTextEmbedsLoad: @classmethod def INPUT_TYPES(s): return {"required": {"embeds": (folder_paths.get_filename_list("hyvid_embeds"), {"tooltip": "The saved embeds to load from output/hyvid_embeds."})}} RETURN_TYPES = ("HYVIDEMBEDS", ) RETURN_NAMES = ("hyvid_embeds",) FUNCTION = "load" CATEGORY = "HunyuanVideoWrapper" DESCTIPTION = "Load the saved text embeds" def load(self, embeds): embed_path = folder_paths.get_full_path_or_raise("hyvid_embeds", embeds) loaded_tensors = load_torch_file(embed_path, safe_load=True) # Reconstruct original dictionary with None for missing keys prompt_embeds_dict = { "prompt_embeds": loaded_tensors.get("prompt_embeds", None), "negative_prompt_embeds": loaded_tensors.get("negative_prompt_embeds", None), "attention_mask": loaded_tensors.get("attention_mask", None), "negative_attention_mask": loaded_tensors.get("negative_attention_mask", None), "prompt_embeds_2": loaded_tensors.get("prompt_embeds_2", None), "negative_prompt_embeds_2": loaded_tensors.get("negative_prompt_embeds_2", None), "attention_mask_2": loaded_tensors.get("attention_mask_2", None), "negative_attention_mask_2": loaded_tensors.get("negative_attention_mask_2", None), "cfg": loaded_tensors.get("cfg", None), "start_percent": loaded_tensors.get("start_percent", None), "end_percent": loaded_tensors.get("end_percent", None), "batched_cfg": loaded_tensors.get("batched_cfg", None), } return (prompt_embeds_dict,) class HyVideoContextOptions: @classmethod def INPUT_TYPES(s): return {"required": { "context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],), "context_frames": ("INT", {"default": 65, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ), "context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ), "context_overlap": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ), "freenoise": ("BOOLEAN", {"default": True, "tooltip": "Shuffle the noise"}), } } RETURN_TYPES = ("HYVIDCONTEXT", ) RETURN_NAMES = ("context_options",) FUNCTION = "process" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "Context options for HunyuanVideo, allows splitting the video into context windows and attemps blending them for longer generations than the model and memory otherwise would allow." def process(self, context_schedule, context_frames, context_stride, context_overlap, freenoise): context_options = { "context_schedule":context_schedule, "context_frames":context_frames, "context_stride":context_stride, "context_overlap":context_overlap, "freenoise":freenoise } return (context_options,) #region Sampler class HyVideoSampler: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("HYVIDEOMODEL",), "hyvid_embeds": ("HYVIDEMBEDS", ), "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 16}), "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 16}), "num_frames": ("INT", {"default": 49, "min": 1, "max": 1024, "step": 4}), "steps": ("INT", {"default": 30, "min": 1}), "embedded_guidance_scale": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), "flow_shift": ("FLOAT", {"default": 9.0, "min": 0.0, "max": 1000.0, "step": 0.01}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "force_offload": ("BOOLEAN", {"default": True}), }, "optional": { "samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ), "image_cond_latents": ("LATENT", {"tooltip": "init Latents to use for image2video process"} ), "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "stg_args": ("STGARGS", ), "context_options": ("HYVIDCONTEXT", ), "feta_args": ("FETAARGS", ), "teacache_args": ("TEACACHEARGS", ), "scheduler": (available_schedulers, { "default": 'FlowMatchDiscreteScheduler' }), } } RETURN_TYPES = ("LATENT",) RETURN_NAMES = ("samples",) FUNCTION = "process" CATEGORY = "HunyuanVideoWrapper" def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, seed, width, height, num_frames, samples=None, denoise_strength=1.0, force_offload=True, stg_args=None, context_options=None, feta_args=None, teacache_args=None, scheduler=None, image_cond_latents=None): model = model.model device = mm.get_torch_device() offload_device = mm.unet_offload_device() dtype = model["dtype"] transformer = model["pipe"].transformer #handle STG if stg_args is not None: if stg_args["stg_mode"] == "STG-A" and transformer.attention_mode != "sdpa": raise ValueError( f"STG-A requires attention_mode to be 'sdpa', but got {transformer.attention_mode}." ) #handle CFG if hyvid_embeds.get("cfg") is not None: cfg = float(hyvid_embeds.get("cfg", 1.0)) cfg_start_percent = float(hyvid_embeds.get("start_percent", 0.0)) cfg_end_percent = float(hyvid_embeds.get("end_percent", 1.0)) batched_cfg = hyvid_embeds.get("batched_cfg", True) else: cfg = 1.0 cfg_start_percent = 0.0 cfg_end_percent = 1.0 batched_cfg = False if embedded_guidance_scale == 0.0: embedded_guidance_scale = None generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) if width <= 0 or height <= 0 or num_frames <= 0: raise ValueError( f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={num_frames}" ) if (num_frames - 1) % 4 != 0: raise ValueError( f"`video_length - 1 (that's minus one frame)` must be a multiple of 4, got {num_frames}" ) log.info( f"Input (height, width, video_length) = ({height}, {width}, {num_frames})" ) target_height = align_to(height, 16) target_width = align_to(width, 16) scheduler_config = model["scheduler_config"] scheduler_config["flow_shift"] = flow_shift if scheduler == "SDE-DPMSolverMultistepScheduler": scheduler_config["algorithm_type"] = "sde-dpmsolver++" elif scheduler == "SASolverScheduler": scheduler_config["algorithm_type"] = "data_prediction" else: scheduler_config.pop("algorithm_type", None) #model["scheduler_config"]["use_beta_flow_sigmas"] = True noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config) model["pipe"].scheduler = noise_scheduler if model["block_swap_args"] is not None: for name, param in transformer.named_parameters(): if "single" not in name and "double" not in name: param.data = param.data.to(device) transformer.block_swap( model["block_swap_args"]["double_blocks_to_swap"] - 1 , model["block_swap_args"]["single_blocks_to_swap"] - 1, offload_txt_in = model["block_swap_args"]["offload_txt_in"], offload_img_in = model["block_swap_args"]["offload_img_in"], ) elif model["auto_cpu_offload"]: for name, param in transformer.named_parameters(): if "single" not in name and "double" not in name: param.data = param.data.to(device) elif model["manual_offloading"]: transformer.to(device) # Initialize TeaCache if enabled if teacache_args is not None: # Check if dimensions have changed since last run if (not hasattr(transformer, 'last_dimensions') or transformer.last_dimensions != (height, width, num_frames) or not hasattr(transformer, 'last_frame_count') or transformer.last_frame_count != num_frames): # Reset TeaCache state on dimension change transformer.cnt = 0 transformer.accumulated_rel_l1_distance = 0 transformer.previous_modulated_input = None transformer.previous_residual = None transformer.last_dimensions = (height, width, num_frames) transformer.last_frame_count = num_frames transformer.enable_teacache = True transformer.num_steps = steps transformer.rel_l1_thresh = teacache_args["rel_l1_thresh"] else: transformer.enable_teacache = False mm.soft_empty_cache() gc.collect() try: torch.cuda.reset_peak_memory_stats(device) except: pass #for name, param in transformer.named_parameters(): # print(name, param.data.device) leapfusion_img2vid = False input_latents = samples["samples"].clone() if samples is not None else None if input_latents is not None: if input_latents.shape[2] == 1: leapfusion_img2vid = True if denoise_strength < 1.0: input_latents *= VAE_SCALING_FACTOR out_latents = model["pipe"]( num_inference_steps=steps, height = target_height, width = target_width, video_length = num_frames, guidance_scale=cfg, cfg_start_percent=cfg_start_percent, cfg_end_percent=cfg_end_percent, batched_cfg=batched_cfg, embedded_guidance_scale=embedded_guidance_scale, latents=input_latents, denoise_strength=denoise_strength, prompt_embed_dict=hyvid_embeds, generator=generator, stg_mode=stg_args["stg_mode"] if stg_args is not None else None, stg_block_idx=stg_args["stg_block_idx"] if stg_args is not None else -1, stg_scale=stg_args["stg_scale"] if stg_args is not None else 0.0, stg_start_percent=stg_args["stg_start_percent"] if stg_args is not None else 0.0, stg_end_percent=stg_args["stg_end_percent"] if stg_args is not None else 1.0, context_options=context_options, feta_args=feta_args, leapfusion_img2vid = leapfusion_img2vid, image_cond_latents = image_cond_latents["samples"] * VAE_SCALING_FACTOR if image_cond_latents is not None else None, ) print_memory(device) try: torch.cuda.reset_peak_memory_stats(device) except: pass if force_offload: if model["manual_offloading"]: transformer.to(offload_device) mm.soft_empty_cache() gc.collect() return ({ "samples": out_latents.cpu() / VAE_SCALING_FACTOR },) #region VideoDecode class HyVideoDecode: @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("VAE",), "samples": ("LATENT",), "enable_vae_tiling": ("BOOLEAN", {"default": True, "tooltip": "Drastically reduces memory use but may introduce seams"}), "temporal_tiling_sample_size": ("INT", {"default": 64, "min": 4, "max": 256, "tooltip": "Smaller values use less VRAM, model default is 64, any other value will cause stutter"}), "spatial_tile_sample_min_size": ("INT", {"default": 256, "min": 32, "max": 2048, "step": 32, "tooltip": "Spatial tile minimum size in pixels, smaller values use less VRAM, may introduce more seams"}), "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Automatically set tile size based on defaults, above settings are ignored"}), }, } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("images",) FUNCTION = "decode" CATEGORY = "HunyuanVideoWrapper" def decode(self, vae, samples, enable_vae_tiling, temporal_tiling_sample_size, spatial_tile_sample_min_size, auto_tile_size): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() latents = samples["samples"] generator = torch.Generator(device=torch.device("cpu"))#.manual_seed(seed) vae.to(device) if not auto_tile_size: vae.tile_latent_min_tsize = temporal_tiling_sample_size // 4 vae.tile_sample_min_size = spatial_tile_sample_min_size vae.tile_latent_min_size = spatial_tile_sample_min_size // 8 if temporal_tiling_sample_size != 64: vae.t_tile_overlap_factor = 0.0 else: vae.t_tile_overlap_factor = 0.25 else: #defaults vae.tile_latent_min_tsize = 16 vae.tile_sample_min_size = 256 vae.tile_latent_min_size = 32 expand_temporal_dim = False if len(latents.shape) == 4: if isinstance(vae, AutoencoderKLCausal3D): latents = latents.unsqueeze(2) expand_temporal_dim = True elif len(latents.shape) == 5: pass else: raise ValueError( f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}." ) #latents = latents / vae.config.scaling_factor latents = latents.to(vae.dtype).to(device) if enable_vae_tiling: vae.enable_tiling() video = vae.decode( latents, return_dict=False, generator=generator )[0] else: video = vae.decode( latents, return_dict=False, generator=generator )[0] if expand_temporal_dim or video.shape[2] == 1: video = video.squeeze(2) vae.to(offload_device) mm.soft_empty_cache() if len(video.shape) == 5: video_processor = VideoProcessor(vae_scale_factor=8) video_processor.config.do_resize = False video = video_processor.postprocess_video(video=video, output_type="pt") out = video[0].permute(0, 2, 3, 1).cpu().float() else: out = (video / 2 + 0.5).clamp(0, 1) out = out.permute(0, 2, 3, 1).cpu().float() return (out,) #region VideoEncode class HyVideoEncode: @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("VAE",), "image": ("IMAGE",), "enable_vae_tiling": ("BOOLEAN", {"default": True, "tooltip": "Drastically reduces memory use but may introduce seams"}), "temporal_tiling_sample_size": ("INT", {"default": 64, "min": 4, "max": 256, "tooltip": "Smaller values use less VRAM, model default is 64, any other value will cause stutter"}), "spatial_tile_sample_min_size": ("INT", {"default": 256, "min": 32, "max": 2048, "step": 32, "tooltip": "Spatial tile minimum size in pixels, smaller values use less VRAM, may introduce more seams"}), "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Automatically set tile size based on defaults, above settings are ignored"}), }, "optional": { "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for leapfusion I2V where some noise can add motion and give sharper results"}), "latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for leapfusion I2V where lower values allow for more motion"}), } } RETURN_TYPES = ("LATENT",) RETURN_NAMES = ("samples",) FUNCTION = "encode" CATEGORY = "HunyuanVideoWrapper" def encode(self, vae, image, enable_vae_tiling, temporal_tiling_sample_size, auto_tile_size, spatial_tile_sample_min_size, noise_aug_strength=0.0, latent_strength=1.0): device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=torch.device("cpu"))#.manual_seed(seed) vae.to(device) if not auto_tile_size: vae.tile_latent_min_tsize = temporal_tiling_sample_size // 4 vae.tile_sample_min_size = spatial_tile_sample_min_size vae.tile_latent_min_size = spatial_tile_sample_min_size // 8 if temporal_tiling_sample_size != 64: vae.t_tile_overlap_factor = 0.0 else: vae.t_tile_overlap_factor = 0.25 else: #defaults vae.tile_latent_min_tsize = 16 vae.tile_sample_min_size = 256 vae.tile_latent_min_size = 32 image = (image.clone() * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W if noise_aug_strength > 0.0: image = add_noise_to_reference_video(image, ratio=noise_aug_strength) if enable_vae_tiling: vae.enable_tiling() latents = vae.encode(image).latent_dist.sample(generator) if latent_strength != 1.0: latents *= latent_strength #latents = latents * vae.config.scaling_factor vae.to(offload_device) print("encoded latents shape",latents.shape) return ({"samples": latents},) class HyVideoLatentPreview: @classmethod def INPUT_TYPES(s): return { "required": { "samples": ("LATENT",), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "min_val": ("FLOAT", {"default": -0.15, "min": -1.0, "max": 0.0, "step": 0.001}), "max_val": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), "r_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}), "g_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}), "b_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}), }, } RETURN_TYPES = ("IMAGE", "STRING", ) RETURN_NAMES = ("images", "latent_rgb_factors",) FUNCTION = "sample" CATEGORY = "HunyuanVideoWrapper" def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias): mm.soft_empty_cache() latents = samples["samples"].clone() print("in sample", latents.shape) #latent_rgb_factors =[[-0.02531045419704009, -0.00504800612542497, 0.13293717293982546], [-0.03421835830845858, 0.13996708548892614, -0.07081038680118075], [0.011091819063647063, -0.03372949685846012, -0.0698232210116172], [-0.06276524604742019, -0.09322986677909442, 0.01826383612148913], [0.021290659938126788, -0.07719530444034409, -0.08247812477766273], [0.04401102991215147, -0.0026401932105894754, -0.01410913586718443], [0.08979717602613707, 0.05361221258740831, 0.11501425309699129], [0.04695121980405198, -0.13053491609675175, 0.05025986885867986], [-0.09704684176098193, 0.03397687417738002, -0.1105886644677771], [0.14694697234804935, -0.12316902186157716, 0.04210404546699645], [0.14432470831243552, -0.002580008133591355, -0.08490676947390643], [0.051502750076553944, -0.10071695490292451, -0.01786223610178095], [-0.12503276881774464, 0.08877830923879379, 0.1076584501927316], [-0.020191205513213406, -0.1493425056303128, -0.14289740371758308], [-0.06470138952271293, -0.07410426095060325, 0.00980804676890873], [0.11747671720735695, 0.10916082743849789, -0.12235599365235904]] latent_rgb_factors = [[-0.41, -0.25, -0.26], [-0.26, -0.49, -0.24], [-0.37, -0.54, -0.3], [-0.04, -0.29, -0.29], [-0.52, -0.59, -0.39], [-0.56, -0.6, -0.02], [-0.53, -0.06, -0.48], [-0.51, -0.28, -0.18], [-0.59, -0.1, -0.33], [-0.56, -0.54, -0.41], [-0.61, -0.19, -0.5], [-0.05, -0.25, -0.17], [-0.23, -0.04, -0.22], [-0.51, -0.56, -0.43], [-0.13, -0.4, -0.05], [-0.01, -0.01, -0.48]] import random random.seed(seed) #latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(16)] out_factors = latent_rgb_factors print(latent_rgb_factors) #latent_rgb_factors_bias = [0.138, 0.025, -0.299] latent_rgb_factors_bias = [r_bias, g_bias, b_bias] latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1) latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype) print("latent_rgb_factors", latent_rgb_factors.shape) latent_images = [] for t in range(latents.shape[2]): latent = latents[:, :, t, :, :] latent = latent[0].permute(1, 2, 0) latent_image = torch.nn.functional.linear( latent, latent_rgb_factors, bias=latent_rgb_factors_bias ) latent_images.append(latent_image) latent_images = torch.stack(latent_images, dim=0) print("latent_images", latent_images.shape) latent_images_min = latent_images.min() latent_images_max = latent_images.max() latent_images = (latent_images - latent_images_min) / (latent_images_max - latent_images_min) return (latent_images.float().cpu(), out_factors) NODE_CLASS_MAPPINGS = { "HyVideoSampler": HyVideoSampler, "HyVideoDecode": HyVideoDecode, "HyVideoTextEncode": HyVideoTextEncode, "HyVideoTextImageEncode": HyVideoTextImageEncode, "HyVideoModelLoader": HyVideoModelLoader, "HyVideoVAELoader": HyVideoVAELoader, "DownloadAndLoadHyVideoTextEncoder": DownloadAndLoadHyVideoTextEncoder, "HyVideoEncode": HyVideoEncode, "HyVideoBlockSwap": HyVideoBlockSwap, "HyVideoTorchCompileSettings": HyVideoTorchCompileSettings, "HyVideoSTG": HyVideoSTG, "HyVideoCFG": HyVideoCFG, "HyVideoCustomPromptTemplate": HyVideoCustomPromptTemplate, "HyVideoLatentPreview": HyVideoLatentPreview, "HyVideoLoraSelect": HyVideoLoraSelect, "HyVideoLoraBlockEdit": HyVideoLoraBlockEdit, "HyVideoTextEmbedsSave": HyVideoTextEmbedsSave, "HyVideoTextEmbedsLoad": HyVideoTextEmbedsLoad, "HyVideoContextOptions": HyVideoContextOptions, "HyVideoEnhanceAVideo": HyVideoEnhanceAVideo, "HyVideoTeaCache": HyVideoTeaCache, } NODE_DISPLAY_NAME_MAPPINGS = { "HyVideoSampler": "HunyuanVideo Sampler", "HyVideoDecode": "HunyuanVideo Decode", "HyVideoTextEncode": "HunyuanVideo TextEncode", "HyVideoTextImageEncode": "HunyuanVideo TextImageEncode (IP2V)", "HyVideoModelLoader": "HunyuanVideo Model Loader", "HyVideoVAELoader": "HunyuanVideo VAE Loader", "DownloadAndLoadHyVideoTextEncoder": "(Down)Load HunyuanVideo TextEncoder", "HyVideoEncode": "HunyuanVideo Encode", "HyVideoBlockSwap": "HunyuanVideo BlockSwap", "HyVideoTorchCompileSettings": "HunyuanVideo Torch Compile Settings", "HyVideoSTG": "HunyuanVideo STG", "HyVideoCFG": "HunyuanVideo CFG", "HyVideoCustomPromptTemplate": "HunyuanVideo Custom Prompt Template", "HyVideoLatentPreview": "HunyuanVideo Latent Preview", "HyVideoLoraSelect": "HunyuanVideo Lora Select", "HyVideoLoraBlockEdit": "HunyuanVideo Lora Block Edit", "HyVideoTextEmbedsSave": "HunyuanVideo TextEmbeds Save", "HyVideoTextEmbedsLoad": "HunyuanVideo TextEmbeds Load", "HyVideoContextOptions": "HunyuanVideo Context Options", "HyVideoEnhanceAVideo": "HunyuanVideo Enhance A Video", "HyVideoTeaCache": "HunyuanVideo TeaCache", }