Spaces:
Running
on
Zero
Running
on
Zero
from typing import Dict, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.models.attention_processor import Attention, AttentionProcessor | |
from diffusers.models.autoencoders.vae import DecoderOutput | |
from diffusers.models.modeling_outputs import AutoencoderKLOutput | |
from diffusers.models.modeling_utils import ModelMixin | |
from diffusers.models.normalization import FP32LayerNorm, LayerNorm | |
from diffusers.utils import logging | |
from diffusers.utils.accelerate_utils import apply_forward_hook | |
from einops import repeat | |
from tqdm import tqdm | |
from torch_cluster import fps | |
from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0 | |
from ..embeddings import FrequencyPositionalEmbedding | |
from ..transformers.triposg_transformer import DiTBlock | |
from .vae import DiagonalGaussianDistribution | |
import subprocess | |
import sys | |
def install_package(package_name): | |
try: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name]) | |
return True | |
except subprocess.CalledProcessError: | |
return False | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class TripoSGEncoder(nn.Module): | |
def __init__( | |
self, | |
in_channels: int = 3, | |
dim: int = 512, | |
num_attention_heads: int = 8, | |
num_layers: int = 8, | |
): | |
super().__init__() | |
self.proj_in = nn.Linear(in_channels, dim, bias=True) | |
self.blocks = nn.ModuleList( | |
[ | |
DiTBlock( | |
dim=dim, | |
num_attention_heads=num_attention_heads, | |
use_self_attention=False, | |
use_cross_attention=True, | |
cross_attention_dim=dim, | |
cross_attention_norm_type="layer_norm", | |
activation_fn="gelu", | |
norm_type="fp32_layer_norm", | |
norm_eps=1e-5, | |
qk_norm=False, | |
qkv_bias=False, | |
) # cross attention | |
] | |
+ [ | |
DiTBlock( | |
dim=dim, | |
num_attention_heads=num_attention_heads, | |
use_self_attention=True, | |
self_attention_norm_type="fp32_layer_norm", | |
use_cross_attention=False, | |
use_cross_attention_2=False, | |
activation_fn="gelu", | |
norm_type="fp32_layer_norm", | |
norm_eps=1e-5, | |
qk_norm=False, | |
qkv_bias=False, | |
) | |
for _ in range(num_layers) # self attention | |
] | |
) | |
self.norm_out = LayerNorm(dim) | |
def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor): | |
hidden_states = self.proj_in(sample_1) | |
encoder_hidden_states = self.proj_in(sample_2) | |
for layer, block in enumerate(self.blocks): | |
if layer == 0: | |
hidden_states = block( | |
hidden_states, encoder_hidden_states=encoder_hidden_states | |
) | |
else: | |
hidden_states = block(hidden_states) | |
hidden_states = self.norm_out(hidden_states) | |
return hidden_states | |
class TripoSGDecoder(nn.Module): | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int = 1, | |
dim: int = 512, | |
num_attention_heads: int = 8, | |
num_layers: int = 16, | |
grad_type: str = "analytical", | |
grad_interval: float = 0.001, | |
): | |
super().__init__() | |
if grad_type not in ["numerical", "analytical"]: | |
raise ValueError(f"grad_type must be one of ['numerical', 'analytical']") | |
self.grad_type = grad_type | |
self.grad_interval = grad_interval | |
self.blocks = nn.ModuleList( | |
[ | |
DiTBlock( | |
dim=dim, | |
num_attention_heads=num_attention_heads, | |
use_self_attention=True, | |
self_attention_norm_type="fp32_layer_norm", | |
use_cross_attention=False, | |
use_cross_attention_2=False, | |
activation_fn="gelu", | |
norm_type="fp32_layer_norm", | |
norm_eps=1e-5, | |
qk_norm=False, | |
qkv_bias=False, | |
) | |
for _ in range(num_layers) # self attention | |
] | |
+ [ | |
DiTBlock( | |
dim=dim, | |
num_attention_heads=num_attention_heads, | |
use_self_attention=False, | |
use_cross_attention=True, | |
cross_attention_dim=dim, | |
cross_attention_norm_type="layer_norm", | |
activation_fn="gelu", | |
norm_type="fp32_layer_norm", | |
norm_eps=1e-5, | |
qk_norm=False, | |
qkv_bias=False, | |
) # cross attention | |
] | |
) | |
self.proj_query = nn.Linear(in_channels, dim, bias=True) | |
self.norm_out = LayerNorm(dim) | |
self.proj_out = nn.Linear(dim, out_channels, bias=True) | |
def query_geometry( | |
self, | |
model_fn: callable, | |
queries: torch.Tensor, | |
sample: torch.Tensor, | |
grad: bool = False, | |
): | |
logits = model_fn(queries, sample) | |
if grad: | |
with torch.autocast(device_type="cuda", dtype=torch.float32): | |
if self.grad_type == "numerical": | |
interval = self.grad_interval | |
grad_value = [] | |
for offset in [ | |
(interval, 0, 0), | |
(0, interval, 0), | |
(0, 0, interval), | |
]: | |
offset_tensor = torch.tensor(offset, device=queries.device)[ | |
None, : | |
] | |
res_p = model_fn(queries + offset_tensor, sample)[..., 0] | |
res_n = model_fn(queries - offset_tensor, sample)[..., 0] | |
grad_value.append((res_p - res_n) / (2 * interval)) | |
grad_value = torch.stack(grad_value, dim=-1) | |
else: | |
queries_d = torch.clone(queries) | |
queries_d.requires_grad = True | |
with torch.enable_grad(): | |
res_d = model_fn(queries_d, sample) | |
grad_value = torch.autograd.grad( | |
res_d, | |
[queries_d], | |
grad_outputs=torch.ones_like(res_d), | |
create_graph=self.training, | |
)[0] | |
else: | |
grad_value = None | |
return logits, grad_value | |
def forward( | |
self, | |
sample: torch.Tensor, | |
queries: torch.Tensor, | |
kv_cache: Optional[torch.Tensor] = None, | |
): | |
if kv_cache is None: | |
hidden_states = sample | |
for _, block in enumerate(self.blocks[:-1]): | |
hidden_states = block(hidden_states) | |
kv_cache = hidden_states | |
# query grid logits by cross attention | |
def query_fn(q, kv): | |
q = self.proj_query(q) | |
l = self.blocks[-1](q, encoder_hidden_states=kv) | |
return self.proj_out(self.norm_out(l)) | |
logits, grad = self.query_geometry( | |
query_fn, queries, kv_cache, grad=self.training | |
) | |
logits = logits * -1 if not isinstance(logits, Tuple) else logits[0] * -1 | |
return logits, kv_cache | |
class TripoSGVAEModel(ModelMixin, ConfigMixin): | |
def __init__( | |
self, | |
in_channels: int = 3, # NOTE xyz instead of feature dim | |
latent_channels: int = 64, | |
num_attention_heads: int = 8, | |
width_encoder: int = 512, | |
width_decoder: int = 1024, | |
num_layers_encoder: int = 8, | |
num_layers_decoder: int = 16, | |
embedding_type: str = "frequency", | |
embed_frequency: int = 8, | |
embed_include_pi: bool = False, | |
): | |
super().__init__() | |
self.out_channels = 1 | |
if embedding_type == "frequency": | |
self.embedder = FrequencyPositionalEmbedding( | |
num_freqs=embed_frequency, | |
logspace=True, | |
input_dim=in_channels, | |
include_pi=embed_include_pi, | |
) | |
else: | |
raise NotImplementedError( | |
f"Embedding type {embedding_type} is not supported." | |
) | |
self.encoder = TripoSGEncoder( | |
in_channels=in_channels + self.embedder.out_dim, | |
dim=width_encoder, | |
num_attention_heads=num_attention_heads, | |
num_layers=num_layers_encoder, | |
) | |
self.decoder = TripoSGDecoder( | |
in_channels=self.embedder.out_dim, | |
out_channels=self.out_channels, | |
dim=width_decoder, | |
num_attention_heads=num_attention_heads, | |
num_layers=num_layers_decoder, | |
) | |
self.quant = nn.Linear(width_encoder, latent_channels * 2, bias=True) | |
self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True) | |
self.use_slicing = False | |
self.slicing_length = 1 | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0 | |
def fuse_qkv_projections(self): | |
""" | |
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) | |
are fused. For cross-attention modules, key and value projection matrices are fused. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
self.original_attn_processors = None | |
for _, attn_processor in self.attn_processors.items(): | |
if "Added" in str(attn_processor.__class__.__name__): | |
raise ValueError( | |
"`fuse_qkv_projections()` is not supported for models having added KV projections." | |
) | |
self.original_attn_processors = self.attn_processors | |
for module in self.modules(): | |
if isinstance(module, Attention): | |
module.fuse_projections(fuse=True) | |
self.set_attn_processor(FusedTripoSGAttnProcessor2_0()) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections | |
def unfuse_qkv_projections(self): | |
"""Disables the fused QKV projection if enabled. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
if self.original_attn_processors is not None: | |
self.set_attn_processor(self.original_attn_processors) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors( | |
name: str, | |
module: torch.nn.Module, | |
processors: Dict[str, AttentionProcessor], | |
): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor() | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor | |
def set_attn_processor( | |
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] | |
): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
self.set_attn_processor(TripoSGAttnProcessor2_0()) | |
def enable_slicing(self, slicing_length: int = 1) -> None: | |
r""" | |
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to | |
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. | |
""" | |
self.use_slicing = True | |
self.slicing_length = slicing_length | |
def disable_slicing(self) -> None: | |
r""" | |
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing | |
decoding in one step. | |
""" | |
self.use_slicing = False | |
def _sample_features( | |
self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None | |
): | |
""" | |
Sample points from features of the input point cloud. | |
Args: | |
x (torch.Tensor): The input point cloud. shape: (B, N, C) | |
num_tokens (int, optional): The number of points to sample. Defaults to 2048. | |
seed (Optional[int], optional): The random seed. Defaults to None. | |
""" | |
rng = np.random.default_rng(seed) | |
indices = rng.choice( | |
x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1] | |
) | |
selected_points = x[:, indices] | |
batch_size, num_points, num_channels = selected_points.shape | |
flattened_points = selected_points.view(batch_size * num_points, num_channels) | |
batch_indices = ( | |
torch.arange(batch_size).to(x.device).repeat_interleave(num_points) | |
) | |
# fps sampling | |
sampling_ratio = 1.0 / 4 | |
sampled_indices = fps( | |
flattened_points[:, :3], | |
batch_indices, | |
ratio=sampling_ratio, | |
random_start=self.training, | |
) | |
sampled_points = flattened_points[sampled_indices].view( | |
batch_size, -1, num_channels | |
) | |
return sampled_points | |
def _encode( | |
self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None | |
): | |
position_channels = self.config.in_channels | |
positions, features = x[..., :position_channels], x[..., position_channels:] | |
x_kv = torch.cat([self.embedder(positions), features], dim=-1) | |
sampled_x = self._sample_features(x, num_tokens, seed) | |
positions, features = ( | |
sampled_x[..., :position_channels], | |
sampled_x[..., position_channels:], | |
) | |
x_q = torch.cat([self.embedder(positions), features], dim=-1) | |
x = self.encoder(x_q, x_kv) | |
x = self.quant(x) | |
return x | |
def encode( | |
self, x: torch.Tensor, return_dict: bool = True, **kwargs | |
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: | |
""" | |
Encode a batch of point features into latents. | |
""" | |
if self.use_slicing and x.shape[0] > 1: | |
encoded_slices = [ | |
self._encode(x_slice, **kwargs) | |
for x_slice in x.split(self.slicing_length) | |
] | |
h = torch.cat(encoded_slices) | |
else: | |
h = self._encode(x, **kwargs) | |
posterior = DiagonalGaussianDistribution(h, feature_dim=-1) | |
if not return_dict: | |
return (posterior,) | |
return AutoencoderKLOutput(latent_dist=posterior) | |
def _decode( | |
self, | |
z: torch.Tensor, | |
sampled_points: torch.Tensor, | |
num_chunks: int = 50000, | |
to_cpu: bool = False, | |
return_dict: bool = True, | |
) -> Union[DecoderOutput, torch.Tensor]: | |
xyz_samples = sampled_points | |
z = self.post_quant(z) | |
num_points = xyz_samples.shape[1] | |
kv_cache = None | |
dec = [] | |
for i in range(0, num_points, num_chunks): | |
queries = xyz_samples[:, i : i + num_chunks, :].to(z.device, dtype=z.dtype) | |
queries = self.embedder(queries) | |
z_, kv_cache = self.decoder(z, queries, kv_cache) | |
dec.append(z_ if not to_cpu else z_.cpu()) | |
z = torch.cat(dec, dim=1) | |
if not return_dict: | |
return (z,) | |
return DecoderOutput(sample=z) | |
def decode( | |
self, | |
z: torch.Tensor, | |
sampled_points: torch.Tensor, | |
return_dict: bool = True, | |
**kwargs, | |
) -> Union[DecoderOutput, torch.Tensor]: | |
if self.use_slicing and z.shape[0] > 1: | |
decoded_slices = [ | |
self._decode(z_slice, p_slice, **kwargs).sample | |
for z_slice, p_slice in zip( | |
z.split(self.slicing_length), | |
sampled_points.split(self.slicing_length), | |
) | |
] | |
decoded = torch.cat(decoded_slices) | |
else: | |
decoded = self._decode(z, sampled_points, **kwargs).sample | |
if not return_dict: | |
return (decoded,) | |
return DecoderOutput(sample=decoded) | |
def forward(self, x: torch.Tensor): | |
pass | |