def calc_cond_batch(model, conds, x_in, timestep, model_options): if 'tiled_diffusion' not in model_options: return calc_cond_batch_original_tiled_diffusion_875b8c8d(model, conds, x_in, timestep, model_options) out_conds = [] out_counts = [] to_run = [] for i in range(len(conds)): out_conds.append(torch.zeros_like(x_in)) out_counts.append(torch.ones_like(x_in) * 1e-37) cond = conds[i] if cond is not None: for x in cond: p = get_area_and_mult(x, x_in, timestep) if p is None: continue to_run += [(p, i)] while len(to_run) > 0: first = to_run[0] first_shape = first[0][0].shape to_batch_temp = [] for x in range(len(to_run)): if can_concat_cond(to_run[x][0], first[0]): to_batch_temp += [x] to_batch_temp.reverse() to_batch = to_batch_temp[:1] free_memory = model_management.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] if model.memory_required(input_shape) * 1.5 < free_memory: to_batch = batch_amount break input_x = [] mult = [] c = [] cond_or_uncond = [] area = [] control = None patches = None for x in to_batch: o = to_run.pop(x) p = o[0] input_x.append(p.input_x) mult.append(p.mult) c.append(p.conditioning) area.append(p.area) cond_or_uncond.append(o[1]) control = p.control patches = p.patches batch_chunks = len(cond_or_uncond) input_x = c = cond_cat(c) timestep_ =[timestep] * batch_chunks) if control is not None: c['control'] = control if 'tiled_diffusion' in model_options else control.get_control(input_x, timestep_, c, len(cond_or_uncond)) transformer_options = {} if 'transformer_options' in model_options: transformer_options = model_options['transformer_options'].copy() if patches is not None: if "patches" in transformer_options: cur_patches = transformer_options["patches"].copy() for p in patches: if p in cur_patches: cur_patches[p] = cur_patches[p] + patches[p] else: cur_patches[p] = patches[p] transformer_options["patches"] = cur_patches else: transformer_options["patches"] = patches transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["sigmas"] = timestep c['transformer_options'] = transformer_options if 'model_function_wrapper' in model_options: output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) for o in range(batch_chunks): cond_index = cond_or_uncond[o] a = area[o] if a is None: out_conds[cond_index] += output[o] * mult[o] out_counts[cond_index] += mult[o] else: out_c = out_conds[cond_index] out_cts = out_counts[cond_index] dims = len(a) // 2 for i in range(dims): out_c = out_c.narrow(i + 2, a[i + dims], a[i]) out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) out_c += output[o] * mult[o] out_cts += mult[o] for i in range(len(out_conds)): out_conds[i] /= out_counts[i] return out_conds def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): # reshape and GAP the attention map _, hw1, hw2 = attn.shape b, _, lh, lw = x0.shape attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold def calc_closest_factors(a): for b in range(int(math.sqrt(a)), 0, -1): if a%b == 0: c = a // b return (b,c) m = calc_closest_factors(hw1) mh = max(m) if lh > lw else min(m) mw = m[1] if mh == m[0] else m[0] mid_shape = mh, mw # Reshape mask = ( mask.reshape(b, *mid_shape) .unsqueeze(1) .type(attn.dtype) ) # Upsample mask = F.interpolate(mask, (lh, lw)) blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) blurred = blurred * mask + x0 * (1 - mask) return blurred def pre_run_control(model, conds): s = model.model_sampling def find_outer_instance(target:str, target_type): import inspect frame = inspect.currentframe() i = 0 while frame and i < 7: if (found:=frame.f_locals.get(target, None)) is not None: if isinstance(found, target_type): return found frame = frame.f_back i += 1 return None from comfy.model_patcher import ModelPatcher if (_model:=find_outer_instance('model', ModelPatcher)) is not None: if (model_function_wrapper:=_model.model_options.get('model_function_wrapper', None)) is not None: import sys tiled_diffusion = sys.modules.get('ComfyUI-TiledDiffusion.tiled_diffusion', None) if tiled_diffusion is None: for key in sys.modules: if 'tiled_diffusion' in key: tiled_diffusion = sys.modules[key] break if (AbstractDiffusion:=getattr(tiled_diffusion, 'AbstractDiffusion', None)) is not None: if isinstance(model_function_wrapper, AbstractDiffusion): model_function_wrapper.reset() for t in range(len(conds)): x = conds[t] timestep_start = None timestep_end = None percent_to_timestep_function = lambda a: s.percent_to_sigma(a) if 'control' in x: try: x['control'].cleanup() except Exception: ... x['control'].pre_run(model, percent_to_timestep_function) def _set_position(self, boxes, masks, positive_embeddings): objs = self.position_net(boxes, masks, positive_embeddings) def func(x, extra_options): key = extra_options["transformer_index"] module = self.module_list[key] return module(x,, dtype=x.dtype)) return func