import os import json import folder_paths import comfy.model_management as mm from typing import Union def patched_write_atomic( path_: str, content: Union[str, bytes], make_dirs: bool = False, encode_utf_8: bool = False, ) -> None: # Write into temporary file first to avoid conflicts between threads # Avoid using a named temporary file, as those have restricted permissions from pathlib import Path import os import shutil import threading assert isinstance( content, (str, bytes) ), "Only strings and byte arrays can be saved in the cache" path = Path(path_) if make_dirs: path.parent.mkdir(parents=True, exist_ok=True) tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" write_mode = "w" if isinstance(content, str) else "wb" with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: f.write(content) shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files os.remove(tmp_path) try: import torch._inductor.codecache torch._inductor.codecache.write_atomic = patched_write_atomic except: pass import torch import torch.nn as nn from diffusers.models import AutoencoderKLCogVideoX from diffusers.schedulers import CogVideoXDDIMScheduler from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel from .pipeline_cogvideox import CogVideoXPipeline from contextlib import nullcontext from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device from .utils import remove_specific_blocks, log from comfy.utils import load_torch_file script_directory = os.path.dirname(os.path.abspath(__file__)) class CogVideoLoraSelect: @classmethod def INPUT_TYPES(s): return { "required": { "lora": (folder_paths.get_filename_list("cogvideox_loras"), {"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), }, "optional": { "prev_lora":("COGLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}), "fuse_lora": ("BOOLEAN", {"default": False, "tooltip": "Fuse the LoRA weights into the transformer"}), } } RETURN_TYPES = ("COGLORA",) RETURN_NAMES = ("lora", ) FUNCTION = "getlorapath" CATEGORY = "CogVideoWrapper" DESCRIPTION = "Select a LoRA model from ComfyUI/models/CogVideo/loras" def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False): cog_loras_list = [] cog_lora = { "path": folder_paths.get_full_path("cogvideox_loras", lora), "strength": strength, "name": lora.split(".")[0], "fuse_lora": fuse_lora } if prev_lora is not None: cog_loras_list.extend(prev_lora) cog_loras_list.append(cog_lora) print(cog_loras_list) return (cog_loras_list,) class CogVideoLoraSelectComfy: @classmethod def INPUT_TYPES(s): return { "required": { "lora": (folder_paths.get_filename_list("loras"), {"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), }, "optional": { "prev_lora":("COGLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}), "fuse_lora": ("BOOLEAN", {"default": False, "tooltip": "Fuse the LoRA weights into the transformer"}), } } RETURN_TYPES = ("COGLORA",) RETURN_NAMES = ("lora", ) FUNCTION = "getlorapath" CATEGORY = "CogVideoWrapper" DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras" def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False): cog_loras_list = [] cog_lora = { "path": folder_paths.get_full_path("loras", lora), "strength": strength, "name": lora.split(".")[0], "fuse_lora": fuse_lora } if prev_lora is not None: cog_loras_list.extend(prev_lora) cog_loras_list.append(cog_lora) print(cog_loras_list) return (cog_loras_list,) #region DownloadAndLoadCogVideoModel class DownloadAndLoadCogVideoModel: @classmethod def INPUT_TYPES(s): return { "required": { "model": ( [ "THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b", "THUDM/CogVideoX-5b-I2V", "kijai/CogVideoX-5b-1.5-T2V", "kijai/CogVideoX-5b-1.5-I2V", "bertjiazheng/KoolCogVideoX-5b", "kijai/CogVideoX-Fun-2b", "kijai/CogVideoX-Fun-5b", "kijai/CogVideoX-5b-Tora", "alibaba-pai/CogVideoX-Fun-V1.1-2b-InP", "alibaba-pai/CogVideoX-Fun-V1.1-5b-InP", "alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose", "alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", "alibaba-pai/CogVideoX-Fun-V1.1-5b-Control", "feizhengcong/CogvideoX-Interpolation", "NimVideo/cogvideox-2b-img2vid" ], ), }, "optional": { "precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} ), "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fastmode', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), "lora": ("COGLORA", {"default": None}), "compile_args":("COMPILEARGS", ), "attention_mode": ([ "sdpa", "fused_sdpa", "sageattn", "fused_sageattn", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "fused_sageattn_qk_int8_pv_fp8_cuda", "fused_sageattn_qk_int8_pv_fp16_cuda", "fused_sageattn_qk_int8_pv_fp16_triton", "comfy" ], {"default": "sdpa"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}), } } RETURN_TYPES = ("COGVIDEOMODEL", "VAE",) RETURN_NAMES = ("model", "vae", ) FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'" def loadmodel(self, model, precision, quantization="disabled", compile="disabled", enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None, attention_mode="sdpa", load_device="main_device"): transformer = None if "sage" in attention_mode: try: from sageattention import sageattn except Exception as e: raise ValueError(f"Can't import SageAttention: {str(e)}") if "qk_int8" in attention_mode: try: from sageattention import sageattn_qk_int8_pv_fp16_cuda except Exception as e: raise ValueError(f"Can't import SageAttention 2.0.0: {str(e)}") if precision == "fp16" and "1.5" in model: raise ValueError("1.5 models do not currently work in fp16") device = mm.get_torch_device() offload_device = mm.unet_offload_device() manual_offloading = True transformer_load_device = device if load_device == "main_device" else offload_device mm.soft_empty_cache() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] download_path = folder_paths.get_folder_paths("CogVideo")[0] if "Fun" in model: if not "1.1" in model: repo_id = "kijai/CogVideoX-Fun-pruned" if "2b" in model: base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-2b-InP") # location of the official model if not os.path.exists(base_path): base_path = os.path.join(download_path, "CogVideoX-Fun-2b-InP") elif "5b" in model: base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") # location of the official model if not os.path.exists(base_path): base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP") elif "1.1" in model: repo_id = model base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", (model.split("/")[-1])) # location of the official model if not os.path.exists(base_path): base_path = os.path.join(download_path, (model.split("/")[-1])) download_path = base_path subfolder = "transformer" allow_patterns = ["*transformer*", "*scheduler*", "*vae*"] elif "2b" in model: if 'img2vid' in model: base_path = os.path.join(download_path, "cogvideox-2b-img2vid") download_path = base_path repo_id = model else: base_path = os.path.join(download_path, "CogVideo2B") download_path = base_path repo_id = model subfolder = "transformer" allow_patterns = ["*transformer*", "*scheduler*", "*vae*"] elif "1.5-T2V" in model or "1.5-I2V" in model: base_path = os.path.join(download_path, "CogVideoX-5b-1.5") download_path = base_path subfolder = "transformer_T2V" if "1.5-T2V" in model else "transformer_I2V" allow_patterns = [f"*{subfolder}*", "*vae*", "*scheduler*"] repo_id = "kijai/CogVideoX-5b-1.5" else: base_path = os.path.join(download_path, (model.split("/")[-1])) download_path = base_path repo_id = model subfolder = "transformer" allow_patterns = ["*transformer*", "*scheduler*", "*vae*"] if "2b" in model: scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json') else: scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json') if not os.path.exists(base_path) or not os.path.exists(os.path.join(base_path, subfolder)): log.info(f"Downloading model to: {base_path}") from huggingface_hub import snapshot_download snapshot_download( repo_id=repo_id, allow_patterns=allow_patterns, ignore_patterns=["*text_encoder*", "*tokenizer*"], local_dir=download_path, local_dir_use_symlinks=False, ) transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder, attention_mode=attention_mode) transformer = transformer.to(dtype).to(transformer_load_device) if "1.5" in model: transformer.config.sample_height = 300 transformer.config.sample_width = 300 if block_edit is not None: transformer = remove_specific_blocks(transformer, block_edit) with open(scheduler_path) as f: scheduler_config = json.load(f) scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) # VAE vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) #pipeline pipe = CogVideoXPipeline( transformer, scheduler, dtype=dtype, is_fun_inpaint="fun" in model.lower() and not ("pose" in model.lower() or "control" in model.lower()) ) if "cogvideox-2b-img2vid" in model: pipe.input_with_padding = False #LoRAs if lora is not None: dimensionx_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling dimensionx_lora = False adapter_list = [] adapter_weights = [] for l in lora: if any(item in l["path"].lower() for item in dimensionx_loras): dimensionx_lora = True fuse = True if l["fuse_lora"] else False lora_sd = load_torch_file(l["path"]) lora_rank = None for key, val in lora_sd.items(): if "lora_B" in key: lora_rank = val.shape[1] break if lora_rank is not None: log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") adapter_name = l['path'].split("/")[-1].split(".")[0] adapter_weight = l['strength'] pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) adapter_list.append(adapter_name) adapter_weights.append(adapter_weight) else: try: #Fun trainer LoRAs are loaded differently from .lora_utils import merge_lora log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}") pipe.transformer = merge_lora(pipe.transformer, l["path"], l["strength"], device=transformer_load_device, state_dict=lora_sd) except: raise ValueError(f"Can't recognize LoRA {l['path']}") if adapter_list: pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) if fuse: lora_scale = 1 if dimensionx_lora: lora_scale = lora_scale / lora_rank pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"]) if "fused" in attention_mode: from diffusers.models.attention import Attention pipe.transformer.fuse_qkv_projections = True for module in pipe.transformer.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) if compile_args is not None: pipe.transformer.to(memory_format=torch.channels_last) #fp8 if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fastmode": params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"} if "1.5" in model: params_to_keep.update({"norm1.linear.weight", "ofs_embedding", "norm_final", "norm_out", "proj_out"}) for name, param in pipe.transformer.named_parameters(): if not any(keyword in name for keyword in params_to_keep): param.data = param.data.to(torch.float8_e4m3fn) if quantization == "fp8_e4m3fn_fastmode": from .fp8_optimization import convert_fp8_linear if "1.5" in model: params_to_keep.update({"ff"}) #otherwise NaNs convert_fp8_linear(pipe.transformer, dtype, params_to_keep=params_to_keep) # compilation if compile_args is not None: torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] for i, block in enumerate(pipe.transformer.transformer_blocks): if "CogVideoXBlock" in str(block): pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) if "torchao" in quantization: try: from torchao.quantization import ( quantize_, fpx_weight_only, float8_dynamic_activation_float8_weight, int8_dynamic_activation_int8_weight ) except: raise ImportError("torchao is not installed, please install torchao to use fp8dq") def filter_fn(module: nn.Module, fqn: str) -> bool: target_submodules = {'attn1', 'ff'} # avoid norm layers, 1.5 at least won't work with quantized norm1 #todo: test other models if any(sub in fqn for sub in target_submodules): return isinstance(module, nn.Linear) return False if "fp6" in quantization: #slower for some reason on 4090 quant_func = fpx_weight_only(3, 2) elif "fp8dq" in quantization: #very fast on 4090 when compiled quant_func = float8_dynamic_activation_float8_weight() elif 'fp8dqrow' in quantization: from torchao.quantization.quant_api import PerRow quant_func = float8_dynamic_activation_float8_weight(granularity=PerRow()) elif 'int8dq' in quantization: quant_func = int8_dynamic_activation_int8_weight() for i, block in enumerate(pipe.transformer.transformer_blocks): if "CogVideoXBlock" in str(block): quantize_(block, quant_func, filter_fn=filter_fn) manual_offloading = False # to disable manual .to(device) calls if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() manual_offloading = False # CogVideoXBlock( # (norm1): CogVideoXLayerNormZero( # (silu): SiLU() # (linear): Linear(in_features=512, out_features=18432, bias=True) # (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True) # ) # (attn1): Attention( # (norm_q): LayerNorm((64,), eps=1e-06, elementwise_affine=True) # (norm_k): LayerNorm((64,), eps=1e-06, elementwise_affine=True) # (to_q): Linear(in_features=3072, out_features=3072, bias=True) # (to_k): Linear(in_features=3072, out_features=3072, bias=True) # (to_v): Linear(in_features=3072, out_features=3072, bias=True) # (to_out): ModuleList( # (0): Linear(in_features=3072, out_features=3072, bias=True) # (1): Dropout(p=0.0, inplace=False) # ) # ) # (norm2): CogVideoXLayerNormZero( # (silu): SiLU() # (linear): Linear(in_features=512, out_features=18432, bias=True) # (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True) # ) # (ff): FeedForward( # (net): ModuleList( # (0): GELU( # (proj): Linear(in_features=3072, out_features=12288, bias=True) # ) # (1): Dropout(p=0.0, inplace=False) # (2): Linear(in_features=12288, out_features=3072, bias=True) # (3): Dropout(p=0.0, inplace=False) # ) # ) # ) # if compile == "onediff": # from onediffx import compile_pipe # os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' # pipe = compile_pipe( # pipe, # backend="nexfort", # options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}}, # ignores=["vae"], # fuse_qkv_projections= False, # ) pipeline = { "pipe": pipe, "dtype": dtype, "quantization": quantization, "base_path": base_path, "onediff": True if compile == "onediff" else False, "cpu_offloading": enable_sequential_cpu_offload, "manual_offloading": manual_offloading, "scheduler_config": scheduler_config, "model_name": model, } return (pipeline, vae) #region GGUF class DownloadAndLoadCogVideoGGUFModel: @classmethod def INPUT_TYPES(s): return { "required": { "model": ( [ "CogVideoX_5b_GGUF_Q4_0.safetensors", "CogVideoX_5b_I2V_GGUF_Q4_0.safetensors", "CogVideoX_5b_1_5_I2V_GGUF_Q4_0.safetensors", "CogVideoX_5b_fun_GGUF_Q4_0.safetensors", "CogVideoX_5b_fun_1_1_GGUF_Q4_0.safetensors", "CogVideoX_5b_fun_1_1_Pose_GGUF_Q4_0.safetensors", "CogVideoX_5b_Interpolation_GGUF_Q4_0.safetensors", "CogVideoX_5b_Tora_GGUF_Q4_0.safetensors", ], ), "vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}), "fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs, also requires torch 2.4.0 with cu124 minimum"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), }, "optional": { "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), #"compile_args":("COMPILEARGS", ), "attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}), } } RETURN_TYPES = ("COGVIDEOMODEL", "VAE",) RETURN_NAMES = ("model", "vae",) FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, block_edit=None, compile_args=None, attention_mode="sdpa"): if "sage" in attention_mode: try: from sageattention import sageattn except Exception as e: raise ValueError(f"Can't import SageAttention: {str(e)}") device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() vae_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[vae_precision] download_path = os.path.join(folder_paths.models_dir, 'CogVideo', 'GGUF') gguf_path = os.path.join(folder_paths.models_dir, 'diffusion_models', model) # check MinusZone's model path first if not os.path.exists(gguf_path): gguf_path = os.path.join(download_path, model) if not os.path.exists(gguf_path): if "I2V" in model or "1_1" in model or "Interpolation" in model or "Tora" in model: repo_id = "Kijai/CogVideoX_GGUF" else: repo_id = "MinusZoneAI/ComfyUI-CogVideoX-MZ" log.info(f"Downloading model to: {gguf_path}") from huggingface_hub import snapshot_download snapshot_download( repo_id=repo_id, allow_patterns=[f"*{model}*"], local_dir=download_path, local_dir_use_symlinks=False, ) if "5b" in model: scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json') transformer_path = os.path.join(script_directory, 'configs', 'transformer_config_5b.json') elif "2b" in model: scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json') transformer_path = os.path.join(script_directory, 'configs', 'transformer_config_2b.json') with open(transformer_path) as f: transformer_config = json.load(f) from . import mz_gguf_loader import importlib importlib.reload(mz_gguf_loader) with mz_gguf_loader.quantize_lazy_load(): if "fun" in model: if "Pose" in model: transformer_config["in_channels"] = 32 else: transformer_config["in_channels"] = 33 elif "I2V" in model or "Interpolation" in model: transformer_config["in_channels"] = 32 if "1_5" in model: transformer_config["ofs_embed_dim"] = 512 transformer_config["use_learned_positional_embeddings"] = False transformer_config["patch_size_t"] = 2 transformer_config["patch_bias"] = False transformer_config["sample_height"] = 300 transformer_config["sample_width"] = 300 else: transformer_config["in_channels"] = 16 transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode) cast_dtype = vae_dtype params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"} if "2b" in model: cast_dtype = torch.float16 elif "1_5" in model: params_to_keep = {"norm1.linear.weight", "patch_embed", "time_embedding", "ofs_embedding", "norm_final", "norm_out", "proj_out"} cast_dtype = torch.bfloat16 for name, param in transformer.named_parameters(): if not any(keyword in name for keyword in params_to_keep): param.data = param.data.to(torch.float8_e4m3fn) else: param.data = param.data.to(cast_dtype) #for name, param in transformer.named_parameters(): # print(name, param.data.dtype) if block_edit is not None: transformer = remove_specific_blocks(transformer, block_edit) transformer.attention_mode = attention_mode if fp8_fastmode: params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding"} if "1.5" in model: params_to_keep.update({"ff","norm1.linear.weight", "norm_k", "norm_q","ofs_embedding", "norm_final", "norm_out", "proj_out"}) from .fp8_optimization import convert_fp8_linear convert_fp8_linear(transformer, vae_dtype, params_to_keep=params_to_keep) with open(scheduler_path) as f: scheduler_config = json.load(f) scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config, subfolder="scheduler") # VAE vae_dl_path = os.path.join(folder_paths.models_dir, 'CogVideo', 'VAE') vae_path = os.path.join(vae_dl_path, "cogvideox_vae.safetensors") if not os.path.exists(vae_path): log.info(f"Downloading VAE model to: {vae_path}") from huggingface_hub import snapshot_download snapshot_download( repo_id="Kijai/CogVideoX-Fun-pruned", allow_patterns=["*cogvideox_vae.safetensors*"], local_dir=vae_dl_path, local_dir_use_symlinks=False, ) with open(os.path.join(script_directory, 'configs', 'vae_config.json')) as f: vae_config = json.load(f) #VAE vae_sd = load_torch_file(vae_path) vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device) vae.load_state_dict(vae_sd) del vae_sd pipe = CogVideoXPipeline( transformer, scheduler, dtype=vae_dtype, is_fun_inpaint="fun" in model.lower() and not ("pose" in model.lower() or "control" in model.lower()) ) if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() sd = load_torch_file(gguf_path) pipe.transformer = mz_gguf_loader.quantize_load_state_dict(pipe.transformer, sd, device="cpu") del sd if load_device == "offload_device": pipe.transformer.to(offload_device) else: pipe.transformer.to(device) pipeline = { "pipe": pipe, "dtype": vae_dtype, "quantization": "GGUF", "base_path": model, "onediff": False, "cpu_offloading": enable_sequential_cpu_offload, "scheduler_config": scheduler_config, "model_name": model, "manual_offloading": True, } return (pipeline, vae) #region ModelLoader class CogVideoXModelLoader: @classmethod def INPUT_TYPES(s): return { "required": { "model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), "base_precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}), "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "optional quantization method"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), }, "optional": { "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), "lora": ("COGLORA", {"default": None}), "compile_args":("COMPILEARGS", ), "attention_mode": ([ "sdpa", "fused_sdpa", "sageattn", "fused_sageattn", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "fused_sageattn_qk_int8_pv_fp8_cuda", "fused_sageattn_qk_int8_pv_fp16_cuda", "fused_sageattn_qk_int8_pv_fp16_triton", "comfy" ], {"default": "sdpa"}), } } RETURN_TYPES = ("COGVIDEOMODEL",) RETURN_NAMES = ("model", ) FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload, block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"): transformer = None if "sage" in attention_mode: try: from sageattention import sageattn except Exception as e: raise ValueError(f"Can't import SageAttention: {str(e)}") device = mm.get_torch_device() offload_device = mm.unet_offload_device() manual_offloading = True transformer_load_device = device if load_device == "main_device" else offload_device mm.soft_empty_cache() base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[base_precision] model_path = folder_paths.get_full_path_or_raise("diffusion_models", model) sd = load_torch_file(model_path, device=transformer_load_device) model_type = "" if sd["patch_embed.proj.weight"].shape == (3072, 33, 2, 2): model_type = "fun_5b" elif sd["patch_embed.proj.weight"].shape == (3072, 16, 2, 2): model_type = "5b" elif sd["patch_embed.proj.weight"].shape == (3072, 128): model_type = "5b_1_5" elif sd["patch_embed.proj.weight"].shape == (3072, 256): model_type = "5b_I2V_1_5" elif sd["patch_embed.proj.weight"].shape == (1920, 33, 2, 2): model_type = "fun_2b" elif sd["patch_embed.proj.weight"].shape == (1920, 16, 2, 2): model_type = "2b" elif sd["patch_embed.proj.weight"].shape == (3072, 32, 2, 2): if "pos_embedding" in sd: model_type = "fun_5b_pose" else: model_type = "I2V_5b" else: raise Exception("Selected model is not recognized") log.info(f"Detected CogVideoX model type: {model_type}") if "5b" in model_type: scheduler_config_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json') transformer_config_path = os.path.join(script_directory, 'configs', 'transformer_config_5b.json') elif "2b" in model_type: scheduler_config_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json') transformer_config_path = os.path.join(script_directory, 'configs', 'transformer_config_2b.json') with open(transformer_config_path) as f: transformer_config = json.load(f) if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5"]: transformer_config["in_channels"] = 32 if "1_5" in model_type: transformer_config["ofs_embed_dim"] = 512 elif "fun" in model_type: transformer_config["in_channels"] = 33 else: transformer_config["in_channels"] = 16 if "1_5" in model_type: transformer_config["use_learned_positional_embeddings"] = False transformer_config["patch_size_t"] = 2 transformer_config["patch_bias"] = False transformer_config["sample_height"] = 300 transformer_config["sample_width"] = 300 with init_empty_weights(): transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode) #load weights #params_to_keep = {} log.info("Using accelerate to load and assign model weights to device...") for name, param in transformer.named_parameters(): #dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=base_dtype, value=sd[name]) del sd #scheduler with open(scheduler_config_path) as f: scheduler_config = json.load(f) scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config, subfolder="scheduler") if block_edit is not None: transformer = remove_specific_blocks(transformer, block_edit) if "fused" in attention_mode: from diffusers.models.attention import Attention transformer.fuse_qkv_projections = True for module in transformer.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) transformer.attention_mode = attention_mode pipe = CogVideoXPipeline( transformer, scheduler, dtype=base_dtype, is_fun_inpaint="fun" in model.lower() and not ("pose" in model.lower() or "control" in model.lower()) ) if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() #LoRAs if lora is not None: dimensionx_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling dimensionx_lora = False adapter_list = [] adapter_weights = [] for l in lora: if any(item in l["path"].lower() for item in dimensionx_loras): dimensionx_lora = True fuse = True if l["fuse_lora"] else False lora_sd = load_torch_file(l["path"]) lora_rank = None for key, val in lora_sd.items(): if "lora_B" in key: lora_rank = val.shape[1] break if lora_rank is not None: log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") adapter_name = l['path'].split("/")[-1].split(".")[0] adapter_weight = l['strength'] pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) adapter_list.append(adapter_name) adapter_weights.append(adapter_weight) else: try: #Fun trainer LoRAs are loaded differently from .lora_utils import merge_lora log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}") pipe.transformer = merge_lora(pipe.transformer, l["path"], l["strength"], device=transformer_load_device, state_dict=lora_sd) except: raise ValueError(f"Can't recognize LoRA {l['path']}") if adapter_list: pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) if fuse: lora_scale = 1 if dimensionx_lora: lora_scale = lora_scale / lora_rank pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"]) if compile_args is not None: pipe.transformer.to(memory_format=torch.channels_last) #quantization if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast": params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"} if "1.5" in model: params_to_keep.update({"norm1.linear.weight", "ofs_embedding", "norm_final", "norm_out", "proj_out"}) for name, param in pipe.transformer.named_parameters(): if not any(keyword in name for keyword in params_to_keep): param.data = param.data.to(torch.float8_e4m3fn) if quantization == "fp8_e4m3fn_fast": from .fp8_optimization import convert_fp8_linear if "1.5" in model: params_to_keep.update({"ff"}) #otherwise NaNs convert_fp8_linear(pipe.transformer, base_dtype, params_to_keep=params_to_keep) #compile if compile_args is not None: torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] for i, block in enumerate(pipe.transformer.transformer_blocks): if "CogVideoXBlock" in str(block): pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) if "torchao" in quantization: try: from torchao.quantization import ( quantize_, fpx_weight_only, float8_dynamic_activation_float8_weight, int8_dynamic_activation_int8_weight ) except: raise ImportError("torchao is not installed, please install torchao to use fp8dq") def filter_fn(module: nn.Module, fqn: str) -> bool: target_submodules = {'attn1', 'ff'} # avoid norm layers, 1.5 at least won't work with quantized norm1 #todo: test other models if any(sub in fqn for sub in target_submodules): return isinstance(module, nn.Linear) return False if "fp6" in quantization: #slower for some reason on 4090 quant_func = fpx_weight_only(3, 2) elif "fp8dq" in quantization: #very fast on 4090 when compiled quant_func = float8_dynamic_activation_float8_weight() elif 'fp8dqrow' in quantization: from torchao.quantization.quant_api import PerRow quant_func = float8_dynamic_activation_float8_weight(granularity=PerRow()) elif 'int8dq' in quantization: quant_func = int8_dynamic_activation_int8_weight() for i, block in enumerate(pipe.transformer.transformer_blocks): if "CogVideoXBlock" in str(block): quantize_(block, quant_func, filter_fn=filter_fn) manual_offloading = False # to disable manual .to(device) calls log.info(f"Quantized transformer blocks to {quantization}") pipeline = { "pipe": pipe, "dtype": base_dtype, "quantization": quantization, "base_path": model, "onediff": False, "cpu_offloading": enable_sequential_cpu_offload, "scheduler_config": scheduler_config, "model_name": model, "manual_offloading": manual_offloading, } return (pipeline,) #region VAE class CogVideoXVAELoader: @classmethod def INPUT_TYPES(s): return { "required": { "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}), }, "optional": { "precision": (["fp16", "fp32", "bf16"], {"default": "bf16"} ), "compile_args":("COMPILEARGS", ), } } RETURN_TYPES = ("VAE",) RETURN_NAMES = ("vae", ) FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" DESCRIPTION = "Loads CogVideoX VAE model from 'ComfyUI/models/vae'" def loadmodel(self, model_name, precision, compile_args=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] with open(os.path.join(script_directory, 'configs', 'vae_config.json')) as f: vae_config = json.load(f) model_path = folder_paths.get_full_path("vae", model_name) vae_sd = load_torch_file(model_path) vae = AutoencoderKLCogVideoX.from_config(vae_config).to(dtype).to(offload_device) vae.load_state_dict(vae_sd) #compile if compile_args is not None: torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] vae = torch.compile(vae, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) return (vae,) #region Tora class DownloadAndLoadToraModel: @classmethod def INPUT_TYPES(s): return { "required": { "model": ( [ "kijai/CogVideoX-5b-Tora", ], ), }, } RETURN_TYPES = ("TORAMODEL",) RETURN_NAMES = ("tora_model", ) FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" DESCRIPTION = "Downloads and loads the the Tora model from Huggingface to 'ComfyUI/models/CogVideo/CogVideoX-5b-Tora'" def loadmodel(self, model): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() download_path = folder_paths.get_folder_paths("CogVideo")[0] from .tora.traj_module import MGF try: from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device is_accelerate_available = True except: is_accelerate_available = False pass download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora") fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors") if not os.path.exists(fuser_path): log.info(f"Downloading Fuser model to: {fuser_path}") from huggingface_hub import snapshot_download snapshot_download( repo_id=model, allow_patterns=["*fuser.safetensors*"], local_dir=download_path, local_dir_use_symlinks=False, ) hidden_size = 3072 num_layers = 42 with (init_empty_weights() if is_accelerate_available else nullcontext()): fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)]) fuser_sd = load_torch_file(fuser_path) if is_accelerate_available: for key in fuser_sd: set_module_tensor_to_device(fuser_list, key, dtype=torch.float16, device=device, value=fuser_sd[key]) else: fuser_list.load_state_dict(fuser_sd) for module in fuser_list: for param in module.parameters(): param.data = param.data.to(torch.bfloat16).to(device) del fuser_sd traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors") if not os.path.exists(traj_extractor_path): log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}") from huggingface_hub import snapshot_download snapshot_download( repo_id="kijai/CogVideoX-5b-Tora", allow_patterns=["*traj_extractor.safetensors*"], local_dir=download_path, local_dir_use_symlinks=False, ) from .tora.traj_module import TrajExtractor with (init_empty_weights() if is_accelerate_available else nullcontext()): traj_extractor = TrajExtractor( vae_downsize=(4, 8, 8), patch_size=2, nums_rb=2, cin=16, channels=[128] * 42, sk=True, use_conv=False, ) traj_sd = load_torch_file(traj_extractor_path) if is_accelerate_available: for key in traj_sd: set_module_tensor_to_device(traj_extractor, key, dtype=torch.float32, device=device, value=traj_sd[key]) else: traj_extractor.load_state_dict(traj_sd) traj_extractor.to(torch.float32).to(device) toramodel = { "fuser_list": fuser_list, "traj_extractor": traj_extractor, } return (toramodel,) #region controlnet class DownloadAndLoadCogVideoControlNet: @classmethod def INPUT_TYPES(s): return { "required": { "model": ( [ "TheDenk/cogvideox-2b-controlnet-hed-v1", "TheDenk/cogvideox-2b-controlnet-canny-v1", "TheDenk/cogvideox-5b-controlnet-hed-v1", "TheDenk/cogvideox-5b-controlnet-canny-v1" ], ), }, } RETURN_TYPES = ("COGVIDECONTROLNETMODEL",) RETURN_NAMES = ("cogvideo_controlnet", ) FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" def loadmodel(self, model): from .cogvideo_controlnet import CogVideoXControlnet device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() download_path = os.path.join(folder_paths.models_dir, 'CogVideo', 'ControlNet') base_path = os.path.join(download_path, (model.split("/")[-1])) if not os.path.exists(base_path): log.info(f"Downloading model to: {base_path}") from huggingface_hub import snapshot_download snapshot_download( repo_id=model, ignore_patterns=["*text_encoder*", "*tokenizer*"], local_dir=base_path, local_dir_use_symlinks=False, ) controlnet = CogVideoXControlnet.from_pretrained(base_path) return (controlnet,) NODE_CLASS_MAPPINGS = { "DownloadAndLoadCogVideoModel": DownloadAndLoadCogVideoModel, "DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel, "DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet, "DownloadAndLoadToraModel": DownloadAndLoadToraModel, "CogVideoLoraSelect": CogVideoLoraSelect, "CogVideoXVAELoader": CogVideoXVAELoader, "CogVideoXModelLoader": CogVideoXModelLoader, "CogVideoLoraSelectComfy": CogVideoLoraSelectComfy } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", "DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model", "DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet", "DownloadAndLoadToraModel": "(Down)load Tora Model", "CogVideoLoraSelect": "CogVideo LoraSelect", "CogVideoXVAELoader": "CogVideoX VAE Loader", "CogVideoXModelLoader": "CogVideoX Model Loader", "CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy" }