|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import importlib |
|
import inspect |
|
import logging |
|
import os |
|
from typing import List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
import trimesh |
|
import yaml |
|
from PIL import Image |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from tqdm import tqdm |
|
|
|
from accelerate import init_empty_weights |
|
from accelerate.utils import set_module_tensor_to_device |
|
|
|
from comfy.utils import ProgressBar |
|
import comfy.model_management as mm |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def retrieve_timesteps( |
|
scheduler, |
|
num_inference_steps: Optional[int] = None, |
|
device: Optional[Union[str, torch.device]] = None, |
|
timesteps: Optional[List[int]] = None, |
|
sigmas: Optional[List[float]] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
|
|
|
Args: |
|
scheduler (`SchedulerMixin`): |
|
The scheduler to get timesteps from. |
|
num_inference_steps (`int`): |
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
|
must be `None`. |
|
device (`str` or `torch.device`, *optional*): |
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
timesteps (`List[int]`, *optional*): |
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
|
`num_inference_steps` and `sigmas` must be `None`. |
|
sigmas (`List[float]`, *optional*): |
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
|
`num_inference_steps` and `timesteps` must be `None`. |
|
|
|
Returns: |
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
|
second element is the number of inference steps. |
|
""" |
|
if timesteps is not None and sigmas is not None: |
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
|
if timesteps is not None: |
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accepts_timesteps: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" timestep schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
elif sigmas is not None: |
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accept_sigmas: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" sigmas schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
else: |
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
return timesteps, num_inference_steps |
|
|
|
|
|
def export_to_trimesh(mesh_output): |
|
if isinstance(mesh_output, list): |
|
outputs = [] |
|
for mesh in mesh_output: |
|
if mesh is None: |
|
outputs.append(None) |
|
else: |
|
mesh.mesh_f = mesh.mesh_f[:, ::-1] |
|
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f) |
|
outputs.append(mesh_output) |
|
return outputs |
|
else: |
|
mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1] |
|
mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f) |
|
return mesh_output |
|
|
|
|
|
def get_obj_from_str(string, reload=False): |
|
package_directory_name = os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
|
|
|
module, cls = string.rsplit(".", 1) |
|
if reload: |
|
module_imp = importlib.import_module(module) |
|
importlib.reload(module_imp) |
|
return getattr(importlib.import_module(module, package=package_directory_name), cls) |
|
|
|
|
|
def instantiate_from_config(config, **kwargs): |
|
if "target" not in config: |
|
raise KeyError("Expected key `target` to instantiate.") |
|
cls = get_obj_from_str(config["target"]) |
|
params = config.get("params", dict()) |
|
kwargs.update(params) |
|
instance = cls(**kwargs) |
|
return instance |
|
|
|
|
|
class Hunyuan3DDiTPipeline: |
|
@classmethod |
|
def from_single_file( |
|
cls, |
|
ckpt_path, |
|
config_path, |
|
device='cuda', |
|
offload_device=torch.device('cpu'), |
|
dtype=torch.float16, |
|
use_safetensors=None, |
|
compile_args=None, |
|
attention_mode="sdpa", |
|
**kwargs, |
|
): |
|
|
|
with open(config_path, 'r') as f: |
|
config = yaml.safe_load(f) |
|
|
|
|
|
if use_safetensors: |
|
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors') |
|
if not os.path.exists(ckpt_path): |
|
raise FileNotFoundError(f"Model file {ckpt_path} not found") |
|
logger.info(f"Loading model from {ckpt_path}") |
|
|
|
if use_safetensors: |
|
|
|
import safetensors.torch |
|
safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device='cpu') |
|
ckpt = {} |
|
for key, value in safetensors_ckpt.items(): |
|
model_name = key.split('.')[0] |
|
new_key = key[len(model_name) + 1:] |
|
if model_name not in ckpt: |
|
ckpt[model_name] = {} |
|
ckpt[model_name][new_key] = value |
|
else: |
|
ckpt = torch.load(ckpt_path, map_location='cpu') |
|
|
|
|
|
|
|
if "guidance_in.in_layer.bias" in ckpt['model']: |
|
logger.info("Model has guidance_in, setting guidance_embed to True") |
|
config['model']['params']['guidance_embed'] = True |
|
config['model']['params']['attention_mode'] = attention_mode |
|
config['vae']['params']['attention_mode'] = attention_mode |
|
|
|
with init_empty_weights(): |
|
model = instantiate_from_config(config['model']) |
|
vae = instantiate_from_config(config['vae']) |
|
conditioner = instantiate_from_config(config['conditioner']) |
|
|
|
for name, param in model.named_parameters(): |
|
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name]) |
|
|
|
for name, param in vae.named_parameters(): |
|
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name]) |
|
|
|
if 'conditioner' in ckpt: |
|
|
|
for name, param in conditioner.named_parameters(): |
|
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=ckpt['conditioner'][name]) |
|
|
|
image_processor = instantiate_from_config(config['image_processor']) |
|
scheduler = instantiate_from_config(config['scheduler']) |
|
|
|
if compile_args is not None: |
|
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] |
|
if compile_args["compile_transformer"]: |
|
model = torch.compile(model) |
|
if compile_args["compile_vae"]: |
|
vae = torch.compile(vae) |
|
|
|
model_kwargs = dict( |
|
|
|
model=model, |
|
scheduler=scheduler, |
|
conditioner=conditioner, |
|
image_processor=image_processor, |
|
device=device, |
|
offload_device=offload_device, |
|
dtype=dtype, |
|
) |
|
model_kwargs.update(kwargs) |
|
|
|
return cls(**model_kwargs), vae |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__( |
|
self, |
|
|
|
model, |
|
scheduler, |
|
conditioner, |
|
image_processor, |
|
device=torch.device('cuda'), |
|
offload_device=torch.device('cpu'), |
|
dtype=torch.float16, |
|
**kwargs |
|
): |
|
|
|
self.model = model |
|
self.scheduler = scheduler |
|
self.conditioner = conditioner |
|
self.image_processor = image_processor |
|
|
|
self.main_device = device |
|
self.offload_device = offload_device |
|
|
|
self.to(offload_device, dtype) |
|
|
|
def to(self, device=None, dtype=None): |
|
if device is not None: |
|
|
|
self.model.to(device) |
|
self.conditioner.to(device) |
|
if dtype is not None: |
|
self.dtype = dtype |
|
|
|
self.model.to(dtype=dtype) |
|
self.conditioner.to(dtype=dtype) |
|
|
|
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance): |
|
self.conditioner.to(self.main_device) |
|
bsz = image.shape[0] |
|
cond = self.conditioner(image=image, mask=mask) |
|
|
|
if do_classifier_free_guidance: |
|
un_cond = self.conditioner.unconditional_embedding(bsz) |
|
|
|
if dual_guidance: |
|
un_cond_drop_main = copy.deepcopy(un_cond) |
|
un_cond_drop_main['additional'] = cond['additional'] |
|
|
|
def cat_recursive(a, b, c): |
|
|
|
if isinstance(a, torch.Tensor): |
|
return torch.cat([a, b, c], dim=0).to(self.dtype) |
|
out = {} |
|
for k in a.keys(): |
|
out[k] = cat_recursive(a[k], b[k], c[k]) |
|
return out |
|
|
|
cond = cat_recursive(cond, un_cond_drop_main, un_cond) |
|
else: |
|
un_cond = self.conditioner.unconditional_embedding(bsz) |
|
|
|
def cat_recursive(a, b): |
|
if isinstance(a, torch.Tensor): |
|
return torch.cat([a, b], dim=0).to(self.dtype) |
|
out = {} |
|
for k in a.keys(): |
|
out[k] = cat_recursive(a[k], b[k]) |
|
return out |
|
|
|
cond = cat_recursive(cond, un_cond) |
|
self.conditioner.to(self.offload_device) |
|
return cond |
|
|
|
def prepare_extra_step_kwargs(self, generator, eta): |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
if accepts_generator: |
|
extra_step_kwargs["generator"] = generator |
|
return extra_step_kwargs |
|
|
|
def prepare_latents(self, batch_size, dtype, device, generator, latents=None): |
|
|
|
num_latents = 3072 |
|
embed_dim = 64 |
|
shape = (batch_size, num_latents, embed_dim) |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if latents is None: |
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
else: |
|
latents = latents.to(device) |
|
|
|
|
|
latents = latents * getattr(self.scheduler, 'init_noise_sigma', 1.0) |
|
return latents |
|
|
|
def prepare_image(self, image): |
|
if isinstance(image, str) and not os.path.exists(image): |
|
raise FileNotFoundError(f"Couldn't find image at path {image}") |
|
|
|
if not isinstance(image, list): |
|
image = [image] |
|
image_pts = [] |
|
mask_pts = [] |
|
for img in image: |
|
image_pt, mask_pt = self.image_processor(img, return_mask=True) |
|
image_pts.append(image_pt) |
|
mask_pts.append(mask_pt) |
|
|
|
image_pts = torch.cat(image_pts, dim=0).to(self.main_device, dtype=self.dtype) |
|
if mask_pts[0] is not None: |
|
mask_pts = torch.cat(mask_pts, dim=0).to(self.main_device, dtype=self.dtype) |
|
else: |
|
mask_pts = None |
|
return image_pts, mask_pts |
|
|
|
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): |
|
""" |
|
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 |
|
|
|
Args: |
|
timesteps (`torch.Tensor`): |
|
generate embedding vectors at these timesteps |
|
embedding_dim (`int`, *optional*, defaults to 512): |
|
dimension of the embeddings to generate |
|
dtype: |
|
data type of the generated embeddings |
|
|
|
Returns: |
|
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` |
|
""" |
|
assert len(w.shape) == 1 |
|
w = w * 1000.0 |
|
|
|
half_dim = embedding_dim // 2 |
|
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) |
|
emb = w.to(dtype)[:, None] * emb[None, :] |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0, 1)) |
|
assert emb.shape == (w.shape[0], embedding_dim) |
|
return emb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline): |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
image: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
num_inference_steps: int = 50, |
|
timesteps: List[int] = None, |
|
sigmas: List[float] = None, |
|
|
|
guidance_scale: float = 7.5, |
|
generator=None, |
|
|
|
|
|
|
|
|
|
|
|
|
|
enable_pbar=True, |
|
**kwargs, |
|
) -> List[List[trimesh.Trimesh]]: |
|
callback = kwargs.pop("callback", None) |
|
callback_steps = kwargs.pop("callback_steps", None) |
|
|
|
device = self.main_device |
|
dtype = self.dtype |
|
do_classifier_free_guidance = guidance_scale >= 0 and not ( |
|
hasattr(self.model, 'guidance_embed') and |
|
self.model.guidance_embed is True |
|
) |
|
|
|
|
|
|
|
cond = self.encode_cond( |
|
image=image, |
|
mask=mask, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
dual_guidance=False, |
|
) |
|
batch_size = image.shape[0] |
|
|
|
|
|
|
|
sigmas = np.linspace(0, 1, num_inference_steps) if sigmas is None else sigmas |
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
self.scheduler, |
|
num_inference_steps, |
|
device, |
|
sigmas=sigmas, |
|
) |
|
latents = self.prepare_latents(batch_size, dtype, device, generator) |
|
|
|
guidance = None |
|
if hasattr(self.model, 'guidance_embed') and \ |
|
self.model.guidance_embed is True: |
|
guidance = torch.tensor([guidance_scale] * batch_size, device=device, dtype=dtype) |
|
print("guidance: ", guidance) |
|
|
|
comfy_pbar = ProgressBar(num_inference_steps) |
|
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:")): |
|
|
|
if do_classifier_free_guidance: |
|
latent_model_input = torch.cat([latents] * 2) |
|
else: |
|
latent_model_input = latents |
|
|
|
|
|
timestep = t.expand(latent_model_input.shape[0]).to( |
|
latents.dtype) / self.scheduler.config.num_train_timesteps |
|
noise_pred = self.model(latent_model_input, timestep, cond, guidance=guidance) |
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
|
|
outputs = self.scheduler.step(noise_pred, t, latents) |
|
latents = outputs.prev_sample |
|
|
|
if callback is not None and i % callback_steps == 0: |
|
step_idx = i // getattr(self.scheduler, "order", 1) |
|
callback(step_idx, t, outputs) |
|
comfy_pbar.update(1) |
|
print("latents shape: ", latents.shape) |
|
return latents |
|
|
|
|
|
|
|
|
|
|
|
|