|
import math |
|
from contextlib import nullcontext |
|
|
|
import comfy.latent_formats |
|
import comfy.model_base |
|
import comfy.model_management |
|
import comfy.model_patcher |
|
import comfy.model_sampling |
|
import comfy.sd |
|
import comfy.supported_models_base |
|
import comfy.utils |
|
import torch |
|
import torch.nn as nn |
|
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier |
|
from ltx_video.models.transformers.transformer3d import Transformer3DModel |
|
|
|
|
|
class LTXVModelConfig: |
|
def __init__(self, latent_channels, dtype): |
|
self.unet_config = {} |
|
self.unet_extra_config = {} |
|
self.latent_format = comfy.latent_formats.LatentFormat() |
|
self.latent_format.latent_channels = latent_channels |
|
self.manual_cast_dtype = dtype |
|
self.sampling_settings = {"multiplier": 1.0} |
|
self.memory_usage_factor = 2.7 |
|
|
|
self.unet_config["disable_unet_model_creation"] = True |
|
|
|
|
|
class LTXVSampling(torch.nn.Module, comfy.model_sampling.CONST): |
|
def __init__(self, condition_mask, guiding_latent=None): |
|
super().__init__() |
|
self.condition_mask = condition_mask |
|
self.guiding_latent = guiding_latent |
|
self.set_parameters(shift=1.0, multiplier=1) |
|
|
|
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000): |
|
self.shift = shift |
|
self.multiplier = multiplier |
|
ts = self.sigma((torch.arange(0, timesteps + 1, 1) / timesteps) * multiplier) |
|
self.register_buffer("sigmas", ts) |
|
|
|
@property |
|
def sigma_min(self): |
|
return self.sigmas[0] |
|
|
|
@property |
|
def sigma_max(self): |
|
return self.sigmas[-1] |
|
|
|
def timestep(self, sigma): |
|
return sigma * self.multiplier |
|
|
|
def sigma(self, timestep): |
|
return timestep |
|
|
|
def percent_to_sigma(self, percent): |
|
if percent <= 0.0: |
|
return 1.0 |
|
if percent >= 1.0: |
|
return 0.0 |
|
return 1.0 - percent |
|
|
|
def calculate_input(self, sigma, noise): |
|
if self.guiding_latent is not None: |
|
noise = ( |
|
noise * (1 - self.condition_mask) |
|
+ self.guiding_latent * self.condition_mask |
|
) |
|
return noise |
|
|
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): |
|
self.condition_mask = self.condition_mask.to(latent_image.device) |
|
scaled = latent_image * (1 - sigma) + noise * sigma |
|
result = latent_image * self.condition_mask + scaled * (1 - self.condition_mask) |
|
|
|
return result |
|
|
|
def calculate_denoised(self, sigma, model_output, model_input): |
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) |
|
result = model_input - model_output * sigma |
|
|
|
if self.guiding_latent is not None: |
|
result = ( |
|
result * (1 - self.condition_mask) |
|
+ self.guiding_latent * self.condition_mask |
|
) |
|
else: |
|
result = ( |
|
result * (1 - self.condition_mask) + model_input * self.condition_mask |
|
) |
|
return result |
|
|
|
|
|
class LTXVModel(comfy.model_base.BaseModel): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.model_sampling = LTXVSampling(torch.zeros([1])) |
|
|
|
|
|
class LTXVTransformer3D(nn.Module): |
|
def __init__( |
|
self, |
|
transformer: Transformer3DModel, |
|
patchifier: SymmetricPatchifier, |
|
conditioning_mask, |
|
latent_frame_rate, |
|
vae_scale_factor, |
|
): |
|
super().__init__() |
|
|
|
self.dtype = transformer.dtype |
|
self.transformer = transformer |
|
self.patchifier = patchifier |
|
self.conditioning_mask = conditioning_mask |
|
self.latent_frame_rate = latent_frame_rate |
|
self.vae_scale_factor = vae_scale_factor |
|
|
|
def indices_grid( |
|
self, |
|
latent_shape, |
|
device, |
|
): |
|
use_rope = self.transformer.use_rope |
|
scale_grid = ( |
|
(1 / self.latent_frame_rate, self.vae_scale_factor, self.vae_scale_factor) |
|
if use_rope |
|
else None |
|
) |
|
|
|
indices_grid = self.patchifier.get_grid( |
|
orig_num_frames=latent_shape[2], |
|
orig_height=latent_shape[3], |
|
orig_width=latent_shape[4], |
|
batch_size=latent_shape[0], |
|
scale_grid=scale_grid, |
|
device=device, |
|
) |
|
|
|
return indices_grid |
|
|
|
def wrapped_transformer( |
|
self, |
|
latent, |
|
timesteps, |
|
context, |
|
indices_grid, |
|
skip_layer_mask=None, |
|
skip_layer_strategy=None, |
|
img_hw=None, |
|
aspect_ratio=None, |
|
mixed_precision=True, |
|
**kwargs, |
|
): |
|
|
|
latent = latent.to(self.transformer.dtype) |
|
latent_patchified = self.patchifier.patchify(latent) |
|
context_mask = (context != 0).any(dim=2).to(self.transformer.dtype) |
|
|
|
if mixed_precision: |
|
context_manager = torch.autocast("cuda", dtype=torch.bfloat16) |
|
else: |
|
context_manager = nullcontext() |
|
with context_manager: |
|
noise_pred = self.transformer( |
|
latent_patchified.to(self.transformer.dtype).to( |
|
self.transformer.device |
|
), |
|
indices_grid.to(self.transformer.device), |
|
encoder_hidden_states=context.to(self.transformer.device), |
|
encoder_attention_mask=context_mask.to(self.transformer.device).to( |
|
torch.int64 |
|
), |
|
timestep=timesteps, |
|
skip_layer_mask=skip_layer_mask, |
|
skip_layer_strategy=skip_layer_strategy, |
|
return_dict=False, |
|
)[0] |
|
|
|
result = self.patchifier.unpatchify( |
|
latents=noise_pred, |
|
output_height=latent.shape[3], |
|
output_width=latent.shape[4], |
|
output_num_frames=latent.shape[2], |
|
out_channels=latent.shape[1] // math.prod(self.patchifier.patch_size), |
|
) |
|
return result |
|
|
|
def forward(self, x, timesteps, context, img_hw=None, aspect_ratio=None, **kwargs): |
|
transformer_options = kwargs.get("transformer_options", {}) |
|
ptb_index = transformer_options.get("ptb_index", None) |
|
mixed_precision = transformer_options.get("mixed_precision", False) |
|
cond_or_uncond = transformer_options.get("cond_or_uncond", []) |
|
skip_block_list = transformer_options.get("skip_block_list", []) |
|
skip_layer_strategy = transformer_options.get("skip_layer_strategy", None) |
|
mask = self.patchifier.patchify(self.conditioning_mask).squeeze(-1).to(x.device) |
|
ndim_mask = mask.ndimension() |
|
expanded_timesteps = timesteps.view(timesteps.size(0), *([1] * (ndim_mask - 1))) |
|
timesteps_masked = expanded_timesteps * (1 - mask) |
|
skip_layer_mask = None |
|
if ptb_index is not None and ptb_index in cond_or_uncond: |
|
skip_layer_mask = self.transformer.create_skip_layer_mask( |
|
skip_block_list, |
|
1, |
|
len(cond_or_uncond), |
|
len(cond_or_uncond) - 1 - cond_or_uncond.index(ptb_index), |
|
) |
|
|
|
result = self.wrapped_transformer( |
|
x, |
|
timesteps_masked, |
|
context, |
|
indices_grid=self.indices_grid(x.shape, x.device), |
|
mixed_precision=mixed_precision, |
|
skip_layer_mask=skip_layer_mask, |
|
skip_layer_strategy=skip_layer_strategy, |
|
) |
|
return result |
|
|