from copy import copy import comfy.latent_formats import comfy.model_base import comfy.model_management import comfy.model_patcher import comfy.sd import comfy.supported_models_base import comfy.utils import torch from diffusers.image_processor import VaeImageProcessor from ltx_video.models.autoencoders.vae_encode import ( get_vae_size_scale_factor, vae_decode, vae_encode, ) from .nodes_registry import comfy_node class LTXVVAE(comfy.sd.VAE): def __init__(self, decode_timestep=0.05, decode_noise_scale=0.025, seed=42): self.device = comfy.model_management.vae_device() self.offload_device = comfy.model_management.vae_offload_device() self.decode_timestep = decode_timestep self.decode_noise_scale = decode_noise_scale self.seed = seed @classmethod def from_pretrained(cls, vae_class, model_path, dtype=torch.bfloat16): instance = cls() model = vae_class.from_pretrained( pretrained_model_name_or_path=model_path, revision=None, torch_dtype=dtype, load_in_8bit=False, ) instance._finalize_model(model) return instance @classmethod def from_config_and_state_dict( cls, vae_class, config, state_dict, dtype=torch.bfloat16 ): instance = cls() model = vae_class.from_config(config) model.load_state_dict(state_dict) model.to(dtype) instance._finalize_model(model) return instance def _finalize_model(self, model): self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor( model ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.first_stage_model = model.eval().to(self.device) # Assumes that the input samples have dimensions in following order # (batch, channels, frames, height, width) def decode(self, samples_in): is_video = samples_in.shape[2] > 1 decode_timestep = self.decode_timestep if getattr(self.first_stage_model.decoder, "timestep_conditioning", False): samples_in = self.add_noise( decode_timestep, self.decode_noise_scale, self.seed, samples_in ) else: decode_timestep = None result = vae_decode( samples_in.to(self.device), vae=self.first_stage_model, is_video=is_video, vae_per_channel_normalize=True, timestep=decode_timestep, ) result = self.image_processor.postprocess( result, output_type="pt", do_denormalize=[True] ) return result.squeeze(0).permute(1, 2, 3, 0).to(torch.float32) @staticmethod def add_noise(decode_timestep, decode_noise_scale, seed, latents): generator = torch.Generator(device="cpu").manual_seed(seed) noise = torch.randn( latents.size(), generator=generator, device=latents.device, dtype=latents.dtype, ) if not isinstance(decode_timestep, list): decode_timestep = [decode_timestep] * latents.shape[0] if decode_noise_scale is None: decode_noise_scale = decode_timestep elif not isinstance(decode_noise_scale, list): decode_noise_scale = [decode_noise_scale] * latents.shape[0] decode_timestep = torch.tensor(decode_timestep).to(latents.device) decode_noise_scale = torch.tensor(decode_noise_scale).to(latents.device)[ :, None, None, None, None ] latents = latents * (1 - decode_noise_scale) + noise * decode_noise_scale return latents # Underlying VAE expects b, c, n, h, w dimensions order and dtype specific dtype. # However in Comfy the convension is n, h, w, c. def encode(self, pixel_samples): preprocessed = self.image_processor.preprocess( pixel_samples.permute(3, 0, 1, 2) ) input = preprocessed.unsqueeze(0).to(torch.bfloat16).to(self.device) latents = vae_encode( input, self.first_stage_model, vae_per_channel_normalize=True ).to(comfy.model_management.get_torch_device()) return latents @comfy_node(name="Add VAE Decoder Noise") class DecoderNoise: @classmethod def INPUT_TYPES(cls): return { "required": { "vae": ("VAE",), "timestep": ( "FLOAT", { "default": 0.05, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The timestep used for decoding the noise.", }, ), "scale": ( "FLOAT", { "default": 0.025, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "The scale of the noise added to the decoder.", }, ), "seed": ( "INT", { "default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "tooltip": "The random seed used for creating the noise.", }, ), } } FUNCTION = "add_noise" RETURN_TYPES = ("VAE",) CATEGORY = "lightricks/LTXV" def add_noise(self, vae, timestep, scale, seed): result = copy(vae) result.decode_timestep = timestep result.decode_noise_scale = scale result.seed = seed return (result,)