all_models / custom_nodes /ComfyUI-KJNodes /nodes /model_optimization_nodes.py
jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
from comfy.ldm.modules import attention as comfy_attention
import comfy.model_patcher
import comfy.utils
import comfy.sd
import torch
import folder_paths
import comfy.model_management as mm
from comfy.cli_args import args
orig_attention = comfy_attention.optimized_attention
original_patch_model = comfy.model_patcher.ModelPatcher.patch_model
original_load_lora_for_models = comfy.sd.load_lora_for_models
class BaseLoaderKJ:
original_linear = None
cublas_patched = False
def _patch_modules(self, patch_cublaslinear, sage_attention):
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
if sage_attention != "disabled":
print("Patching comfy attention to use sageattn")
from sageattention import sageattn
def set_sage_func(sage_attention):
if sage_attention == "auto":
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
return func
elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda":
from sageattention import sageattn_qk_int8_pv_fp16_cuda
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout)
return func
elif sage_attention == "sageattn_qk_int8_pv_fp16_triton":
from sageattention import sageattn_qk_int8_pv_fp16_triton
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
return func
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda":
from sageattention import sageattn_qk_int8_pv_fp8_cuda
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout)
return func
sage_func = set_sage_func(sage_attention)
@torch.compiler.disable()
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
tensor_layout="NHD"
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
if skip_output_reshape:
out = out.transpose(1, 2)
else:
out = out.reshape(b, -1, heads * dim_head)
return out
comfy_attention.optimized_attention = attention_sage
comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage
comfy.ldm.flux.math.optimized_attention = attention_sage
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage
comfy.ldm.cosmos.blocks.optimized_attention = attention_sage
else:
comfy_attention.optimized_attention = orig_attention
comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention
comfy.ldm.flux.math.optimized_attention = orig_attention
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention
comfy.ldm.cosmos.blocks.optimized_attention = orig_attention
if patch_cublaslinear:
if not BaseLoaderKJ.cublas_patched:
BaseLoaderKJ.original_linear = disable_weight_init.Linear
try:
from cublas_ops import CublasLinear
except ImportError:
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
class PatchedLinear(CublasLinear, CastWeightBiasOp):
def reset_parameters(self):
pass
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
disable_weight_init.Linear = PatchedLinear
BaseLoaderKJ.cublas_patched = True
else:
if BaseLoaderKJ.cublas_patched:
disable_weight_init.Linear = BaseLoaderKJ.original_linear
BaseLoaderKJ.cublas_patched = False
class PathchSageAttentionKJ(BaseLoaderKJ):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}),
}}
RETURN_TYPES = ("MODEL", )
FUNCTION = "patch"
DESCRIPTION = "Experimental node for patching attention mode. This doesn't use the model patching system and thus can't be disabled without running the node again with 'disabled' option."
EXPERIMENTAL = True
CATEGORY = "KJNodes/experimental"
def patch(self, model, sage_attention):
self._patch_modules(False, sage_attention)
return model,
class CheckpointLoaderKJ(BaseLoaderKJ):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "patch"
OUTPUT_NODE = True
DESCRIPTION = "Experimental node for patching torch.nn.Linear with CublasLinear."
EXPERIMENTAL = True
CATEGORY = "KJNodes/experimental"
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
self._patch_modules(patch_cublaslinear, sage_attention)
from nodes import CheckpointLoaderSimple
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
return model, clip, vae
class DiffusionModelLoaderKJ(BaseLoaderKJ):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"ckpt_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],),
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch_and_load"
OUTPUT_NODE = True
DESCRIPTION = "Node for patching torch.nn.Linear with CublasLinear."
EXPERIMENTAL = True
CATEGORY = "KJNodes/experimental"
def patch_and_load(self, ckpt_name, weight_dtype, patch_cublaslinear, sage_attention):
from nodes import UNETLoader
model, = UNETLoader.load_unet(self, ckpt_name, weight_dtype)
self._patch_modules(patch_cublaslinear, sage_attention)
return (model,)
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
if lowvram_model_memory == 0:
full_load = True
else:
full_load = False
device_to = mm.get_torch_device()
load_weights = True
if load_weights:
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
return self.model
def patched_load_lora_for_models(model, clip, lora, strength_model, strength_clip):
patch_keys = list(model.object_patches_backup.keys())
for k in patch_keys:
#print("backing up object patch: ", k)
comfy.utils.set_attr(model.model, k, model.object_patches_backup[k])
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
loaded = comfy.lora.load_lora(lora, key_map)
#print(temp_object_patches_backup)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
else:
k1 = ()
new_clip = None
k = set(k)
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
print("NOT LOADED {}".format(x))
if patch_keys:
if hasattr(model.model, "compile_settings"):
compile_settings = getattr(model.model, "compile_settings")
print("compile_settings: ", compile_settings)
for k in patch_keys:
if "diffusion_model." in k:
# Remove the prefix to get the attribute path
key = k.replace('diffusion_model.', '')
attributes = key.split('.')
# Start with the diffusion_model object
block = model.get_model_object("diffusion_model")
# Navigate through the attributes to get to the block
for attr in attributes:
if attr.isdigit():
block = block[int(attr)]
else:
block = getattr(block, attr)
# Compile the block
compiled_block = torch.compile(block, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"])
# Add the compiled block back as an object patch
model.add_object_patch(k, compiled_block)
return (new_modelpatcher, new_clip)
class PatchModelPatcherOrder:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"patch_order": (["object_patch_first", "weight_patch_first"], {"default": "weight_patch_first", "tooltip": "Patch the comfy patch_model function to load weight patches (LoRAs) before compiling the model"}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "KJNodes/experimental"
DESCTIPTION = "Patch the comfy patch_model function patching order, useful for torch.compile (used as object_patch) as it should come last if you want to use LoRAs with compile"
EXPERIMENTAL = True
def patch(self, model, patch_order):
comfy.model_patcher.ModelPatcher.temp_object_patches_backup = {}
if patch_order == "weight_patch_first":
comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model
comfy.sd.load_lora_for_models = patched_load_lora_for_models
else:
comfy.model_patcher.ModelPatcher.patch_model = original_patch_model
comfy.sd.load_lora_for_models = original_load_lora_for_models
return model,
class TorchCompileModelFluxAdvanced:
def __init__(self):
self._compiled = False
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"double_blocks": ("STRING", {"default": "0-18", "multiline": True}),
"single_blocks": ("STRING", {"default": "0-37", "multiline": True}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
def parse_blocks(self, blocks_str):
blocks = []
for part in blocks_str.split(','):
part = part.strip()
if '-' in part:
start, end = map(int, part.split('-'))
blocks.extend(range(start, end + 1))
else:
blocks.append(int(part))
return blocks
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic):
single_block_list = self.parse_blocks(single_blocks)
double_block_list = self.parse_blocks(double_blocks)
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
if not self._compiled:
try:
for i, block in enumerate(diffusion_model.double_blocks):
if i in double_block_list:
#print("Compiling double_block", i)
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
for i, block in enumerate(diffusion_model.single_blocks):
if i in single_block_list:
#print("Compiling single block", i)
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
self._compiled = True
compile_settings = {
"backend": backend,
"mode": mode,
"fullgraph": fullgraph,
"dynamic": dynamic,
}
setattr(m.model, "compile_settings", compile_settings)
except:
raise RuntimeError("Failed to compile model")
return (m, )
# rest of the layers that are not patched
# diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend)
class TorchCompileVAE:
def __init__(self):
self._compiled_encoder = False
self._compiled_decoder = False
@classmethod
def INPUT_TYPES(s):
return {"required": {
"vae": ("VAE",),
"backend": (["inductor", "cudagraphs"],),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"compile_encoder": ("BOOLEAN", {"default": True, "tooltip": "Compile encoder"}),
"compile_decoder": ("BOOLEAN", {"default": True, "tooltip": "Compile decoder"}),
}}
RETURN_TYPES = ("VAE",)
FUNCTION = "compile"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder):
if compile_encoder:
if not self._compiled_encoder:
encoder_name = "encoder"
if hasattr(vae.first_stage_model, "taesd_encoder"):
encoder_name = "taesd_encoder"
try:
setattr(
vae.first_stage_model,
encoder_name,
torch.compile(
getattr(vae.first_stage_model, encoder_name),
mode=mode,
fullgraph=fullgraph,
backend=backend,
),
)
self._compiled_encoder = True
except:
raise RuntimeError("Failed to compile model")
if compile_decoder:
if not self._compiled_decoder:
decoder_name = "decoder"
if hasattr(vae.first_stage_model, "taesd_decoder"):
decoder_name = "taesd_decoder"
try:
setattr(
vae.first_stage_model,
decoder_name,
torch.compile(
getattr(vae.first_stage_model, decoder_name),
mode=mode,
fullgraph=fullgraph,
backend=backend,
),
)
self._compiled_decoder = True
except:
raise RuntimeError("Failed to compile model")
return (vae, )
class TorchCompileControlNet:
def __init__(self):
self._compiled= False
@classmethod
def INPUT_TYPES(s):
return {"required": {
"controlnet": ("CONTROL_NET",),
"backend": (["inductor", "cudagraphs"],),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "compile"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
def compile(self, controlnet, backend, mode, fullgraph):
if not self._compiled:
try:
# for i, block in enumerate(controlnet.control_model.double_blocks):
# print("Compiling controlnet double_block", i)
# controlnet.control_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
controlnet.control_model = torch.compile(controlnet.control_model, mode=mode, fullgraph=fullgraph, backend=backend)
self._compiled = True
except:
self._compiled = False
raise RuntimeError("Failed to compile model")
return (controlnet, )
class TorchCompileLTXModel:
def __init__(self):
self._compiled = False
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
def patch(self, model, backend, mode, fullgraph, dynamic):
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
if not self._compiled:
try:
for i, block in enumerate(diffusion_model.transformer_blocks):
compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)
m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", compiled_block)
self._compiled = True
compile_settings = {
"backend": backend,
"mode": mode,
"fullgraph": fullgraph,
"dynamic": dynamic,
}
setattr(m.model, "compile_settings", compile_settings)
except:
raise RuntimeError("Failed to compile model")
return (m, )
class TorchCompileCosmosModel:
def __init__(self):
self._compiled = False
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
"dynamo_cache_size_limit": ("INT", {"default": 64, "tooltip": "Set the dynamo cache size limit"}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
def patch(self, model, backend, mode, fullgraph, dynamic, dynamo_cache_size_limit):
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
if not self._compiled:
try:
for name, block in diffusion_model.blocks.items():
#print(f"Compiling block {name}")
compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)
m.add_object_patch(f"diffusion_model.blocks.{name}", compiled_block)
#diffusion_model.blocks[name] = compiled_block
self._compiled = True
compile_settings = {
"backend": backend,
"mode": mode,
"fullgraph": fullgraph,
"dynamic": dynamic,
}
setattr(m.model, "compile_settings", compile_settings)
except:
raise RuntimeError("Failed to compile model")
return (m, )