jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import json
from pathlib import Path
import comfy
import comfy.model_management
import comfy.model_patcher
import folder_paths
import safetensors.torch
import torch
from ltx_video.models.autoencoders.causal_video_autoencoder import (
CausalVideoAutoencoder,
)
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
from ltx_video.models.transformers.transformer3d import Transformer3DModel
from safetensors import safe_open
from .model import LTXVModel, LTXVModelConfig, LTXVTransformer3D
from .nodes_registry import comfy_node
from .vae import LTXVVAE
@comfy_node(name="LTXVLoader")
class LTXVLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (
folder_paths.get_filename_list("checkpoints"),
{"tooltip": "The name of the checkpoint (model) to load."},
),
"dtype": (["bfloat16", "float32"], {"default": "bfloat16"}),
}
}
RETURN_TYPES = ("MODEL", "VAE")
RETURN_NAMES = ("model", "vae")
FUNCTION = "load"
CATEGORY = "lightricks/LTXV"
TITLE = "LTXV Loader"
OUTPUT_NODE = False
def load(self, ckpt_name, dtype):
dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32}
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
ckpt_path = Path(folder_paths.get_full_path("checkpoints", ckpt_name))
vae_config = None
unet_config = None
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata is not None:
config_metadata = metadata.get("config", None)
if config_metadata is not None:
config_metadata = json.loads(config_metadata)
vae_config = config_metadata.get("vae", None)
unet_config = config_metadata.get("transformer", None)
weights = safetensors.torch.load_file(ckpt_path, device="cpu")
vae = self._load_vae(weights, vae_config)
num_latent_channels = vae.first_stage_model.config.latent_channels
model = self._load_unet(
load_device,
offload_device,
weights,
num_latent_channels,
dtype=dtype_map[dtype],
config=unet_config,
)
return (model, vae)
def _load_vae(self, weights, config=None):
if config is None:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"blocks": [
["res_x", 4],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x", 3],
["res_x", 4],
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
}
vae_prefix = "vae."
vae = LTXVVAE.from_config_and_state_dict(
vae_class=CausalVideoAutoencoder,
config=config,
state_dict={
key.removeprefix(vae_prefix): value
for key, value in weights.items()
if key.startswith(vae_prefix)
},
)
return vae
def _load_unet(
self,
load_device,
offload_device,
weights,
num_latent_channels,
dtype,
config=None,
):
if config is None:
config = {
"_class_name": "Transformer3DModel",
"_diffusers_version": "0.25.1",
"_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256",
"activation_fn": "gelu-approximate",
"attention_bias": True,
"attention_head_dim": 64,
"attention_type": "default",
"caption_channels": 4096,
"cross_attention_dim": 2048,
"double_self_attention": False,
"dropout": 0.0,
"in_channels": 128,
"norm_elementwise_affine": False,
"norm_eps": 1e-06,
"norm_num_groups": 32,
"num_attention_heads": 32,
"num_embeds_ada_norm": 1000,
"num_layers": 28,
"num_vector_embeds": None,
"only_cross_attention": False,
"out_channels": 128,
"project_to_2d_pos": True,
"upcast_attention": False,
"use_linear_projection": False,
"qk_norm": "rms_norm",
"standardization_norm": "rms_norm",
"positional_embedding_type": "rope",
"positional_embedding_theta": 10000.0,
"positional_embedding_max_pos": [20, 2048, 2048],
"timestep_scale_multiplier": 1000,
}
transformer = Transformer3DModel.from_config(config)
unet_prefix = "model.diffusion_model."
transformer.load_state_dict(
{
key.removeprefix(unet_prefix): value
for key, value in weights.items()
if key.startswith(unet_prefix)
}
)
transformer.to(dtype).to(load_device).eval()
patchifier = SymmetricPatchifier(1)
diffusion_model = LTXVTransformer3D(transformer, patchifier, None, None, None)
model = LTXVModel(
LTXVModelConfig(num_latent_channels, dtype=dtype),
model_type=comfy.model_base.ModelType.FLOW,
device=comfy.model_management.get_torch_device(),
)
model.diffusion_model = diffusion_model
patcher = comfy.model_patcher.ModelPatcher(model, load_device, offload_device)
return patcher