import torch from torch import nn import comfy.ldm.modules.attention import comfy.ldm.common_dit import math from comfy.ldm.lightricks.model import apply_rotary_emb, precompute_freqs_cis, LTXVModel, BasicTransformerBlock from ..utils.latent_guide import LatentGuide class LTXModifiedCrossAttention(nn.Module): def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): context = x if context is None else context context_v = x if context is None else context step = transformer_options.get('step', -1) total_steps = transformer_options.get('total_steps', 0) attn_bank = transformer_options.get('attn_bank', None) sample_mode = transformer_options.get('sample_mode', None) if attn_bank is not None and self.idx in attn_bank['block_map']: len_conds = len(transformer_options['cond_or_uncond']) pred_order = transformer_options['pred_order'] if sample_mode == 'forward' and total_steps-step-1 < attn_bank['save_steps']: step_idx = f'{pred_order}_{total_steps-step-1}' attn_bank['block_map'][self.idx][step_idx] = x.cpu() elif sample_mode == 'reverse' and step < attn_bank['inject_steps']: step_idx = f'{pred_order}_{step}' inject_settings = attn_bank.get('inject_settings', {}) if len(inject_settings) > 0: inj = attn_bank['block_map'][self.idx][step_idx].to(x.device).repeat(len_conds, 1, 1) if 'q' in inject_settings: x = inj if 'k' in inject_settings: context = inj if 'v' in inject_settings: context_v = inj q = self.to_q(x) k = self.to_k(context) v = self.to_v(context_v) q = self.q_norm(q) k = self.k_norm(k) if pe is not None: q = apply_rotary_emb(q, pe) k = apply_rotary_emb(k, pe) alt_attn_fn = transformer_options.get('patches_replace', {}).get(f'layer', {}).get(('self_attn', self.idx), None) if alt_attn_fn is not None: out = alt_attn_fn(q,k,v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) elif mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) else: out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) return self.to_out(out) class LTXModifiedBasicTransformerBlock(BasicTransformerBlock): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa x += self.attn2(x, context=context, mask=attention_mask) y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp x += self.ff(y) * gate_mlp return x class LTXVModelModified(LTXVModel): def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latents={}, transformer_options={}, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) guiding_latents = transformer_options.get('patches', {}).get('guiding_latents', None) indices_grid = self.patchifier.get_grid( orig_num_frames=x.shape[2], orig_height=x.shape[3], orig_width=x.shape[4], batch_size=x.shape[0], scale_grid=((1 / frame_rate) * 8, 32, 32), device=x.device, ) ts = None input_x = None if guiding_latent is not None: guiding_latents = guiding_latents if guiding_latents is not None else [] guiding_latents.append(LatentGuide(guiding_latent, 0)) if guiding_latents is not None: input_x = x.clone() ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype) input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1)) ts *= input_ts for guide in guiding_latents: ts[:, :, guide.index] = 0.0 x[:,:,guide.index] = guide.latent[:,:,0] timestep = self.patchifier.patchify(ts) orig_shape = list(x.shape) transformer_options['original_shape'] = orig_shape x = self.patchifier.patchify(x) x = self.patchify_proj(x) timestep = timestep * 1000.0 attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this # attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype) pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype) batch_size = x.shape[0] timestep, embedded_timestep = self.adaln_single( timestep.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=x.dtype, ) # Second dimension is 1 or number of tokens (if timestep_per_token) timestep = timestep.view(batch_size, -1, timestep.shape[-1]) embedded_timestep = embedded_timestep.view( batch_size, -1, embedded_timestep.shape[-1] ) # 2. Blocks if self.caption_projection is not None: batch_size = x.shape[0] context = self.caption_projection(context) context = context.view( batch_size, -1, x.shape[-1] ) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) x = out["img"] else: x = block( x, context=context, attention_mask=attention_mask, timestep=timestep, pe=pe, transformer_options=transformer_options ) # 3. Output scale_shift_values = ( self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) # Modulation x = x * (1 + scale) + shift x = self.proj_out(x) x = self.patchifier.unpatchify( latents=x, output_height=orig_shape[3], output_width=orig_shape[4], output_num_frames=orig_shape[2], out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), ) if guiding_latents is not None: for guide in guiding_latents: x[:, :, guide.index] = (input_x[:, :, guide.index] - guide.latent[:, :, 0]) / input_ts[:, :, 0] return x def inject_model(diffusion_model): diffusion_model.__class__ = LTXVModelModified for idx, transformer_block in enumerate(diffusion_model.transformer_blocks): transformer_block.__class__ = LTXModifiedBasicTransformerBlock transformer_block.idx = idx transformer_block.attn1.__class__ = LTXModifiedCrossAttention transformer_block.attn1.idx = idx return diffusion_model