|
store = {}
|
|
|
|
|
|
|
|
import comfy.samplers
|
|
|
|
def patch1(fn_name):
|
|
def calc_cond_batch(*args, **kwargs):
|
|
x_in = kwargs['x_in'] if 'x_in' in kwargs else args[2]
|
|
model_options = kwargs['model_options'] if 'model_options' in kwargs else args[4]
|
|
if not hasattr(x_in, 'model_options'):
|
|
x_in.model_options = model_options
|
|
return store[fn_name](*args, **kwargs)
|
|
return calc_cond_batch
|
|
|
|
def patch2(fn_name):
|
|
def get_area_and_mult(*args, **kwargs):
|
|
x_in = kwargs['x_in'] if 'x_in' in kwargs else args[1]
|
|
conds = kwargs['conds'] if 'conds' in kwargs else args[0]
|
|
if (model_options:=getattr(x_in, 'model_options', None)) is not None and 'tiled_diffusion' in model_options:
|
|
if 'control' in conds:
|
|
control = conds['control']
|
|
if not hasattr(control, 'get_control_orig'):
|
|
control.get_control_orig = control.get_control
|
|
control.get_control = lambda *a, **kw: control
|
|
else:
|
|
if 'control' in conds:
|
|
control = conds['control']
|
|
if hasattr(control, 'get_control_orig') and control.get_control != control.get_control_orig:
|
|
control.get_control = control.get_control_orig
|
|
return store[fn_name](*args, **kwargs)
|
|
return get_area_and_mult
|
|
|
|
patches = [
|
|
(comfy.samplers, 'calc_cond_batch', patch1),
|
|
(comfy.samplers, 'get_area_and_mult', patch2),
|
|
]
|
|
for parent, fn_name, create_patch in patches:
|
|
store[fn_name] = getattr(parent, fn_name)
|
|
setattr(parent, fn_name, create_patch(fn_name))
|
|
|
|
|
|
|
|
|
|
def pre_run_control(model, conds):
|
|
s = model.model_sampling
|
|
for t in range(len(conds)):
|
|
x = conds[t]
|
|
|
|
timestep_start = None
|
|
timestep_end = None
|
|
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
|
if 'control' in x:
|
|
try: x['control'].cleanup()
|
|
except Exception: ...
|
|
x['control'].pre_run(model, percent_to_timestep_function)
|
|
comfy.samplers.pre_run_control = pre_run_control
|
|
|
|
|
|
|
|
import math
|
|
import torch.nn.functional as F
|
|
import comfy_extras.nodes_sag
|
|
from comfy_extras.nodes_sag import gaussian_blur_2d
|
|
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
|
|
|
_, hw1, hw2 = attn.shape
|
|
b, _, lh, lw = x0.shape
|
|
attn = attn.reshape(b, -1, hw1, hw2)
|
|
|
|
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
|
def calc_closest_factors(a):
|
|
for b in range(int(math.sqrt(a)), 0, -1):
|
|
if a % b == 0:
|
|
c = a // b
|
|
return (b,c)
|
|
m = calc_closest_factors(hw1)
|
|
mh = max(m) if lh > lw else min(m)
|
|
mw = m[1] if mh == m[0] else m[0]
|
|
mid_shape = mh, mw
|
|
|
|
|
|
mask = (
|
|
mask.reshape(b, *mid_shape)
|
|
.unsqueeze(1)
|
|
.type(attn.dtype)
|
|
)
|
|
|
|
mask = F.interpolate(mask, (lh, lw))
|
|
|
|
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
|
blurred = blurred * mask + x0 * (1 - mask)
|
|
return blurred
|
|
comfy_extras.nodes_sag.create_blur_map = create_blur_map
|
|
|
|
|
|
|
|
def _set_position(self, boxes, masks, positive_embeddings):
|
|
objs = self.position_net(boxes, masks, positive_embeddings)
|
|
def func(x, extra_options):
|
|
key = extra_options["transformer_index"]
|
|
module = self.module_list[key]
|
|
nonlocal objs
|
|
_objs = objs.repeat(-(x.shape[0] // -objs.shape[0]),1,1) if x.shape[0] > objs.shape[0] else objs
|
|
return module(x, _objs.to(device=x.device, dtype=x.dtype))
|
|
return func
|
|
|
|
import comfy.gligen
|
|
comfy.gligen.Gligen._set_position = _set_position
|
|
|