jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
from ..utils.attn_bank import AttentionBank
class LTXAttentionBankNode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"save_steps": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
"blocks": ("STRING", { "multiline": True })
}}
RETURN_TYPES = ("ATTN_BANK",)
FUNCTION = "build"
CATEGORY = "ltxtricks"
def build(self, save_steps, blocks=''):
block_map = {}
block_list = blocks.split(',')
for block in block_list:
block_idx = int(block)
block_map[block_idx] = {}
bank = AttentionBank(save_steps, block_map)
return (bank, )
class LTXPrepareAttnInjectionsNode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"latent": ("LATENT",),
"attn_bank": ("ATTN_BANK",),
"query": ("BOOLEAN", { "default": False }),
"key": ("BOOLEAN", { "default": False }),
"value": ("BOOLEAN", { "default": False }),
"inject_steps": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
}, "optional": {
"blocks": ("LTX_BLOCKS",)
}}
RETURN_TYPES = ("LATENT", "ATTN_INJ")
FUNCTION = "prepare"
CATEGORY = "fluxtapoz"
def prepare(self, latent, attn_bank, query, key, value, inject_steps, blocks=None):
if inject_steps > attn_bank['save_steps']:
raise ValueError(f"Can not inject more steps than were saved.")
attn_bank = AttentionBank(attn_bank['save_steps'], attn_bank['block_map'], inject_steps)
attn_bank['inject_settings'] = set([])
if query:
attn_bank['inject_settings'].add('q')
if key:
attn_bank['inject_settings'].add('k')
if value:
attn_bank['inject_settings'].add('v')
if blocks is not None:
attn_bank['block_map'] = {**attn_bank['block_map']}
for key in list(attn_bank['block_map'].keys()):
if key not in blocks:
del attn_bank['block_map'][key]
# Hack to force order of operations in ComfyUI graph
return (latent, attn_bank)
class LTXAttentioOverrideNode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"blocks": ("STRING", { "multiline": True })
}}
RETURN_TYPES = ("LTX_BLOCKS",)
FUNCTION = "build"
CATEGORY = "ltxtricks"
def build(self, blocks=''):
block_set = set(list(int(block) for block in blocks.split(',')))
return (block_set, )