|
""" |
|
Helpers for sampling from a single- or multi-stage point cloud diffusion model. |
|
""" |
|
|
|
from typing import Any, Callable, Dict, Iterator, List, Sequence, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from point_e.util.point_cloud import PointCloud |
|
|
|
from .gaussian_diffusion import GaussianDiffusion |
|
from .k_diffusion import karras_sample_progressive |
|
|
|
|
|
class PointCloudSampler: |
|
""" |
|
A wrapper around a model or stack of models that produces conditional or |
|
unconditional sample tensors. |
|
|
|
By default, this will load models and configs from files. |
|
If you want to modify the sampler arguments of an existing sampler, call |
|
with_options() or with_args(). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
device: torch.device, |
|
models: Sequence[nn.Module], |
|
diffusions: Sequence[GaussianDiffusion], |
|
num_points: Sequence[int], |
|
aux_channels: Sequence[str], |
|
model_kwargs_key_filter: Sequence[str] = ("*",), |
|
guidance_scale: Sequence[float] = (3.0, 3.0), |
|
clip_denoised: bool = True, |
|
use_karras: Sequence[bool] = (True, True), |
|
karras_steps: Sequence[int] = (64, 64), |
|
sigma_min: Sequence[float] = (1e-3, 1e-3), |
|
sigma_max: Sequence[float] = (120, 160), |
|
s_churn: Sequence[float] = (3, 0), |
|
): |
|
n = len(models) |
|
assert n > 0 |
|
|
|
if n > 1: |
|
if len(guidance_scale) == 1: |
|
|
|
guidance_scale = list(guidance_scale) + [1.0] * (n - 1) |
|
if len(use_karras) == 1: |
|
use_karras = use_karras * n |
|
if len(karras_steps) == 1: |
|
karras_steps = karras_steps * n |
|
if len(sigma_min) == 1: |
|
sigma_min = sigma_min * n |
|
if len(sigma_max) == 1: |
|
sigma_max = sigma_max * n |
|
if len(s_churn) == 1: |
|
s_churn = s_churn * n |
|
if len(model_kwargs_key_filter) == 1: |
|
model_kwargs_key_filter = model_kwargs_key_filter * n |
|
if len(model_kwargs_key_filter) == 0: |
|
model_kwargs_key_filter = ["*"] * n |
|
assert len(guidance_scale) == n |
|
assert len(use_karras) == n |
|
assert len(karras_steps) == n |
|
assert len(sigma_min) == n |
|
assert len(sigma_max) == n |
|
assert len(s_churn) == n |
|
assert len(model_kwargs_key_filter) == n |
|
|
|
self.device = device |
|
self.num_points = num_points |
|
self.aux_channels = aux_channels |
|
self.model_kwargs_key_filter = model_kwargs_key_filter |
|
self.guidance_scale = guidance_scale |
|
self.clip_denoised = clip_denoised |
|
self.use_karras = use_karras |
|
self.karras_steps = karras_steps |
|
self.sigma_min = sigma_min |
|
self.sigma_max = sigma_max |
|
self.s_churn = s_churn |
|
|
|
self.models = models |
|
self.diffusions = diffusions |
|
|
|
@property |
|
def num_stages(self) -> int: |
|
return len(self.models) |
|
|
|
def sample_batch(self, batch_size: int, model_kwargs: Dict[str, Any]) -> torch.Tensor: |
|
samples = None |
|
for x in self.sample_batch_progressive(batch_size, model_kwargs): |
|
samples = x |
|
return samples |
|
|
|
def sample_batch_progressive( |
|
self, batch_size: int, model_kwargs: Dict[str, Any] |
|
) -> Iterator[torch.Tensor]: |
|
samples = None |
|
for ( |
|
model, |
|
diffusion, |
|
stage_num_points, |
|
stage_guidance_scale, |
|
stage_use_karras, |
|
stage_karras_steps, |
|
stage_sigma_min, |
|
stage_sigma_max, |
|
stage_s_churn, |
|
stage_key_filter, |
|
) in zip( |
|
self.models, |
|
self.diffusions, |
|
self.num_points, |
|
self.guidance_scale, |
|
self.use_karras, |
|
self.karras_steps, |
|
self.sigma_min, |
|
self.sigma_max, |
|
self.s_churn, |
|
self.model_kwargs_key_filter, |
|
): |
|
stage_model_kwargs = model_kwargs.copy() |
|
if stage_key_filter != "*": |
|
use_keys = set(stage_key_filter.split(",")) |
|
stage_model_kwargs = {k: v for k, v in stage_model_kwargs.items() if k in use_keys} |
|
if samples is not None: |
|
stage_model_kwargs["low_res"] = samples |
|
if hasattr(model, "cached_model_kwargs"): |
|
stage_model_kwargs = model.cached_model_kwargs(batch_size, stage_model_kwargs) |
|
sample_shape = (batch_size, 3 + len(self.aux_channels), stage_num_points) |
|
|
|
if stage_guidance_scale != 1 and stage_guidance_scale != 0: |
|
for k, v in stage_model_kwargs.copy().items(): |
|
stage_model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0) |
|
|
|
if stage_use_karras: |
|
samples_it = karras_sample_progressive( |
|
diffusion=diffusion, |
|
model=model, |
|
shape=sample_shape, |
|
steps=stage_karras_steps, |
|
clip_denoised=self.clip_denoised, |
|
model_kwargs=stage_model_kwargs, |
|
device=self.device, |
|
sigma_min=stage_sigma_min, |
|
sigma_max=stage_sigma_max, |
|
s_churn=stage_s_churn, |
|
guidance_scale=stage_guidance_scale, |
|
) |
|
else: |
|
internal_batch_size = batch_size |
|
if stage_guidance_scale: |
|
model = self._uncond_guide_model(model, stage_guidance_scale) |
|
internal_batch_size *= 2 |
|
samples_it = diffusion.p_sample_loop_progressive( |
|
model, |
|
shape=(internal_batch_size, *sample_shape[1:]), |
|
model_kwargs=stage_model_kwargs, |
|
device=self.device, |
|
clip_denoised=self.clip_denoised, |
|
) |
|
for x in samples_it: |
|
samples = x["pred_xstart"][:batch_size] |
|
if "low_res" in stage_model_kwargs: |
|
samples = torch.cat( |
|
[stage_model_kwargs["low_res"][: len(samples)], samples], dim=-1 |
|
) |
|
yield samples |
|
|
|
@classmethod |
|
def combine(cls, *samplers: "PointCloudSampler") -> "PointCloudSampler": |
|
assert all(x.device == samplers[0].device for x in samplers[1:]) |
|
assert all(x.aux_channels == samplers[0].aux_channels for x in samplers[1:]) |
|
assert all(x.clip_denoised == samplers[0].clip_denoised for x in samplers[1:]) |
|
return cls( |
|
device=samplers[0].device, |
|
models=[x for y in samplers for x in y.models], |
|
diffusions=[x for y in samplers for x in y.diffusions], |
|
num_points=[x for y in samplers for x in y.num_points], |
|
aux_channels=samplers[0].aux_channels, |
|
model_kwargs_key_filter=[x for y in samplers for x in y.model_kwargs_key_filter], |
|
guidance_scale=[x for y in samplers for x in y.guidance_scale], |
|
clip_denoised=samplers[0].clip_denoised, |
|
use_karras=[x for y in samplers for x in y.use_karras], |
|
karras_steps=[x for y in samplers for x in y.karras_steps], |
|
sigma_min=[x for y in samplers for x in y.sigma_min], |
|
sigma_max=[x for y in samplers for x in y.sigma_max], |
|
s_churn=[x for y in samplers for x in y.s_churn], |
|
) |
|
|
|
def _uncond_guide_model( |
|
self, model: Callable[..., torch.Tensor], scale: float |
|
) -> Callable[..., torch.Tensor]: |
|
def model_fn(x_t, ts, **kwargs): |
|
half = x_t[: len(x_t) // 2] |
|
combined = torch.cat([half, half], dim=0) |
|
model_out = model(combined, ts, **kwargs) |
|
eps, rest = model_out[:, :3], model_out[:, 3:] |
|
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0) |
|
half_eps = uncond_eps + scale * (cond_eps - uncond_eps) |
|
eps = torch.cat([half_eps, half_eps], dim=0) |
|
return torch.cat([eps, rest], dim=1) |
|
|
|
return model_fn |
|
|
|
def split_model_output( |
|
self, |
|
output: torch.Tensor, |
|
rescale_colors: bool = False, |
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
assert ( |
|
len(self.aux_channels) + 3 == output.shape[1] |
|
), "there must be three spatial channels before aux" |
|
pos, joined_aux = output[:, :3], output[:, 3:] |
|
|
|
aux = {} |
|
for i, name in enumerate(self.aux_channels): |
|
v = joined_aux[:, i] |
|
if name in {"R", "G", "B", "A"}: |
|
v = v.clamp(0, 255).round() |
|
if rescale_colors: |
|
v = v / 255.0 |
|
aux[name] = v |
|
return pos, aux |
|
|
|
def output_to_point_clouds(self, output: torch.Tensor) -> List[PointCloud]: |
|
res = [] |
|
for sample in output: |
|
xyz, aux = self.split_model_output(sample[None], rescale_colors=True) |
|
res.append( |
|
PointCloud( |
|
coords=xyz[0].t().cpu().numpy(), |
|
channels={k: v[0].cpu().numpy() for k, v in aux.items()}, |
|
) |
|
) |
|
return res |
|
|
|
def with_options( |
|
self, |
|
guidance_scale: float, |
|
clip_denoised: bool, |
|
use_karras: Sequence[bool] = (True, True), |
|
karras_steps: Sequence[int] = (64, 64), |
|
sigma_min: Sequence[float] = (1e-3, 1e-3), |
|
sigma_max: Sequence[float] = (120, 160), |
|
s_churn: Sequence[float] = (3, 0), |
|
) -> "PointCloudSampler": |
|
return PointCloudSampler( |
|
device=self.device, |
|
models=self.models, |
|
diffusions=self.diffusions, |
|
num_points=self.num_points, |
|
aux_channels=self.aux_channels, |
|
model_kwargs_key_filter=self.model_kwargs_key_filter, |
|
guidance_scale=guidance_scale, |
|
clip_denoised=clip_denoised, |
|
use_karras=use_karras, |
|
karras_steps=karras_steps, |
|
sigma_min=sigma_min, |
|
sigma_max=sigma_max, |
|
s_churn=s_churn, |
|
) |
|
|