jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import functools
import importlib
import json
import unittest
import comfy.model_management
import comfy.model_patcher
import comfy.sd
import folder_paths
import torch
from . import patchers, utils
HAS_VELOCATOR = importlib.util.find_spec("xelerate") is not None
def get_quant_inputs():
return {
"quant_type": (
[
"int8_dynamic",
"e4m3_e4m3_dynamic",
"e4m3_e4m3_dynamic_per_tensor",
"int8_weightonly",
"e4m3_weightonly",
"e4m3_e4m3_weightonly",
"e4m3_e4m3_weightonly_per_tensor",
"nf4_weightonly",
"af4_weightonly",
"int4_weightonly",
],
),
"filter_fn": (
"STRING",
{
"default": "fnmatch_matches_fqn",
},
),
"filter_fn_kwargs": (
"STRING",
{
"multiline": True,
"default": '{"pattern": ["*"]}',
},
),
"kwargs": (
"STRING",
{
"multiline": True,
# "default": "{}",
},
),
}
class VelocatorLoadAndQuantizeDiffusionModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"unet_name": (folder_paths.get_filename_list("diffusion_models"),),
**utils.get_weight_dtype_inputs(),
"lowvram": ("BOOLEAN", {"default": True}),
"full_load": ("BOOLEAN", {"default": True}),
"quantize": ("BOOLEAN", {"default": True}),
"quantize_on_load_device": ("BOOLEAN", {"default": True}),
**get_quant_inputs(),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
CATEGORY = "wavespeed/velocator"
def load_unet(
self,
unet_name,
weight_dtype,
lowvram,
full_load,
quantize,
quantize_on_load_device,
quant_type,
filter_fn,
filter_fn_kwargs,
kwargs,
):
model_options = {}
if lowvram:
model_options["initial_device"] = torch.device("cpu")
model_options = utils.parse_weight_dtype(model_options, weight_dtype)
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
quantize_fn = None
if quantize:
assert HAS_VELOCATOR, "velocator is not installed"
from xelerate.ao.quant import quantize
kwargs = json.loads(kwargs) if kwargs else {}
if lowvram and quantize_on_load_device:
preprocessor = lambda t: (
t.to(patchers.QuantizedModelPatcher._load_device)
if patchers.QuantizedModelPatcher._load_device is not None
else t
)
kwargs["preprocessor"] = preprocessor
postprocessor = lambda t: (t.to(torch.device("cpu")))
kwargs["postprocessor"] = postprocessor
quantize_fn = functools.partial(
quantize,
quant_type=quant_type,
filter_fn=filter_fn,
filter_fn_kwargs=(
json.loads(filter_fn_kwargs) if filter_fn_kwargs else {}
),
**kwargs,
)
with patchers.QuantizedModelPatcher._override_defaults(
quantize_fn=quantize_fn,
lowvram=lowvram,
full_load=full_load,
), utils.disable_load_models_gpu(), unittest.mock.patch.object(
comfy.model_patcher, "ModelPatcher", patchers.QuantizedModelPatcher
):
model = comfy.sd.load_diffusion_model(
unet_path, model_options=model_options
)
return (model,)
class VelocatorLoadAndQuantizeClip:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"clip_name1": ([""] + folder_paths.get_filename_list("text_encoders"),),
"clip_name2": ([""] + folder_paths.get_filename_list("text_encoders"),),
"clip_name3": ([""] + folder_paths.get_filename_list("text_encoders"),),
"type": ([member.name.lower() for member in comfy.sd.CLIPType],),
**utils.get_weight_dtype_inputs(),
"lowvram": ("BOOLEAN", {"default": True}),
"full_load": ("BOOLEAN", {"default": True}),
"quantize": ("BOOLEAN", {"default": True}),
"quantize_on_load_device": ("BOOLEAN", {"default": True}),
**get_quant_inputs(),
}
}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "wavespeed/velocator"
def load_clip(
self,
clip_name1,
clip_name2,
clip_name3,
type,
weight_dtype,
lowvram,
full_load,
quantize,
quantize_on_load_device,
quant_type,
filter_fn,
filter_fn_kwargs,
kwargs,
):
model_options = {}
if lowvram:
model_options["initial_device"] = torch.device("cpu")
model_options = utils.parse_weight_dtype(model_options, weight_dtype)
clip_paths = []
clip_type = None
for clip_type_ in comfy.sd.CLIPType:
if clip_type_.name.lower() == type:
clip_type = clip_type_
break
assert clip_type is not None, f"Invalid clip type: {type}"
for clip_name in [clip_name1, clip_name2, clip_name3]:
if clip_name:
clip_path = folder_paths.get_full_path_or_raise(
"text_encoders", clip_name
)
clip_paths.append(clip_path)
quantize_fn = None
if quantize:
assert HAS_VELOCATOR, "velocator is not installed"
from xelerate.ao.quant import quantize
kwargs = json.loads(kwargs) if kwargs else {}
if lowvram and quantize_on_load_device:
preprocessor = lambda t: (
t.to(patchers.QuantizedModelPatcher._load_device)
if patchers.QuantizedModelPatcher._load_device is not None
else t
)
kwargs["preprocessor"] = preprocessor
postprocessor = lambda t: (t.to(torch.device("cpu")))
kwargs["postprocessor"] = postprocessor
quantize_fn = functools.partial(
quantize,
quant_type=quant_type,
filter_fn=filter_fn,
filter_fn_kwargs=(
json.loads(filter_fn_kwargs) if filter_fn_kwargs else {}
),
**kwargs,
)
with patchers.QuantizedModelPatcher._override_defaults(
quantize_fn=quantize_fn,
lowvram=lowvram,
full_load=full_load,
), utils.disable_load_models_gpu(), unittest.mock.patch.object(
comfy.model_patcher, "ModelPatcher", patchers.QuantizedModelPatcher
):
clip = comfy.sd.load_clip(
ckpt_paths=clip_paths,
embedding_directory=folder_paths.get_folder_paths("embeddings"),
clip_type=clip_type,
model_options=model_options,
)
return (clip,)
class VelocatorQuantizeModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"object_to_patch": (
"STRING",
{
"default": "diffusion_model",
},
),
**get_quant_inputs(),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "wavespeed/velocator"
def patch(
self,
model,
object_to_patch,
quantize,
quant_type,
filter_fn,
filter_fn_kwargs,
kwargs,
):
assert HAS_VELOCATOR, "velocator is not installed"
from xelerate.ao.quant import quantize
if quantize:
comfy.model_management.unload_all_models()
comfy.model_management.load_models_gpu(
[model], force_patch_weights=True, force_full_load=True
)
filter_fn_kwargs = json.loads(filter_fn_kwargs) if filter_fn_kwargs else {}
kwargs = json.loads(kwargs) if kwargs else {}
model = model.clone()
model.add_object_patch(
object_to_patch,
quantize(
model.get_model_object(object_to_patch),
quant_type=quant_type,
filter_fn=filter_fn,
filter_fn_kwargs=filter_fn_kwargs,
**kwargs,
),
)
return (model,)
class VelocatorCompileModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (utils.any_typ,),
"is_patcher": (
"BOOLEAN",
{
"default": True,
},
),
"object_to_patch": (
"STRING",
{
"default": "diffusion_model",
},
),
"memory_format": (
["channels_last", "contiguous_format", "preserve_format"],
),
"fullgraph": (
"BOOLEAN",
{
"default": False,
},
),
"dynamic": ("BOOLEAN", {"default": False}),
"mode": (
"STRING",
{
"multiline": True,
"default": "cache-all:max-autotune:low-precision",
},
),
"options": (
"STRING",
{
"multiline": True,
# "default": "{}",
},
),
"disable": (
"BOOLEAN",
{
"default": False,
},
),
"backend": (
"STRING",
{
"default": "velocator",
},
),
}
}
RETURN_TYPES = (utils.any_typ,)
FUNCTION = "patch"
CATEGORY = "wavespeed/velocator"
def patch(
self,
model,
is_patcher,
object_to_patch,
memory_format,
fullgraph,
dynamic,
mode,
options,
disable,
backend,
):
assert HAS_VELOCATOR, "velocator is not installed"
from xelerate.compilers.xelerate_compiler import xelerate_compile
from xelerate.utils.memory_format import apply_memory_format
compile_function = xelerate_compile
memory_format = getattr(torch, memory_format)
mode = mode if mode else None
options = json.loads(options) if options else None
if backend == "velocator":
backend = "xelerate"
if is_patcher:
patcher = model.clone()
else:
patcher = model.patcher
patcher = patcher.clone()
patcher.add_object_patch(
object_to_patch,
compile_function(
apply_memory_format(
patcher.get_model_object(object_to_patch),
memory_format=memory_format,
),
fullgraph=fullgraph,
dynamic=dynamic,
mode=mode,
options=options,
disable=disable,
backend=backend,
),
)
if is_patcher:
return (patcher,)
else:
model.patcher = patcher
return (model,)