|
from ..utils.attn_bank import AttentionBank |
|
|
|
|
|
class LTXAttentionBankNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"save_steps": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), |
|
"blocks": ("STRING", { "multiline": True }) |
|
}} |
|
RETURN_TYPES = ("ATTN_BANK",) |
|
FUNCTION = "build" |
|
|
|
CATEGORY = "ltxtricks" |
|
|
|
def build(self, save_steps, blocks=''): |
|
block_map = {} |
|
block_list = blocks.split(',') |
|
for block in block_list: |
|
block_idx = int(block) |
|
block_map[block_idx] = {} |
|
|
|
bank = AttentionBank(save_steps, block_map) |
|
return (bank, ) |
|
|
|
|
|
class LTXPrepareAttnInjectionsNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"latent": ("LATENT",), |
|
"attn_bank": ("ATTN_BANK",), |
|
"query": ("BOOLEAN", { "default": False }), |
|
"key": ("BOOLEAN", { "default": False }), |
|
"value": ("BOOLEAN", { "default": False }), |
|
"inject_steps": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), |
|
}, "optional": { |
|
"blocks": ("LTX_BLOCKS",) |
|
}} |
|
|
|
RETURN_TYPES = ("LATENT", "ATTN_INJ") |
|
FUNCTION = "prepare" |
|
|
|
CATEGORY = "fluxtapoz" |
|
|
|
def prepare(self, latent, attn_bank, query, key, value, inject_steps, blocks=None): |
|
if inject_steps > attn_bank['save_steps']: |
|
raise ValueError(f"Can not inject more steps than were saved.") |
|
attn_bank = AttentionBank(attn_bank['save_steps'], attn_bank['block_map'], inject_steps) |
|
attn_bank['inject_settings'] = set([]) |
|
if query: |
|
attn_bank['inject_settings'].add('q') |
|
if key: |
|
attn_bank['inject_settings'].add('k') |
|
if value: |
|
attn_bank['inject_settings'].add('v') |
|
|
|
if blocks is not None: |
|
attn_bank['block_map'] = {**attn_bank['block_map']} |
|
for key in list(attn_bank['block_map'].keys()): |
|
if key not in blocks: |
|
del attn_bank['block_map'][key] |
|
|
|
|
|
return (latent, attn_bank) |
|
|
|
|
|
class LTXAttentioOverrideNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"blocks": ("STRING", { "multiline": True }) |
|
}} |
|
RETURN_TYPES = ("LTX_BLOCKS",) |
|
FUNCTION = "build" |
|
|
|
CATEGORY = "ltxtricks" |
|
|
|
def build(self, blocks=''): |
|
block_set = set(list(int(block) for block in blocks.split(','))) |
|
|
|
return (block_set, ) |
|
|