import inspect
import importlib
from textwrap import dedent, indent
from copy import copy
import types
import functools
import os
import sys
import binascii
from typing import List, NamedTuple
class Hook(NamedTuple):
fn: object
module_name: str
target: str
orig_key: str
module_name_path: str
def gen_id():
return binascii.hexlify(os.urandom(1024))[64:72].decode("utf-8")
def hook_calc_cond_uncond_batch():
from comfy.samplers import calc_cond_batch
calc_cond_batch_ = calc_cond_batch
except Exception:
from comfy.samplers import calc_cond_uncond_batch
calc_cond_batch_ = calc_cond_uncond_batch
orig_key = f"{calc_cond_batch_.__name__}_original_tiled_diffusion_{gen_id()}"
payload = [{
"mode": "replace",
"target_line": 'control.get_control',
"code_to_insert": """control if 'tiled_diffusion' in model_options else control.get_control"""
"dedent": False,
"target_line": calc_cond_batch_.__name__,
"code_to_insert": f"""
if 'tiled_diffusion' not in model_options:
return {orig_key}{inspect.signature(calc_cond_batch_)}"""
fn = inject_code(calc_cond_batch_, payload, 'w')
return create_hook(fn, 'comfy.samplers', orig_key=orig_key)
def hook_sag_create_blur_map():
imported = False
import comfy_extras
from comfy_extras import nodes_sag
imported = True
except Exception: ...
if not imported: return
import comfy_extras
from comfy_extras import nodes_sag
import re
def calc_closest_factors(a):
for b in range(int(math.sqrt(a)), 0, -1):
if a%b == 0:
c = a // b
return (b,c)
m = calc_closest_factors(hw1)
mh = max(m) if lh > lw else min(m)
mw = m[1] if mh == m[0] else m[0]
mid_shape = mh, mw"""
modified_source = re.sub(r"ratio =.*\s+mid_shape =.*", replace_str, source, flags=re.MULTILINE)
fn = write_to_file_and_return_fn(nodes_sag.create_blur_map, modified_source)
return create_hook(fn, 'comfy_extras.nodes_sag')
def hook_samplers_pre_run_control():
from comfy.samplers import pre_run_control
payload = [{
"dedent": False,
"target_line": "if 'control' in x:",
"code_to_insert": """ try: x['control'].cleanup()\n except Exception: ..."""
"target_line": "s = model.model_sampling",
"code_to_insert": """
def find_outer_instance(target:str, target_type):
import inspect
frame = inspect.currentframe()
i = 0
while frame and i < 7:
if (found:=frame.f_locals.get(target, None)) is not None:
if isinstance(found, target_type):
return found
frame = frame.f_back
i += 1
return None
from comfy.model_patcher import ModelPatcher
if (_model:=find_outer_instance('model', ModelPatcher)) is not None:
if (model_function_wrapper:=_model.model_options.get('model_function_wrapper', None)) is not None:
import sys
tiled_diffusion = sys.modules.get('ComfyUI-TiledDiffusion.tiled_diffusion', None)
if tiled_diffusion is None:
for key in sys.modules:
if 'tiled_diffusion' in key:
tiled_diffusion = sys.modules[key]
if (AbstractDiffusion:=getattr(tiled_diffusion, 'AbstractDiffusion', None)) is not None:
if isinstance(model_function_wrapper, AbstractDiffusion):
fn = inject_code(pre_run_control, payload)
return create_hook(fn, 'comfy.samplers')
def hook_gligen__set_position():
from comfy.gligen import Gligen
nonlocal objs
if x.shape[0] > objs.shape[0]:
_objs = objs.repeat(-(x.shape[0] // -objs.shape[0]),1,1)
_objs = objs
return module(x, _objs)"""
modified_source = dedent(source.replace(" return module(x, objs)", replace_str, 1))
fn = write_to_file_and_return_fn(Gligen._set_position, modified_source)
return create_hook(fn, 'comfy.gligen', 'Gligen._set_position')
def create_hook(fn, module_name:str, target = None, orig_key = None):
if target is None: target = fn.__name__
if orig_key is None: orig_key = f'{target}_original'
module_name_path = os.path.normpath(module_name.replace('.', '/'))
return Hook(fn, module_name, target, orig_key, module_name_path)
def _getattr(obj, name:str, default=None):
"""multi-level getattr"""
for attr in name.split('.'):
obj = getattr(obj, attr, default)
return obj
def _hasattr(obj, name:str):
"""multi-level hasattr"""
return _getattr(obj, name) is not None
def _setattr(obj, name:str, value=None):
"""multi-level setattr"""
split = name.split('.')
if not split[:-1]:
return setattr(obj, name, value)
name = split[-1]
for attr in split[:-1]:
obj = getattr(obj, attr, None)
return setattr(obj, name, value)
def hook_all(restore=False, hooks=None):
if hooks is None:
hooks: List[Hook] = [
for key, module in sys.modules.items():
for hook in hooks:
if key == hook.module_name or key.endswith(hook.module_name_path):
if _hasattr(module, hook.target):
if not _hasattr(module, hook.orig_key):
if (orig_fn:=_getattr(module, hook.target, None)) is not None:
_setattr(module, hook.orig_key, orig_fn)
if restore:
_setattr(module, hook.target, _getattr(module, hook.orig_key, None))
_setattr(module, hook.target, hook.fn)
def inject_code(original_func, data, mode='a'):
original_source = inspect.getsource(original_func)
lines = original_source.split("\n")
for item in data:
target_line_number = None
for i, line in enumerate(lines):
if item['target_line'] not in line: continue
target_line_number = i + 1
if item.get("mode","insert") == "replace":
lines[i] = lines[i].replace(item['target_line'], item['code_to_insert'])
indentation = ''
for char in line:
if char == ' ':
indentation += char
code_to_insert = item['code_to_insert']
if item.get("dedent",True):
code_to_insert = dedent(item['code_to_insert'])
code_to_insert = indent(code_to_insert, indentation)
if item.get("mode","insert") == "insert" and target_line_number is not None:
lines.insert(target_line_number, code_to_insert)
modified_source = "\n".join(lines)
modified_source = dedent(modified_source.strip("\n"))
return write_to_file_and_return_fn(original_func, modified_source, mode)
def write_to_file_and_return_fn(original_func, source:str, mode='a'):
custom_name = ".patches.py"
current_dir = os.path.dirname(os.path.abspath(__file__))
temp_file_path = os.path.join(current_dir, custom_name)
with open(temp_file_path, mode) as temp_file:
MODULE_PATH = temp_file.name
MODULE_NAME = __name__.split('.')[0].replace('-','_') + "_patch_modules"
spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
modified_function = getattr(module, original_func.__name__)
def copy_func(f, globals=None, module=None, code=None, update_wrapper=True):
if globals is None: globals = f.__globals__
if code is None: code = f.__code__
g = types.FunctionType(code, globals, name=f.__name__,
argdefs=f.__defaults__, closure=f.__closure__)
if update_wrapper: g = functools.update_wrapper(g, f)
if module is not None: g.__module__ = module
g.__kwdefaults__ = copy(f.__kwdefaults__)
return g
return copy_func(original_func, code=modified_function.__code__, update_wrapper=False)