File size: 4,156 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import folder_paths
import importlib
import json
import comfy.sd
from . import utils
class EnhancedLoadDiffusionModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"unet_name": (folder_paths.get_filename_list("diffusion_models"),),
**utils.get_weight_dtype_inputs(),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
CATEGORY = "wavespeed"
def load_unet(self, unet_name, weight_dtype):
model_options = {}
model_options = utils.parse_weight_dtype(model_options, weight_dtype)
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
class EnhancedCompileModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (utils.any_typ,),
"is_patcher": (
"BOOLEAN",
{
"default": True,
},
),
"object_to_patch": (
"STRING",
{
"default": "diffusion_model",
},
),
"compiler": (
"STRING",
{
"default": "torch.compile",
}
),
"fullgraph": (
"BOOLEAN",
{
"default": False,
},
),
"dynamic": ("BOOLEAN", {"default": False}),
"mode": (
"STRING",
{
"multiline": True,
"default": "",
},
),
"options": (
"STRING",
{
"multiline": True,
# "default": "{}",
},
),
"disable": (
"BOOLEAN",
{
"default": False,
},
),
"backend": (
"STRING",
{
"default": "inductor",
},
),
}
}
RETURN_TYPES = (utils.any_typ,)
FUNCTION = "patch"
CATEGORY = "wavespeed"
def patch(
self,
model,
is_patcher,
object_to_patch,
compiler,
fullgraph,
dynamic,
mode,
options,
disable,
backend,
):
utils.patch_optimized_module()
utils.patch_same_meta()
import_path, function_name = compiler.rsplit(".", 1)
module = importlib.import_module(import_path)
compile_function = getattr(module, function_name)
mode = mode if mode else None
options = json.loads(options) if options else None
if compiler == "torch.compile" and backend == "inductor" and dynamic:
# TODO: Fix this
# File "pytorch/torch/_inductor/fx_passes/post_grad.py", line 643, in same_meta
# and statically_known_true(sym_eq(val1.size(), val2.size()))
# AttributeError: 'SymInt' object has no attribute 'size'
pass
if is_patcher:
patcher = model.clone()
else:
patcher = model.patcher
patcher = patcher.clone()
patcher.add_object_patch(
object_to_patch,
compile_function(
patcher.get_model_object(object_to_patch),
fullgraph=fullgraph,
dynamic=dynamic,
mode=mode,
options=options,
disable=disable,
backend=backend,
),
)
if is_patcher:
return (patcher,)
else:
model.patcher = patcher
return (model,)
|