Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torch.nn as nn | |
import yaml | |
from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder | |
from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors | |
from .volume_decoders import VanillaVolumeDecoder | |
from ...utils import logger, synchronize_timer | |
class VectsetVAE(nn.Module): | |
def from_single_file( | |
cls, | |
ckpt_path, | |
config_path, | |
device='cuda', | |
dtype=torch.float16, | |
use_safetensors=None, | |
**kwargs, | |
): | |
# load config | |
with open(config_path, 'r') as f: | |
config = yaml.safe_load(f) | |
# load ckpt | |
if use_safetensors: | |
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors') | |
if not os.path.exists(ckpt_path): | |
raise FileNotFoundError(f"Model file {ckpt_path} not found") | |
logger.info(f"Loading model from {ckpt_path}") | |
if use_safetensors: | |
import safetensors.torch | |
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu') | |
else: | |
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True) | |
model_kwargs = config['params'] | |
model_kwargs.update(kwargs) | |
model = cls(**model_kwargs) | |
model.load_state_dict(ckpt) | |
model.to(device=device, dtype=dtype) | |
return model | |
def from_pretrained( | |
cls, | |
model_path, | |
device='cuda', | |
dtype=torch.float16, | |
use_safetensors=True, | |
variant='fp16', | |
subfolder='hunyuan3d-vae-v2-0', | |
**kwargs, | |
): | |
original_model_path = model_path | |
# try local path | |
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen') | |
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder)) | |
logger.info(f'Try to load model from local path: {model_path}') | |
if not os.path.exists(model_path): | |
logger.info('Model path not exists, try to download from huggingface') | |
try: | |
import huggingface_hub | |
# download from huggingface | |
path = huggingface_hub.snapshot_download(repo_id=original_model_path) | |
model_path = os.path.join(path, subfolder) | |
except ImportError: | |
logger.warning( | |
"You need to install HuggingFace Hub to load models from the hub." | |
) | |
raise RuntimeError(f"Model path {model_path} not found") | |
except Exception as e: | |
raise e | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"Model path {original_model_path} not found") | |
extension = 'ckpt' if not use_safetensors else 'safetensors' | |
variant = '' if variant is None else f'.{variant}' | |
ckpt_name = f'model{variant}.{extension}' | |
config_path = os.path.join(model_path, 'config.yaml') | |
ckpt_path = os.path.join(model_path, ckpt_name) | |
return cls.from_single_file( | |
ckpt_path, | |
config_path, | |
device=device, | |
dtype=dtype, | |
use_safetensors=use_safetensors, | |
**kwargs | |
) | |
def __init__( | |
self, | |
volume_decoder=None, | |
surface_extractor=None | |
): | |
super().__init__() | |
if volume_decoder is None: | |
volume_decoder = VanillaVolumeDecoder() | |
if surface_extractor is None: | |
surface_extractor = MCSurfaceExtractor() | |
self.volume_decoder = volume_decoder | |
self.surface_extractor = surface_extractor | |
def latents2mesh(self, latents: torch.FloatTensor, **kwargs): | |
with synchronize_timer('Volume decoding'): | |
grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs) | |
with synchronize_timer('Surface extraction'): | |
outputs = self.surface_extractor(grid_logits, **kwargs) | |
return outputs | |
class ShapeVAE(VectsetVAE): | |
def __init__( | |
self, | |
*, | |
num_latents: int, | |
embed_dim: int, | |
width: int, | |
heads: int, | |
num_decoder_layers: int, | |
geo_decoder_downsample_ratio: int = 1, | |
geo_decoder_mlp_expand_ratio: int = 4, | |
geo_decoder_ln_post: bool = True, | |
num_freqs: int = 8, | |
include_pi: bool = True, | |
qkv_bias: bool = True, | |
qk_norm: bool = False, | |
label_type: str = "binary", | |
drop_path_rate: float = 0.0, | |
scale_factor: float = 1.0, | |
): | |
super().__init__() | |
self.geo_decoder_ln_post = geo_decoder_ln_post | |
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) | |
self.post_kl = nn.Linear(embed_dim, width) | |
self.transformer = Transformer( | |
n_ctx=num_latents, | |
width=width, | |
layers=num_decoder_layers, | |
heads=heads, | |
qkv_bias=qkv_bias, | |
qk_norm=qk_norm, | |
drop_path_rate=drop_path_rate | |
) | |
self.geo_decoder = CrossAttentionDecoder( | |
fourier_embedder=self.fourier_embedder, | |
out_channels=1, | |
num_latents=num_latents, | |
mlp_expand_ratio=geo_decoder_mlp_expand_ratio, | |
downsample_ratio=geo_decoder_downsample_ratio, | |
enable_ln_post=self.geo_decoder_ln_post, | |
width=width // geo_decoder_downsample_ratio, | |
heads=heads // geo_decoder_downsample_ratio, | |
qkv_bias=qkv_bias, | |
qk_norm=qk_norm, | |
label_type=label_type, | |
) | |
self.scale_factor = scale_factor | |
self.latent_shape = (num_latents, embed_dim) | |
def forward(self, latents): | |
latents = self.post_kl(latents) | |
latents = self.transformer(latents) | |
return latents | |