File size: 11,001 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
# IF_TrellisCheckpointLoader.py
import os
import sys
import importlib
import torch
import logging
import folder_paths
from huggingface_hub import hf_hub_download, snapshot_download
from pathlib import Path
import json
from trellis_model_manager import TrellisModelManager
from trellis.pipelines.trellis_image_to_3d import TrellisImageTo3DPipeline
from trellis.modules import set_attention_backend
from typing import Literal
from trellis.modules.attention_utils import enable_sage_attention, disable_sage_attention
logger = logging.getLogger("IF_Trellis")
def set_backend(backend: Literal['spconv', 'torchsparse']):
# Example helper if you wish to call the underlying global set_backend from trellis.modules.sparse:
from trellis.modules.sparse import set_backend as _set_sparse_backend
# Also handle spconv algo if desired, e.g. os.environ['SPCONV_ALGO'] = ...
_set_sparse_backend(backend)
class TrellisConfig:
"""Global configuration for Trellis"""
def __init__(self):
self.logger = logger
self.attention_backend = "sage"
self.spconv_algo = "implicit_gemm"
self.smooth_k = True
self.device = "cuda"
self.use_fp16 = True
# Added new configuration dictionary
self._config = {
"dinov2_size": "large", # Default model size
"dinov2_model": "dinov2_vitg14" # Default model name
}
# Added new methods
def get(self, key, default=None):
"""Get configuration value with fallback"""
return self._config.get(key, default)
def set(self, key, value):
"""Set configuration value"""
self._config[key] = value
def setup_environment(self):
"""Set up all environment variables and backends"""
import os
from trellis.modules import set_attention_backend
from trellis.modules.sparse import set_backend
# Set attention backend
set_attention_backend(self.attention_backend)
# Set smooth k for sage attention
os.environ['SAGEATTN_SMOOTH_K'] = '1' if self.smooth_k else '0'
# Set spconv algorithm
os.environ['SPCONV_ALGO'] = self.spconv_algo
# Always use spconv as backend for now
set_backend('spconv')
logger.info(f"Environment configured - Backend: spconv, "
f"Attention: {self.attention_backend}, "
f"Smooth K: {self.smooth_k}, "
f"SpConv Algo: {self.spconv_algo}")
# Global config instance
TRELLIS_CONFIG = TrellisConfig()
class IF_TrellisCheckpointLoader:
"""
Node to manage the loading of the TRELLIS model.
Follows ComfyUI conventions for model management.
"""
def __init__(self):
self.logger = logger
self.model_manager = None
# Check for available devices
self.device = self._get_device()
def _get_device(self):
"""Determine the best available device."""
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps"
return "cpu"
@classmethod
def INPUT_TYPES(cls):
"""Define input types with device-specific options."""
device_options = []
if torch.cuda.is_available():
device_options.append("cuda")
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device_options.append("mps")
device_options.append("cpu")
return {
"required": {
"model_name": (["TRELLIS-image-large"],),
"dinov2_model": (["dinov2_vitl14_reg", "dinov2_vitg14_reg"], {"default": "dinov2_vitl14_reg", "tooltip": "Select the Dinov2 model to use for the image to 3D conversion. Smaller models work but better results with larger models."}),
"use_fp16": ("BOOLEAN", {"default": True}),
"attn_backend": (["sage", "xformers", "flash_attn", "sdpa", "naive"], {"default": "sage", "tooltip": "Select the attention backend to use for the image to 3D conversion. Sage is experimental but faster"}),
"smooth_k": ("BOOLEAN", {"default": True, "tooltip": "Smooth k for sage attention. This is a hyperparameter that controls the smoothness of the attention distribution. It is a boolean value that determines whether to use smooth k or not. Smooth k is a hyperparameter that controls the smoothness of the attention distribution. It is a boolean value that determines whether to use smooth k or not."}),
"spconv_algo": (["implicit_gemm", "native"], {"default": "implicit_gemm", "tooltip": "Select the spconv algorithm to use for the image to 3D conversion. Implicit gemm is the best but slower. Native is the fastest but less accurate."}),
"main_device": (device_options, {"default": device_options[0]}),
},
}
RETURN_TYPES = ("TRELLIS_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "ImpactFrames💥🎞️/Trellis"
@classmethod
def _check_backend_availability(cls, backend: str) -> bool:
"""Check if a specific attention backend is available"""
try:
if backend == 'sage':
import sageattention
elif backend == 'xformers':
import xformers.ops
elif backend == 'flash_attn':
import flash_attn
elif backend in ['sdpa', 'naive']:
# These are always available in PyTorch
pass
else:
return False
return True
except ImportError:
return False
@classmethod
def _initialize_backend(cls, requested_backend: str = None) -> str:
"""Initialize attention backend with fallback logic"""
# Priority order for backends
backend_priority = ['sage', 'flash_attn', 'xformers', 'sdpa']
# If a specific backend is requested, try it first
if requested_backend:
if cls._check_backend_availability(requested_backend):
logger.info(f"Using requested attention backend: {requested_backend}")
return requested_backend
else:
logger.warning(f"Requested backend '{requested_backend}' not available, falling back")
# Try backends in priority order
for backend in backend_priority:
if cls._check_backend_availability(backend):
logger.info(f"Using attention backend: {backend}")
return backend
# Final fallback to SDPA
logger.info("All optimized attention backends unavailable, using PyTorch SDPA")
return 'sdpa'
def _setup_environment(self):
"""
Set up environment variables based on the global TRELLIS_CONFIG.
"""
import os
from trellis.modules import set_attention_backend
from trellis.modules.sparse import set_backend
from trellis.modules.sparse.conv import SPCONV_ALGO
# Set attention backend
os.environ['ATTN_BACKEND'] = TRELLIS_CONFIG.attention_backend
set_attention_backend(TRELLIS_CONFIG.attention_backend)
# Set smooth k for sage attention
os.environ['SAGEATTN_SMOOTH_K'] = '1' if TRELLIS_CONFIG.smooth_k else '0'
# Set spconv algorithm
os.environ['SPCONV_ALGO'] = TRELLIS_CONFIG.spconv_algo
# Always use spconv as backend for now
set_backend('spconv')
logger.info(f"Environment configured - Backend: spconv, "
f"Attention: {TRELLIS_CONFIG.attention_backend}, "
f"Smooth K: {TRELLIS_CONFIG.smooth_k}, "
f"SpConv Algo: {TRELLIS_CONFIG.spconv_algo}")
def optimize_pipeline(self, pipeline, use_fp16=True, attn_backend='sage'):
"""Apply optimizations to the pipeline if available"""
if self.device == "cuda":
try:
if hasattr(pipeline, 'cuda'):
pipeline.cuda()
if use_fp16:
if hasattr(pipeline, 'enable_attention_slicing'):
pipeline.enable_attention_slicing()
if hasattr(pipeline, 'half'):
pipeline.half()
# Only enable xformers if using xformers backend
if attn_backend == 'xformers' and hasattr(pipeline, 'enable_xformers_memory_efficient_attention'):
pipeline.enable_xformers_memory_efficient_attention()
except Exception as e:
logger.warning(f"Some optimizations failed: {str(e)}")
return pipeline
def load_model(self, model_name, dinov2_model="dinov2_vitg14", attn_backend="sage", use_fp16=True,
smooth_k=True, spconv_algo="implicit_gemm", main_device="cuda"):
"""Load and configure the TRELLIS model."""
try:
# Update global config
TRELLIS_CONFIG.attention_backend = attn_backend
TRELLIS_CONFIG.spconv_algo = spconv_algo
TRELLIS_CONFIG.smooth_k = smooth_k
TRELLIS_CONFIG.device = main_device
TRELLIS_CONFIG.use_fp16 = use_fp16
TRELLIS_CONFIG.set("dinov2_model", dinov2_model)
# Set up environment
self._setup_environment()
# Configure attention backend
set_attention_backend(attn_backend)
if attn_backend == 'sage':
enable_sage_attention()
else:
disable_sage_attention()
# Get model path
model_path = folder_paths.get_full_path("checkpoints", model_name)
if model_path is None:
model_path = os.path.join(folder_paths.models_dir, "checkpoints", model_name)
# Create pipeline with specified dinov2 model
pipeline = TrellisImageTo3DPipeline.from_pretrained(model_path, dinov2_model=dinov2_model)
# Configure pipeline after loading
pipeline._device = torch.device(main_device)
pipeline.attention_backend = attn_backend
# Store configuration in pipeline
pipeline.config = {
'device': main_device,
'use_fp16': use_fp16,
'attention_backend': attn_backend,
'dinov2_model': dinov2_model,
'spconv_algo': spconv_algo,
'smooth_k': smooth_k
}
# Apply optimizations
pipeline = self.optimize_pipeline(pipeline, use_fp16, attn_backend)
return (pipeline,)
except Exception as e:
logger.error(f"Error loading TRELLIS model: {str(e)}")
raise
|