import folder_paths import importlib import json import comfy.sd from . import utils class EnhancedLoadDiffusionModel: @classmethod def INPUT_TYPES(s): return { "required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"),), **utils.get_weight_dtype_inputs(), } } RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" CATEGORY = "wavespeed" def load_unet(self, unet_name, weight_dtype): model_options = {} model_options = utils.parse_weight_dtype(model_options, weight_dtype) unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) return (model,) class EnhancedCompileModel: @classmethod def INPUT_TYPES(s): return { "required": { "model": (utils.any_typ,), "is_patcher": ( "BOOLEAN", { "default": True, }, ), "object_to_patch": ( "STRING", { "default": "diffusion_model", }, ), "compiler": ( "STRING", { "default": "torch.compile", } ), "fullgraph": ( "BOOLEAN", { "default": False, }, ), "dynamic": ("BOOLEAN", {"default": False}), "mode": ( "STRING", { "multiline": True, "default": "", }, ), "options": ( "STRING", { "multiline": True, # "default": "{}", }, ), "disable": ( "BOOLEAN", { "default": False, }, ), "backend": ( "STRING", { "default": "inductor", }, ), } } RETURN_TYPES = (utils.any_typ,) FUNCTION = "patch" CATEGORY = "wavespeed" def patch( self, model, is_patcher, object_to_patch, compiler, fullgraph, dynamic, mode, options, disable, backend, ): utils.patch_optimized_module() utils.patch_same_meta() import_path, function_name = compiler.rsplit(".", 1) module = importlib.import_module(import_path) compile_function = getattr(module, function_name) mode = mode if mode else None options = json.loads(options) if options else None if compiler == "torch.compile" and backend == "inductor" and dynamic: # TODO: Fix this # File "pytorch/torch/_inductor/fx_passes/post_grad.py", line 643, in same_meta # and statically_known_true(sym_eq(val1.size(), val2.size())) # AttributeError: 'SymInt' object has no attribute 'size' pass if is_patcher: patcher = model.clone() else: patcher = model.patcher patcher = patcher.clone() patcher.add_object_patch( object_to_patch, compile_function( patcher.get_model_object(object_to_patch), fullgraph=fullgraph, dynamic=dynamic, mode=mode, options=options, disable=disable, backend=backend, ), ) if is_patcher: return (patcher,) else: model.patcher = patcher return (model,)