from ..utils.attn_bank import AttentionBank class LTXAttentionBankNode: @classmethod def INPUT_TYPES(s): return {"required": { "save_steps": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), "blocks": ("STRING", { "multiline": True }) }} RETURN_TYPES = ("ATTN_BANK",) FUNCTION = "build" CATEGORY = "ltxtricks" def build(self, save_steps, blocks=''): block_map = {} block_list = blocks.split(',') for block in block_list: block_idx = int(block) block_map[block_idx] = {} bank = AttentionBank(save_steps, block_map) return (bank, ) class LTXPrepareAttnInjectionsNode: @classmethod def INPUT_TYPES(s): return {"required": { "latent": ("LATENT",), "attn_bank": ("ATTN_BANK",), "query": ("BOOLEAN", { "default": False }), "key": ("BOOLEAN", { "default": False }), "value": ("BOOLEAN", { "default": False }), "inject_steps": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), }, "optional": { "blocks": ("LTX_BLOCKS",) }} RETURN_TYPES = ("LATENT", "ATTN_INJ") FUNCTION = "prepare" CATEGORY = "fluxtapoz" def prepare(self, latent, attn_bank, query, key, value, inject_steps, blocks=None): if inject_steps > attn_bank['save_steps']: raise ValueError(f"Can not inject more steps than were saved.") attn_bank = AttentionBank(attn_bank['save_steps'], attn_bank['block_map'], inject_steps) attn_bank['inject_settings'] = set([]) if query: attn_bank['inject_settings'].add('q') if key: attn_bank['inject_settings'].add('k') if value: attn_bank['inject_settings'].add('v') if blocks is not None: attn_bank['block_map'] = {**attn_bank['block_map']} for key in list(attn_bank['block_map'].keys()): if key not in blocks: del attn_bank['block_map'][key] # Hack to force order of operations in ComfyUI graph return (latent, attn_bank) class LTXAttentioOverrideNode: @classmethod def INPUT_TYPES(s): return {"required": { "blocks": ("STRING", { "multiline": True }) }} RETURN_TYPES = ("LTX_BLOCKS",) FUNCTION = "build" CATEGORY = "ltxtricks" def build(self, blocks=''): block_set = set(list(int(block) for block in blocks.split(','))) return (block_set, )