|
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
|
|
|
if 'tiled_diffusion' not in model_options:
|
|
return calc_cond_batch_original_tiled_diffusion_875b8c8d(model, conds, x_in, timestep, model_options)
|
|
out_conds = []
|
|
out_counts = []
|
|
to_run = []
|
|
|
|
for i in range(len(conds)):
|
|
out_conds.append(torch.zeros_like(x_in))
|
|
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
|
|
|
cond = conds[i]
|
|
if cond is not None:
|
|
for x in cond:
|
|
p = get_area_and_mult(x, x_in, timestep)
|
|
if p is None:
|
|
continue
|
|
|
|
to_run += [(p, i)]
|
|
|
|
while len(to_run) > 0:
|
|
first = to_run[0]
|
|
first_shape = first[0][0].shape
|
|
to_batch_temp = []
|
|
for x in range(len(to_run)):
|
|
if can_concat_cond(to_run[x][0], first[0]):
|
|
to_batch_temp += [x]
|
|
|
|
to_batch_temp.reverse()
|
|
to_batch = to_batch_temp[:1]
|
|
|
|
free_memory = model_management.get_free_memory(x_in.device)
|
|
for i in range(1, len(to_batch_temp) + 1):
|
|
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
|
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
|
if model.memory_required(input_shape) * 1.5 < free_memory:
|
|
to_batch = batch_amount
|
|
break
|
|
|
|
input_x = []
|
|
mult = []
|
|
c = []
|
|
cond_or_uncond = []
|
|
area = []
|
|
control = None
|
|
patches = None
|
|
for x in to_batch:
|
|
o = to_run.pop(x)
|
|
p = o[0]
|
|
input_x.append(p.input_x)
|
|
mult.append(p.mult)
|
|
c.append(p.conditioning)
|
|
area.append(p.area)
|
|
cond_or_uncond.append(o[1])
|
|
control = p.control
|
|
patches = p.patches
|
|
|
|
batch_chunks = len(cond_or_uncond)
|
|
input_x = torch.cat(input_x)
|
|
c = cond_cat(c)
|
|
timestep_ = torch.cat([timestep] * batch_chunks)
|
|
|
|
if control is not None:
|
|
c['control'] = control if 'tiled_diffusion' in model_options else control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
|
|
|
transformer_options = {}
|
|
if 'transformer_options' in model_options:
|
|
transformer_options = model_options['transformer_options'].copy()
|
|
|
|
if patches is not None:
|
|
if "patches" in transformer_options:
|
|
cur_patches = transformer_options["patches"].copy()
|
|
for p in patches:
|
|
if p in cur_patches:
|
|
cur_patches[p] = cur_patches[p] + patches[p]
|
|
else:
|
|
cur_patches[p] = patches[p]
|
|
transformer_options["patches"] = cur_patches
|
|
else:
|
|
transformer_options["patches"] = patches
|
|
|
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
|
transformer_options["sigmas"] = timestep
|
|
|
|
c['transformer_options'] = transformer_options
|
|
|
|
if 'model_function_wrapper' in model_options:
|
|
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
|
else:
|
|
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
|
|
|
for o in range(batch_chunks):
|
|
cond_index = cond_or_uncond[o]
|
|
a = area[o]
|
|
if a is None:
|
|
out_conds[cond_index] += output[o] * mult[o]
|
|
out_counts[cond_index] += mult[o]
|
|
else:
|
|
out_c = out_conds[cond_index]
|
|
out_cts = out_counts[cond_index]
|
|
dims = len(a) // 2
|
|
for i in range(dims):
|
|
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
|
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
|
out_c += output[o] * mult[o]
|
|
out_cts += mult[o]
|
|
|
|
for i in range(len(out_conds)):
|
|
out_conds[i] /= out_counts[i]
|
|
|
|
return out_conds
|
|
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
|
|
|
_, hw1, hw2 = attn.shape
|
|
b, _, lh, lw = x0.shape
|
|
attn = attn.reshape(b, -1, hw1, hw2)
|
|
|
|
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
|
|
|
def calc_closest_factors(a):
|
|
for b in range(int(math.sqrt(a)), 0, -1):
|
|
if a%b == 0:
|
|
c = a // b
|
|
return (b,c)
|
|
m = calc_closest_factors(hw1)
|
|
mh = max(m) if lh > lw else min(m)
|
|
mw = m[1] if mh == m[0] else m[0]
|
|
mid_shape = mh, mw
|
|
|
|
|
|
mask = (
|
|
mask.reshape(b, *mid_shape)
|
|
.unsqueeze(1)
|
|
.type(attn.dtype)
|
|
)
|
|
|
|
mask = F.interpolate(mask, (lh, lw))
|
|
|
|
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
|
blurred = blurred * mask + x0 * (1 - mask)
|
|
return blurred
|
|
|
|
def pre_run_control(model, conds):
|
|
s = model.model_sampling
|
|
|
|
def find_outer_instance(target:str, target_type):
|
|
import inspect
|
|
frame = inspect.currentframe()
|
|
i = 0
|
|
while frame and i < 7:
|
|
if (found:=frame.f_locals.get(target, None)) is not None:
|
|
if isinstance(found, target_type):
|
|
return found
|
|
frame = frame.f_back
|
|
i += 1
|
|
return None
|
|
from comfy.model_patcher import ModelPatcher
|
|
if (_model:=find_outer_instance('model', ModelPatcher)) is not None:
|
|
if (model_function_wrapper:=_model.model_options.get('model_function_wrapper', None)) is not None:
|
|
import sys
|
|
tiled_diffusion = sys.modules.get('ComfyUI-TiledDiffusion.tiled_diffusion', None)
|
|
if tiled_diffusion is None:
|
|
for key in sys.modules:
|
|
if 'tiled_diffusion' in key:
|
|
tiled_diffusion = sys.modules[key]
|
|
break
|
|
if (AbstractDiffusion:=getattr(tiled_diffusion, 'AbstractDiffusion', None)) is not None:
|
|
if isinstance(model_function_wrapper, AbstractDiffusion):
|
|
model_function_wrapper.reset()
|
|
|
|
for t in range(len(conds)):
|
|
x = conds[t]
|
|
|
|
timestep_start = None
|
|
timestep_end = None
|
|
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
|
if 'control' in x:
|
|
try: x['control'].cleanup()
|
|
except Exception: ...
|
|
x['control'].pre_run(model, percent_to_timestep_function)
|
|
def _set_position(self, boxes, masks, positive_embeddings):
|
|
objs = self.position_net(boxes, masks, positive_embeddings)
|
|
def func(x, extra_options):
|
|
key = extra_options["transformer_index"]
|
|
module = self.module_list[key]
|
|
return module(x, objs.to(device=x.device, dtype=x.dtype))
|
|
return func
|
|
|
|
|