File size: 11,347 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
from typing import Callable, Union
import comfy.hooks
import comfy.model_patcher
import comfy.patcher_extension
import comfy.sample
import comfy.samplers
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
from comfy.ldm.modules.attention import BasicTransformerBlock
from .control import convert_all_to_advanced, restore_all_controlnet_conns
from .control_reference import (ReferenceAdvanced, ReferenceInjections,
RefBasicTransformerBlock, RefTimestepEmbedSequential,
InjectionBasicTransformerBlockHolder, InjectionTimestepEmbedSequentialHolder,
_forward_inject_BasicTransformerBlock,
handle_context_ref_setup, handle_reference_injection,
REF_CONTROL_LIST_ALL, CONTEXTREF_CLEAN_FUNC)
from .dinklink import get_dinklink
from .utils import torch_dfs, WrapperConsts
def prepare_dinklink_acn_wrapper():
# expose acn_sampler_sample_wrapper
d = get_dinklink()
link_acn = d.setdefault(WrapperConsts.ACN, {})
link_acn[WrapperConsts.VERSION] = 10000
link_acn[WrapperConsts.ACN_CREATE_SAMPLER_SAMPLE_WRAPPER] = (comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
WrapperConsts.ACN_OUTER_SAMPLE_WRAPPER_KEY,
acn_outer_sample_wrapper)
def support_sliding_context_windows(conds) -> tuple[bool, list[dict]]:
# convert to advanced, with report if anything was actually modified
modified, new_conds = convert_all_to_advanced(conds)
return modified, new_conds
def has_sliding_context_windows(model: ModelPatcher):
params = model.get_attachment("ADE_params")
if params is None:
# backwards compatibility
params = getattr(model, "motion_injection_params", None)
if params is None:
return False
context_options = getattr(params, "context_options")
return context_options.context_length is not None
def get_contextref_obj(model: ModelPatcher):
params = model.get_attachment("ADE_params")
if params is None:
# backwards compatibility
params = getattr(model, "motion_injection_params", None)
if params is None:
return None
context_options = getattr(params, "context_options")
extras = getattr(context_options, "extras", None)
if extras is None:
return None
return getattr(extras, "context_ref", None)
def get_refcn(control: ControlBase, order: int=-1):
ref_set: set[ReferenceAdvanced] = set()
if control is None:
return ref_set
if type(control) == ReferenceAdvanced and not control.is_context_ref:
control.order = order
order -= 1
ref_set.add(control)
ref_set.update(get_refcn(control.previous_controlnet, order=order))
return ref_set
def should_register_outer_sample_wrapper(hook, model, model_options: dict, target, registered: list):
wrappers = comfy.patcher_extension.get_wrappers_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
WrapperConsts.ACN_OUTER_SAMPLE_WRAPPER_KEY,
model_options, is_model_options=True)
return len(wrappers) == 0
def create_wrapper_hooks():
wrappers = {}
comfy.patcher_extension.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
WrapperConsts.ACN_OUTER_SAMPLE_WRAPPER_KEY,
acn_outer_sample_wrapper,
transformer_options=wrappers)
hooks = comfy.hooks.HookGroup()
hook = comfy.hooks.WrapperHook(wrappers)
hook.hook_id = WrapperConsts.ACN_OUTER_SAMPLE_WRAPPER_KEY
hook.custom_should_register = should_register_outer_sample_wrapper
hooks.add(hook)
return hooks
def acn_outer_sample_wrapper(executor, *args, **kwargs):
controlnets_modified = False
guider: comfy.samplers.CFGGuider = executor.class_obj
model = guider.model_patcher
orig_conds = guider.conds
orig_model_options = guider.model_options
try:
new_model_options = orig_model_options
# if context options present, perform some special actions that may be required
context_refs = []
if has_sliding_context_windows(guider.model_patcher):
new_model_options = comfy.model_patcher.create_model_options_clone(new_model_options)
# convert all CNs to Advanced if needed
controlnets_modified, conds = support_sliding_context_windows(orig_conds)
if controlnets_modified:
guider.conds = conds
# enable ContextRef, if requested
existing_contextref_obj = get_contextref_obj(guider.model_patcher)
if existing_contextref_obj is not None:
context_refs = handle_context_ref_setup(existing_contextref_obj, new_model_options["transformer_options"], guider.conds)
controlnets_modified = True
# look for Advanced ControlNets that will require intervention to work
ref_set = set()
for outer_cond in guider.conds.values():
for cond in outer_cond:
if "control" in cond:
ref_set.update(get_refcn(cond["control"]))
# if no ref cn found, do original function immediately
if len(ref_set) == 0 and len(context_refs) == 0:
return executor(*args, **kwargs)
# otherwise, injection time
try:
# inject
# storage for all Reference-related injections
reference_injections = ReferenceInjections()
# first, handle attn module injection
all_modules = torch_dfs(model.model)
attn_modules: list[RefBasicTransformerBlock] = []
for module in all_modules:
if isinstance(module, BasicTransformerBlock):
attn_modules.append(module)
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
for i, module in enumerate(attn_modules):
injection_holder = InjectionBasicTransformerBlockHolder(block=module, idx=i)
injection_holder.attn_weight = float(i) / float(len(attn_modules))
if hasattr(module, "_forward"): # backward compatibility
module._forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
else:
module.forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
module.injection_holder = injection_holder
reference_injections.attn_modules.append(module)
# figure out which module is middle block
if hasattr(model.model.diffusion_model, "middle_block"):
mid_modules = torch_dfs(model.model.diffusion_model.middle_block)
mid_attn_modules: list[RefBasicTransformerBlock] = [module for module in mid_modules if isinstance(module, BasicTransformerBlock)]
for module in mid_attn_modules:
module.injection_holder.is_middle = True
# next, handle gn module injection (TimestepEmbedSequential)
# TODO: figure out the logic behind these hardcoded indexes
if type(model.model).__name__ == "SDXL":
input_block_indices = [4, 5, 7, 8]
output_block_indices = [0, 1, 2, 3, 4, 5]
else:
input_block_indices = [4, 5, 7, 8, 10, 11]
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
if hasattr(model.model.diffusion_model, "middle_block"):
module = model.model.diffusion_model.middle_block
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=0, is_middle=True)
injection_holder.gn_weight = 0.0
module.injection_holder = injection_holder
reference_injections.gn_modules.append(module)
for w, i in enumerate(input_block_indices):
module = model.model.diffusion_model.input_blocks[i]
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_input=True)
injection_holder.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
module.injection_holder = injection_holder
reference_injections.gn_modules.append(module)
for w, i in enumerate(output_block_indices):
module = model.model.diffusion_model.output_blocks[i]
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_output=True)
injection_holder.gn_weight = float(w) / float(len(output_block_indices))
module.injection_holder = injection_holder
reference_injections.gn_modules.append(module)
# hack gn_module forwards and update weights
for i, module in enumerate(reference_injections.gn_modules):
module.injection_holder.gn_weight *= 2
# store ordered ref cns in model's transformer options
new_model_options = comfy.model_patcher.create_model_options_clone(new_model_options)
# handle diffusion_model forward injection
handle_reference_injection(new_model_options, reference_injections)
ref_list: list[ReferenceAdvanced] = list(ref_set)
new_model_options["transformer_options"][REF_CONTROL_LIST_ALL] = sorted(ref_list, key=lambda x: x.order)
new_model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC] = reference_injections.clean_contextref_module_mem
guider.model_options = new_model_options
# continue with original function
return executor(*args, **kwargs)
finally:
# cleanup injections
# restore attn modules
attn_modules: list[RefBasicTransformerBlock] = reference_injections.attn_modules
for module in attn_modules:
module.injection_holder.restore(module)
module.injection_holder.clean_all()
del module.injection_holder
del attn_modules
# restore gn modules
gn_modules: list[RefTimestepEmbedSequential] = reference_injections.gn_modules
for module in gn_modules:
module.injection_holder.restore(module)
module.injection_holder.clean_all()
del module.injection_holder
del gn_modules
# cleanup
reference_injections.cleanup()
finally:
# restore model_options
guider.model_options = orig_model_options
# restore guider.conds
guider.conds = orig_conds
# restore controlnets in conds, if needed
if controlnets_modified:
restore_all_controlnet_conns(guider.conds)
del orig_conds
del orig_model_options
del model
del guider
|