Huiwenshi's picture
init
04b20ec
raw
history blame
5.9 kB
import os
import torch
import torch.nn as nn
import yaml
from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder
from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
from .volume_decoders import VanillaVolumeDecoder
from ...utils import logger, synchronize_timer
class VectsetVAE(nn.Module):
@classmethod
@synchronize_timer('VectsetVAE Model Loading')
def from_single_file(
cls,
ckpt_path,
config_path,
device='cuda',
dtype=torch.float16,
use_safetensors=None,
**kwargs,
):
# load config
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# load ckpt
if use_safetensors:
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Model file {ckpt_path} not found")
logger.info(f"Loading model from {ckpt_path}")
if use_safetensors:
import safetensors.torch
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
else:
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
model_kwargs = config['params']
model_kwargs.update(kwargs)
model = cls(**model_kwargs)
model.load_state_dict(ckpt)
model.to(device=device, dtype=dtype)
return model
@classmethod
def from_pretrained(
cls,
model_path,
device='cuda',
dtype=torch.float16,
use_safetensors=True,
variant='fp16',
subfolder='hunyuan3d-vae-v2-0',
**kwargs,
):
original_model_path = model_path
# try local path
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
logger.info(f'Try to load model from local path: {model_path}')
if not os.path.exists(model_path):
logger.info('Model path not exists, try to download from huggingface')
try:
import huggingface_hub
# download from huggingface
path = huggingface_hub.snapshot_download(repo_id=original_model_path)
model_path = os.path.join(path, subfolder)
except ImportError:
logger.warning(
"You need to install HuggingFace Hub to load models from the hub."
)
raise RuntimeError(f"Model path {model_path} not found")
except Exception as e:
raise e
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model path {original_model_path} not found")
extension = 'ckpt' if not use_safetensors else 'safetensors'
variant = '' if variant is None else f'.{variant}'
ckpt_name = f'model{variant}.{extension}'
config_path = os.path.join(model_path, 'config.yaml')
ckpt_path = os.path.join(model_path, ckpt_name)
return cls.from_single_file(
ckpt_path,
config_path,
device=device,
dtype=dtype,
use_safetensors=use_safetensors,
**kwargs
)
def __init__(
self,
volume_decoder=None,
surface_extractor=None
):
super().__init__()
if volume_decoder is None:
volume_decoder = VanillaVolumeDecoder()
if surface_extractor is None:
surface_extractor = MCSurfaceExtractor()
self.volume_decoder = volume_decoder
self.surface_extractor = surface_extractor
def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
with synchronize_timer('Volume decoding'):
grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)
with synchronize_timer('Surface extraction'):
outputs = self.surface_extractor(grid_logits, **kwargs)
return outputs
class ShapeVAE(VectsetVAE):
def __init__(
self,
*,
num_latents: int,
embed_dim: int,
width: int,
heads: int,
num_decoder_layers: int,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
include_pi: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary",
drop_path_rate: float = 0.0,
scale_factor: float = 1.0,
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
self.post_kl = nn.Linear(embed_dim, width)
self.transformer = Transformer(
n_ctx=num_latents,
width=width,
layers=num_decoder_layers,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
self.geo_decoder = CrossAttentionDecoder(
fourier_embedder=self.fourier_embedder,
out_channels=1,
num_latents=num_latents,
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
downsample_ratio=geo_decoder_downsample_ratio,
enable_ln_post=self.geo_decoder_ln_post,
width=width // geo_decoder_downsample_ratio,
heads=heads // geo_decoder_downsample_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
label_type=label_type,
)
self.scale_factor = scale_factor
self.latent_shape = (num_latents, embed_dim)
def forward(self, latents):
latents = self.post_kl(latents)
latents = self.transformer(latents)
return latents