import json from pathlib import Path import comfy import comfy.model_management import comfy.model_patcher import folder_paths import safetensors.torch import torch from ltx_video.models.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier from ltx_video.models.transformers.transformer3d import Transformer3DModel from safetensors import safe_open from .model import LTXVModel, LTXVModelConfig, LTXVTransformer3D from .nodes_registry import comfy_node from .vae import LTXVVAE @comfy_node(name="LTXVLoader") class LTXVLoader: @classmethod def INPUT_TYPES(s): return { "required": { "ckpt_name": ( folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}, ), "dtype": (["bfloat16", "float32"], {"default": "bfloat16"}), } } RETURN_TYPES = ("MODEL", "VAE") RETURN_NAMES = ("model", "vae") FUNCTION = "load" CATEGORY = "lightricks/LTXV" TITLE = "LTXV Loader" OUTPUT_NODE = False def load(self, ckpt_name, dtype): dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} load_device = comfy.model_management.get_torch_device() offload_device = comfy.model_management.unet_offload_device() ckpt_path = Path(folder_paths.get_full_path("checkpoints", ckpt_name)) vae_config = None unet_config = None with safe_open(ckpt_path, framework="pt", device="cpu") as f: metadata = f.metadata() if metadata is not None: config_metadata = metadata.get("config", None) if config_metadata is not None: config_metadata = json.loads(config_metadata) vae_config = config_metadata.get("vae", None) unet_config = config_metadata.get("transformer", None) weights = safetensors.torch.load_file(ckpt_path, device="cpu") vae = self._load_vae(weights, vae_config) num_latent_channels = vae.first_stage_model.config.latent_channels model = self._load_unet( load_device, offload_device, weights, num_latent_channels, dtype=dtype_map[dtype], config=unet_config, ) return (model, vae) def _load_vae(self, weights, config=None): if config is None: config = { "_class_name": "CausalVideoAutoencoder", "dims": 3, "in_channels": 3, "out_channels": 3, "latent_channels": 128, "blocks": [ ["res_x", 4], ["compress_all", 1], ["res_x_y", 1], ["res_x", 3], ["compress_all", 1], ["res_x_y", 1], ["res_x", 3], ["compress_all", 1], ["res_x", 3], ["res_x", 4], ], "scaling_factor": 1.0, "norm_layer": "pixel_norm", "patch_size": 4, "latent_log_var": "uniform", "use_quant_conv": False, "causal_decoder": False, } vae_prefix = "vae." vae = LTXVVAE.from_config_and_state_dict( vae_class=CausalVideoAutoencoder, config=config, state_dict={ key.removeprefix(vae_prefix): value for key, value in weights.items() if key.startswith(vae_prefix) }, ) return vae def _load_unet( self, load_device, offload_device, weights, num_latent_channels, dtype, config=None, ): if config is None: config = { "_class_name": "Transformer3DModel", "_diffusers_version": "0.25.1", "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", "activation_fn": "gelu-approximate", "attention_bias": True, "attention_head_dim": 64, "attention_type": "default", "caption_channels": 4096, "cross_attention_dim": 2048, "double_self_attention": False, "dropout": 0.0, "in_channels": 128, "norm_elementwise_affine": False, "norm_eps": 1e-06, "norm_num_groups": 32, "num_attention_heads": 32, "num_embeds_ada_norm": 1000, "num_layers": 28, "num_vector_embeds": None, "only_cross_attention": False, "out_channels": 128, "project_to_2d_pos": True, "upcast_attention": False, "use_linear_projection": False, "qk_norm": "rms_norm", "standardization_norm": "rms_norm", "positional_embedding_type": "rope", "positional_embedding_theta": 10000.0, "positional_embedding_max_pos": [20, 2048, 2048], "timestep_scale_multiplier": 1000, } transformer = Transformer3DModel.from_config(config) unet_prefix = "model.diffusion_model." transformer.load_state_dict( { key.removeprefix(unet_prefix): value for key, value in weights.items() if key.startswith(unet_prefix) } ) transformer.to(dtype).to(load_device).eval() patchifier = SymmetricPatchifier(1) diffusion_model = LTXVTransformer3D(transformer, patchifier, None, None, None) model = LTXVModel( LTXVModelConfig(num_latent_channels, dtype=dtype), model_type=comfy.model_base.ModelType.FLOW, device=comfy.model_management.get_torch_device(), ) model.diffusion_model = diffusion_model patcher = comfy.model_patcher.ModelPatcher(model, load_device, offload_device) return patcher