|
import torch |
|
|
|
import comfy.model_sampling |
|
|
|
from ..utils.mask_utils import consolidate_masks |
|
|
|
|
|
DEFAULT_REGIONAL_ATTN = { |
|
'double': [i for i in range(1, 100, 1)], |
|
'single': [i for i in range(1, 100, 2)] |
|
} |
|
|
|
|
|
class RegionalMask(torch.nn.Module): |
|
def __init__(self, mask: torch.Tensor, start_percent: float, end_percent: float) -> None: |
|
super().__init__() |
|
self.register_buffer('mask', mask) |
|
self.start_percent = start_percent |
|
self.end_percent = end_percent |
|
|
|
def __call__(self, q, transformer_options, *args, **kwargs): |
|
if self.start_percent <= 1 - transformer_options['sigmas'][0] < self.end_percent: |
|
return self.mask |
|
|
|
return None |
|
|
|
|
|
class RegionalConditioning(torch.nn.Module): |
|
def __init__(self, region_cond: torch.Tensor, start_percent: float, end_percent: float, main_cond_strength: float, always_included: bool) -> None: |
|
super().__init__() |
|
self.register_buffer('region_cond', region_cond) |
|
self.start_percent = start_percent |
|
self.end_percent = end_percent |
|
self.main_cond_strength = main_cond_strength |
|
self.always_included = always_included |
|
|
|
def __call__(self, context, transformer_options, *args, **kwargs): |
|
if self.start_percent <= 1 - transformer_options['sigmas'][0] < self.end_percent or self.always_included: |
|
return torch.cat([context*self.main_cond_strength, self.region_cond.to(context.dtype)], dim=1) |
|
return context |
|
|
|
|
|
class HYCreateRegionalCondNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"cond": ("CONDITIONING",), |
|
"mask": ("MASK",), |
|
"cond_strength": ("FLOAT", {"default": 1, "min": 0.0, "max": 10.0, "step": 0.01, "round": 0.01}), |
|
"mask_consolidation": (["first_only", "select_first", "select_last", "union"], { "tooltip": "How to choose the masking when they are compressed for the latents temporally."}), |
|
}, "optional": { |
|
"prev_regions": ("REGION_COND",), |
|
}} |
|
|
|
RETURN_TYPES = ("REGION_COND",) |
|
FUNCTION = "create" |
|
|
|
CATEGORY = "hunyuanloom" |
|
|
|
def create(self, cond, mask, cond_strength, mask_consolidation, prev_regions=[]): |
|
prev_regions = [*prev_regions] |
|
prev_regions.append({ |
|
'mask': mask, |
|
'cond': cond[0][0] * cond_strength, |
|
'mask_consolidation': mask_consolidation, |
|
}) |
|
|
|
return (prev_regions,) |
|
|
|
|
|
class HYApplyRegionalCondsNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"model": ("MODEL",), |
|
"cond": ("CONDITIONING",), |
|
"region_conds": ("REGION_COND",), |
|
"latent": ("LATENT",), |
|
"start_percent": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01}), |
|
"end_percent": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01}), |
|
"main_cond_strength": ("FLOAT", {"default": 1, "min": 0.0, "max": 10.0, "step": 0.01, "round": 0.01}), |
|
"always_included": ("BOOLEAN", { "default": False, "tooltip": "Whether to keep the prompts always but allow them to affect the entire video (unmasked) outside of end/start percents." }), |
|
}, "optional": { |
|
"attn_override": ("ATTN_OVERRIDE",) |
|
}} |
|
|
|
RETURN_TYPES = ("MODEL",) |
|
FUNCTION = "patch" |
|
|
|
CATEGORY = "hunyuanloom" |
|
|
|
def patch(self, model, cond, region_conds, latent, start_percent, end_percent, main_cond_strength, always_included, attn_override=DEFAULT_REGIONAL_ATTN): |
|
model = model.clone() |
|
|
|
latent = latent['samples'] |
|
b, c, f, h, w = latent.shape |
|
h //=2 |
|
w //=2 |
|
|
|
img_len = h*w*f |
|
|
|
main_cond = cond[0][0] |
|
main_cond_size = main_cond.shape[1] |
|
|
|
regional_conditioning = torch.cat([region_cond['cond'] for region_cond in region_conds], dim=1) |
|
text_len = main_cond_size + regional_conditioning.shape[1] |
|
|
|
regional_mask = torch.zeros((text_len + img_len, text_len + img_len), dtype=torch.bool) |
|
|
|
self_attend_masks = torch.zeros((img_len, img_len), dtype=torch.bool) |
|
union_masks = torch.zeros((img_len, img_len), dtype=torch.bool) |
|
|
|
main_mask = torch.zeros((f, h, w), dtype=torch.float16) |
|
if main_cond_strength == 0: |
|
main_mask = torch.ones((f, h, w), dtype=torch.float16) |
|
|
|
region_conds = [ |
|
{ |
|
'mask': main_mask, |
|
'cond': torch.ones((1, main_cond_size, 4096), dtype=torch.float16), |
|
'mask_consolidation': 'first_only', |
|
}, |
|
*region_conds |
|
] |
|
|
|
current_seq_len = 0 |
|
for region_cond_dict in region_conds: |
|
region_cond = region_cond_dict['cond'] |
|
region_mask = region_cond_dict['mask'] |
|
region_mask = consolidate_masks(region_mask, f, region_cond_dict['mask_consolidation']) |
|
region_mask = 1 - region_mask |
|
region_mask = torch.nn.functional.interpolate(region_mask[None, :, :, :], (h, w), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, region_cond.size(1)) |
|
next_seq_len = current_seq_len + region_cond.shape[1] |
|
|
|
|
|
regional_mask[current_seq_len:next_seq_len, current_seq_len:next_seq_len] = True |
|
|
|
|
|
regional_mask[current_seq_len:next_seq_len, text_len:] = region_mask.transpose(-1, -2) |
|
|
|
|
|
regional_mask[text_len:, current_seq_len:next_seq_len] = region_mask |
|
|
|
|
|
img_size_masks = region_mask[:, :1].repeat(1, img_len) |
|
img_size_masks_transpose = img_size_masks.transpose(-1, -2) |
|
self_attend_masks = torch.logical_or(self_attend_masks, |
|
torch.logical_and(img_size_masks, img_size_masks_transpose)) |
|
|
|
|
|
union_masks = torch.logical_or(union_masks, |
|
torch.logical_or(img_size_masks, img_size_masks_transpose)) |
|
|
|
current_seq_len = next_seq_len |
|
|
|
background_masks = torch.logical_not(union_masks) |
|
background_and_self_attend_masks = torch.logical_or(background_masks, self_attend_masks) |
|
regional_mask[text_len:, text_len:] = background_and_self_attend_masks |
|
|
|
|
|
text_rows, img_rows = regional_mask.split([text_len, img_len], dim=0) |
|
|
|
|
|
text_rows_text_cols, text_rows_img_cols = text_rows.split([text_len, img_len], dim=1) |
|
img_rows_text_cols, img_rows_img_cols = img_rows.split([text_len, img_len], dim=1) |
|
|
|
double_regional_mask = torch.cat([ |
|
torch.cat([img_rows_img_cols, img_rows_text_cols], dim=1), |
|
torch.cat([text_rows_img_cols, text_rows_text_cols], dim=1) |
|
], dim=0) |
|
|
|
|
|
double_regional_mask = RegionalMask(double_regional_mask, start_percent, end_percent) |
|
single_regional_mask = RegionalMask(regional_mask, start_percent, end_percent) |
|
regional_conditioning = RegionalConditioning(regional_conditioning, start_percent, end_percent, main_cond_strength, always_included) |
|
|
|
|
|
model.set_model_patch(regional_conditioning, 'regional_conditioning') |
|
|
|
for block_idx in attn_override['double']: |
|
model.set_model_patch_replace(double_regional_mask, f"double", "mask_fn", int(block_idx)) |
|
|
|
for block_idx in attn_override['single']: |
|
model.set_model_patch_replace(single_regional_mask, f"single", "mask_fn", int(block_idx)) |
|
|
|
return (model,) |
|
|