Spaces:
Running
on
Zero
Running
on
Zero
import math | |
from collections import defaultdict | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import craftsman | |
from craftsman.utils.typing import * | |
def dot(x, y): | |
return torch.sum(x * y, -1, keepdim=True) | |
def reflect(x, n): | |
return 2 * dot(x, n) * n - x | |
ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] | |
def scale_tensor( | |
dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale | |
): | |
if inp_scale is None: | |
inp_scale = (0, 1) | |
if tgt_scale is None: | |
tgt_scale = (0, 1) | |
if isinstance(tgt_scale, Tensor): | |
assert dat.shape[-1] == tgt_scale.shape[-1] | |
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) | |
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] | |
return dat | |
def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: | |
if chunk_size <= 0: | |
return func(*args, **kwargs) | |
B = None | |
for arg in list(args) + list(kwargs.values()): | |
if isinstance(arg, torch.Tensor): | |
B = arg.shape[0] | |
break | |
assert ( | |
B is not None | |
), "No tensor found in args or kwargs, cannot determine batch size." | |
out = defaultdict(list) | |
out_type = None | |
# max(1, B) to support B == 0 | |
for i in range(0, max(1, B), chunk_size): | |
out_chunk = func( | |
*[ | |
arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg | |
for arg in args | |
], | |
**{ | |
k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg | |
for k, arg in kwargs.items() | |
}, | |
) | |
if out_chunk is None: | |
continue | |
out_type = type(out_chunk) | |
if isinstance(out_chunk, torch.Tensor): | |
out_chunk = {0: out_chunk} | |
elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): | |
chunk_length = len(out_chunk) | |
out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} | |
elif isinstance(out_chunk, dict): | |
pass | |
else: | |
print( | |
f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." | |
) | |
exit(1) | |
for k, v in out_chunk.items(): | |
v = v if torch.is_grad_enabled() else v.detach() | |
out[k].append(v) | |
if out_type is None: | |
return None | |
out_merged: Dict[Any, Optional[torch.Tensor]] = {} | |
for k, v in out.items(): | |
if all([vv is None for vv in v]): | |
# allow None in return value | |
out_merged[k] = None | |
elif all([isinstance(vv, torch.Tensor) for vv in v]): | |
out_merged[k] = torch.cat(v, dim=0) | |
else: | |
raise TypeError( | |
f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" | |
) | |
if out_type is torch.Tensor: | |
return out_merged[0] | |
elif out_type in [tuple, list]: | |
return out_type([out_merged[i] for i in range(chunk_length)]) | |
elif out_type is dict: | |
return out_merged | |
def randn_tensor( | |
shape: Union[Tuple, List], | |
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, | |
device: Optional["torch.device"] = None, | |
dtype: Optional["torch.dtype"] = None, | |
layout: Optional["torch.layout"] = None, | |
): | |
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When | |
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor | |
is always created on the CPU. | |
""" | |
# device on which tensor is created defaults to device | |
rand_device = device | |
batch_size = shape[0] | |
layout = layout or torch.strided | |
device = device or torch.device("cpu") | |
if generator is not None: | |
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type | |
if gen_device_type != device.type and gen_device_type == "cpu": | |
rand_device = "cpu" | |
if device != "mps": | |
logger.info( | |
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." | |
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" | |
f" slighly speed up this function by passing a generator that was created on the {device} device." | |
) | |
elif gen_device_type != device.type and gen_device_type == "cuda": | |
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") | |
# make sure generator list of length 1 is treated like a non-list | |
if isinstance(generator, list) and len(generator) == 1: | |
generator = generator[0] | |
if isinstance(generator, list): | |
shape = (1,) + shape[1:] | |
latents = [ | |
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) | |
for i in range(batch_size) | |
] | |
latents = torch.cat(latents, dim=0).to(device) | |
else: | |
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) | |
return latents | |
def generate_dense_grid_points( | |
bbox_min: np.ndarray, | |
bbox_max: np.ndarray, | |
octree_depth: int, | |
indexing: str = "ij" | |
): | |
length = bbox_max - bbox_min | |
num_cells = np.exp2(octree_depth) | |
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) | |
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) | |
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) | |
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) | |
xyz = np.stack((xs, ys, zs), axis=-1) | |
xyz = xyz.reshape(-1, 3) | |
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] | |
return xyz, grid_size, length |