|
import torch |
|
|
|
import comfy.sd |
|
import comfy.model_sampling |
|
|
|
|
|
DEFAULT_REGIONAL_ATTN = { |
|
'double': [i for i in range(1, 19, 2)], |
|
'single': [i for i in range(1, 38, 2)] |
|
} |
|
|
|
|
|
class RegionalMask(torch.nn.Module): |
|
def __init__(self, mask: torch.Tensor, start_percent: float, end_percent: float) -> None: |
|
super().__init__() |
|
self.register_buffer('mask', mask) |
|
self.start_percent = start_percent |
|
self.end_percent = end_percent |
|
|
|
def __call__(self, q, transformer_options, *args, **kwargs): |
|
if self.start_percent <= 1 - transformer_options['sigmas'][0] < self.end_percent: |
|
return self.mask |
|
|
|
return None |
|
|
|
|
|
class RegionalConditioning(torch.nn.Module): |
|
def __init__(self, region_cond: torch.Tensor, start_percent: float, end_percent: float) -> None: |
|
super().__init__() |
|
self.register_buffer('region_cond', region_cond) |
|
self.start_percent = start_percent |
|
self.end_percent = end_percent |
|
|
|
def __call__(self, transformer_options, *args, **kwargs): |
|
if self.start_percent <= 1 - transformer_options['sigmas'][0] < self.end_percent: |
|
return self.region_cond |
|
return None |
|
|
|
|
|
class CreateRegionalCondNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"cond": ("CONDITIONING",), |
|
"mask": ("MASK",), |
|
}, "optional": { |
|
"prev_regions": ("REGION_COND",), |
|
}} |
|
|
|
RETURN_TYPES = ("REGION_COND",) |
|
FUNCTION = "create" |
|
|
|
CATEGORY = "fluxtapoz" |
|
|
|
def create(self, cond, mask, prev_regions=[]): |
|
prev_regions = [*prev_regions] |
|
prev_regions.append({ |
|
'mask': mask, |
|
'cond': cond[0][0] |
|
}) |
|
|
|
return (prev_regions,) |
|
|
|
|
|
class ApplyRegionalCondsNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"model": ("MODEL",), |
|
"region_conds": ("REGION_COND",), |
|
"latent": ("LATENT",), |
|
"start_percent": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01}), |
|
"end_percent": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01}), |
|
}, "optional": { |
|
"attn_override": ("ATTN_OVERRIDE",) |
|
}} |
|
|
|
RETURN_TYPES = ("MODEL",) |
|
FUNCTION = "patch" |
|
|
|
CATEGORY = "fluxtapoz" |
|
|
|
def patch(self, model, region_conds, latent, start_percent, end_percent, attn_override=DEFAULT_REGIONAL_ATTN): |
|
model = model.clone() |
|
|
|
latent = latent['samples'] |
|
b, c, h, w = latent.shape |
|
h //=2 |
|
w //=2 |
|
|
|
img_len = h*w |
|
|
|
regional_conditioning = torch.cat([region_cond['cond'] for region_cond in region_conds], dim=1) |
|
text_len = 256 + regional_conditioning.shape[1] |
|
|
|
regional_mask = torch.zeros((text_len + img_len, text_len + img_len), dtype=torch.bool) |
|
|
|
self_attend_masks = torch.zeros((img_len, img_len), dtype=torch.bool) |
|
union_masks = torch.zeros((img_len, img_len), dtype=torch.bool) |
|
|
|
region_conds = [ |
|
{ |
|
'mask': torch.ones((1, h, w), dtype=torch.float16), |
|
'cond': torch.ones((1, 256, 4096), dtype=torch.float16) |
|
}, |
|
*region_conds |
|
] |
|
|
|
current_seq_len = 0 |
|
for region_cond_dict in region_conds: |
|
region_cond = region_cond_dict['cond'] |
|
region_mask = 1 - region_cond_dict['mask'][0] |
|
region_mask = torch.nn.functional.interpolate(region_mask[None, None, :, :], (h, w), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, region_cond.size(1)) |
|
next_seq_len = current_seq_len + region_cond.shape[1] |
|
|
|
|
|
regional_mask[current_seq_len:next_seq_len, current_seq_len:next_seq_len] = True |
|
|
|
|
|
regional_mask[current_seq_len:next_seq_len, text_len:] = region_mask.transpose(-1, -2) |
|
|
|
|
|
regional_mask[text_len:, current_seq_len:next_seq_len] = region_mask |
|
|
|
|
|
img_size_masks = region_mask[:, :1].repeat(1, img_len) |
|
img_size_masks_transpose = img_size_masks.transpose(-1, -2) |
|
self_attend_masks = torch.logical_or(self_attend_masks, |
|
torch.logical_and(img_size_masks, img_size_masks_transpose)) |
|
|
|
|
|
union_masks = torch.logical_or(union_masks, |
|
torch.logical_or(img_size_masks, img_size_masks_transpose)) |
|
|
|
current_seq_len = next_seq_len |
|
|
|
background_masks = torch.logical_not(union_masks) |
|
background_and_self_attend_masks = torch.logical_or(background_masks, self_attend_masks) |
|
regional_mask[text_len:, text_len:] = background_and_self_attend_masks |
|
|
|
|
|
regional_mask = RegionalMask(regional_mask, start_percent, end_percent) |
|
regional_conditioning = RegionalConditioning(regional_conditioning, start_percent, end_percent) |
|
|
|
model.set_model_patch(regional_conditioning, 'regional_conditioning') |
|
|
|
for block_idx in attn_override['double']: |
|
model.set_model_patch_replace(regional_mask, f"double", "mask_fn", int(block_idx)) |
|
|
|
for block_idx in attn_override['single']: |
|
model.set_model_patch_replace(regional_mask, f"single", "mask_fn", int(block_idx)) |
|
|
|
return (model,) |
|
|